Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pykokkos/core/visitors/pykokkos_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,17 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr:
)
elif name in ["PerTeam", "PerThread", "fence"]:
name = "Kokkos::" + name
elif name in {"complex32", "complex64"}:
name = "Kokkos::complex"
if "32" in name:
name += "<float>"
else:
name += "<double>"

function = cppast.DeclRefExpr(name)
args: List[cppast.Expr] = [self.visit(a) for a in node.args]

if visitors_util.is_math_function(name) or name in ["printf", "abs", "Kokkos::PerTeam", "Kokkos::PerThread", "Kokkos::fence"]:
if visitors_util.is_math_function(name) or name in ["printf", "abs", "Kokkos::PerTeam", "Kokkos::PerThread", "Kokkos::fence", "Kokkos::complex<float>", "Kokkos::complex<double>"]:
return cppast.CallExpr(function, args)

if function in self.kokkos_functions:
Expand Down
7 changes: 5 additions & 2 deletions pykokkos/core/visitors/visitors_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def pretty_print(node):
"double": "double",
"bool": "bool",
"TeamMember": f"Kokkos::TeamPolicy<{Keywords.DefaultExecSpace.value}>::member_type",
"cpp_auto": "auto"
"cpp_auto": "auto",
"complex32": "Kokkos::complex<float>",
"complex64": "Kokkos::complex<double>"
}

# Maps from the DataType enum to cppast
Expand Down Expand Up @@ -311,7 +313,8 @@ def parse_view_template_params(

if parameter in ("int", "double", "float",
"int8_t", "int16_t", "int32_t", "int64_t",
"uint8_t", "uint16_t", "uint32_t", "uint64_t"):
"uint8_t", "uint16_t", "uint32_t", "uint64_t",
"Kokkos::complex<float>", "Kokkos::complex<double>"):
datatype: str = parameter + "*" * rank
params["dtype"] = datatype

Expand Down
1 change: 1 addition & 0 deletions pykokkos/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
uint16, uint32, uint64,
float, double, real,
float32, float64, bool,
complex32, complex64
)
from .decorators import (
callback, classtype, Decorator, function, functor, main,
Expand Down
114 changes: 114 additions & 0 deletions pykokkos/interface/data_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from enum import Enum
from typing import Union
import builtins

from pykokkos.bindings import kokkos
import pykokkos.kokkos_manager as km

import numpy as np


Expand All @@ -26,6 +30,8 @@ class DataType(Enum):
float64 = kokkos.double
real = None
bool = np.bool_
complex32 = getattr(kokkos, 'complex_float32_dtype', None)
complex64 = getattr(kokkos, 'complex_float64_dtype', None)


class DataTypeClass:
Expand Down Expand Up @@ -91,3 +97,111 @@ class float64(DataTypeClass):
class bool(DataTypeClass):
value = kokkos.uint8
np_equiv = np.bool_


class complex(DataTypeClass):
def __add__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot add '{}' and '{}'".format(type(other), type(self)))

if isinstance(self, complex32):
return complex32(self.kokkos_complex + other.kokkos_complex)
elif isinstance(self, complex64):
return complex64(self.kokkos_complex + other.kokkos_complex)

def __iadd__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot add '{}' and '{}'".format(type(other), type(self)))

self.kokkos_complex += other.kokkos_complex
return self

def __sub__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot subtract '{}' and '{}'".format(type(other), type(self)))

if isinstance(self, complex32):
return complex32(self.kokkos_complex - other.kokkos_complex)
elif isinstance(self, complex64):
return complex64(self.kokkos_complex - other.kokkos_complex)

def __isub__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot subtract '{}' and '{}'".format(type(other), type(self)))

self.kokkos_complex -= other.kokkos_complex
return self

def __mul__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot multiply '{}' and '{}'".format(type(other), type(self)))

if isinstance(self, complex32):
return complex32(self.kokkos_complex * other.kokkos_complex)
elif isinstance(self, complex64):
return complex64(self.kokkos_complex * other.kokkos_complex)

def __imul__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot multiply '{}' and '{}'".format(type(other), type(self)))

self.kokkos_complex *= other.kokkos_complex
return self

def __truediv__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot divide '{}' and '{}'".format(type(other), type(self)))

if isinstance(self, complex32):
return complex32(self.kokkos_complex / other.kokkos_complex)
elif isinstance(self, complex64):
return complex64(self.kokkos_complex / other.kokkos_complex)

def __itruediv__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot divide '{}' and '{}'".format(type(other), type(self)))

self.kokkos_complex /= other.kokkos_complex
return self

def __repr__(self):
return f"({self.kokkos_complex.real_const()}, {self.kokkos_complex.imag_const()})"

@property
def real(self):
return self.kokkos_complex.real_const()

@property
def imag(self):
return self.kokkos_complex.imag_const()

class complex32(complex):
value = getattr(kokkos, 'complex_float32_dtype', None)
np_equiv = np.complex64 # 32 bits from real + 32 from imaginary

def __init__(self, real: Union[builtins.int, builtins.float], imaginary: Union[builtins.int, builtins.float] = 0.0):
complex_float32_type = getattr(kokkos, 'complex_float32', None)
if complex_float32_type is not None and isinstance(real, complex_float32_type):
self.kokkos_complex = real
else:
kokkos_module = km.get_kokkos_module(is_cpu=True)
if hasattr(kokkos_module, 'complex_float32'):
self.kokkos_complex = kokkos_module.complex_float32(real, imaginary)
else:
raise NotImplementedError("complex_float32 not available in kokkos bindings")


class complex64(complex):
value = getattr(kokkos, 'complex_float64_dtype', None)
np_equiv = np.complex128 # 64 bits from real + 64 from imaginary

def __init__(self, real: Union[builtins.int, builtins.float], imaginary: Union[builtins.int, builtins.float] = 0.0):
complex_float64_type = getattr(kokkos, 'complex_float64', None)
if complex_float64_type is not None and isinstance(real, complex_float64_type):
self.kokkos_complex = real
else:
kokkos_module = km.get_kokkos_module(is_cpu=True)
if hasattr(kokkos_module, 'complex_float64'):
self.kokkos_complex = kokkos_module.complex_float64(real, imaginary)
else:
raise NotImplementedError("complex_float64 not available in kokkos bindings")
Loading
Loading