diff --git a/pyop3/codegen/ir.py b/pyop3/codegen/ir.py index 18af57ca..46b21ef7 100644 --- a/pyop3/codegen/ir.py +++ b/pyop3/codegen/ir.py @@ -22,6 +22,7 @@ from pyop3 import utils from pyop3.axes import Axis, AxisComponent, AxisTree, AxisVariable from pyop3.axes.tree import ContextSensitiveAxisTree +from pyop3.device import CPUDevice, CUDADevice, OpenCLDevice, offloading_device from pyop3.distarray import DistributedArray, MultiArray from pyop3.distarray.multiarray import ContextSensitiveMultiArray from pyop3.distarray.petsc import IndexedPetscMat, PetscMat, PetscObject @@ -66,10 +67,20 @@ strictly_all, ) -# FIXME this needs to be synchronised with TSFC, tricky -# shared base package? or both set by Firedrake - better solution -LOOPY_TARGET = lp.CWithGNULibcTarget() -LOOPY_LANG_VERSION = (2018, 2) + +def loopy_target(): + if isinstance(offloading_device, CPUDevice): + return lp.CWithGNULibcTarget() + elif isinstance(offloading_device, CUDADevice): + return lp.CudaTarget() + elif isinstance(offloading_device, OpenCLDevice): + return lp.PyOpenCLTarget() + else: + raise AssertionError + + +def loopy_lang_version(): + return (2018, 2) class OpaqueType(lp.types.OpaqueType): @@ -320,8 +331,8 @@ def compile(expr: LoopExpr, name="mykernel"): ctx.instructions, ctx.arguments, name=name, - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), preambles=preambles, # options=lp.Options(check_dep_resolution=False), ) diff --git a/pyop3/codegen/loopy.py b/pyop3/codegen/loopy.py new file mode 100644 index 00000000..bad51d65 --- /dev/null +++ b/pyop3/codegen/loopy.py @@ -0,0 +1,870 @@ +from __future__ import annotations + +import abc +import collections +import contextlib +import copy +import dataclasses +import enum +import functools +import itertools +import numbers +import operator +from typing import Any, Dict, FrozenSet, Optional, Sequence, Tuple + +import loopy as lp +import loopy.symbolic +import numpy as np +import pymbolic as pym +import pytools +from pyrsistent import pmap + +from pyop3 import utils +from pyop3.axes import Axis, AxisComponent, AxisTree, AxisVariable, CalledAxisTree +from pyop3.device import CPUDevice, CUDADevice, OpenCLDevice, offloading_device +from pyop3.distarray import MultiArray +from pyop3.dtypes import IntType +from pyop3.indices import ( + AffineMapComponent, + AffineSliceComponent, + CalledMap, + Index, + IndexedArray, + IndexedAxisTree, + IndexTree, + LocalLoopIndex, + LoopIndex, + Map, + MapVariable, + Slice, + Subset, + TabulatedMapComponent, +) +from pyop3.lang import ( + INC, + MAX_RW, + MAX_WRITE, + MIN_RW, + MIN_WRITE, + READ, + RW, + WRITE, + FunctionCall, + Increment, + Loop, + Read, + Write, + Zero, +) +from pyop3.log import logger +from pyop3.utils import ( + PrettyTuple, + checked_zip, + just_one, + merge_dicts, + single_valued, + strictly_all, +) + + +def loopy_target(): + if isinstance(offloading_device, CPUDevice): + return lp.CWithGNULibcTarget() + elif isinstance(offloading_device, CUDADevice): + return lp.CudaTarget() + elif isinstance(offloading_device, OpenCLDevice): + return lp.PyOpenCLTarget() + else: + raise AssertionError + + +def loopy_lang_version(): + return (2018, 2) + + +class CodegenContext(abc.ABC): + pass + + +class LoopyCodegenContext(CodegenContext): + def __init__(self): + self._domains = [] + self._insns = [] + self._args = [] + self._subkernels = [] + + self._within_inames = frozenset() + self._last_insn_id = None + + self._name_generator = pytools.UniqueNameGenerator() + + @property + def domains(self): + return tuple(self._domains) + + @property + def instructions(self): + return tuple(self._insns) + + @property + def arguments(self): + # TODO should renumber things here + return tuple(self._args) + + @property + def subkernels(self): + return tuple(self._subkernels) + + def add_domain(self, iname, *args): + nargs = len(args) + if nargs == 1: + start, stop = 0, args[0] + else: + assert nargs == 2 + start, stop = args[0], args[1] + domain_str = f"{{ [{iname}]: {start} <= {iname} < {stop} }}" + self._domains.append(domain_str) + + def add_assignment(self, assignee, expression, prefix="insn"): + insn = lp.Assignment( + assignee, + expression, + id=self._name_generator(prefix), + within_inames=frozenset(self._within_inames), + within_inames_is_final=True, + depends_on=self._depends_on, + depends_on_is_final=True, + ) + self._add_instruction(insn) + + def add_function_call(self, assignees, expression, prefix="insn"): + insn = lp.CallInstruction( + assignees, + expression, + id=self._name_generator(prefix), + within_inames=self._within_inames, + within_inames_is_final=True, + depends_on=self._depends_on, + depends_on_is_final=True, + ) + self._add_instruction(insn) + + def add_argument(self, name, dtype): + # FIXME if self._args is a set then we can add duplicates here provided + # that we canonically renumber at a later point + if name in [a.name for a in self._args]: + logger.debug( + f"Skipping adding {name} to the codegen context as it is already present" + ) + return + arg = lp.GlobalArg(name, dtype=dtype, shape=None) + self._args.append(arg) + + def add_temporary(self, name, dtype=IntType, shape=()): + temp = lp.TemporaryVariable(name, dtype=dtype, shape=shape) + self._args.append(temp) + + def add_subkernel(self, subkernel): + self._subkernels.append(subkernel) + + # I am not sure that this belongs here, I generate names separately from adding domains etc + def unique_name(self, prefix): + # add prefix to the generator so names are generated starting with + # "prefix_0" instead of "prefix" + self._name_generator.add_name(prefix, conflicting_ok=True) + return self._name_generator(prefix) + + @contextlib.contextmanager + def within_inames(self, inames) -> None: + orig_within_inames = self._within_inames + self._within_inames |= inames + yield + self._within_inames = orig_within_inames + + @property + def _depends_on(self): + return frozenset({self._last_insn_id}) - {None} + + def _add_instruction(self, insn): + self._insns.append(insn) + self._last_insn_id = insn.id + + +def compile(expr: LoopExpr, name="mykernel"): + ctx = LoopyCodegenContext() + _compile(expr, pmap(), ctx) + + # add a no-op instruction touching all of the kernel arguments so they are + # not silently dropped + noop = lp.CInstruction( + (), + "", + read_variables=frozenset({a.name for a in ctx.arguments}), + within_inames=frozenset(), + within_inames_is_final=True, + depends_on=ctx._depends_on, + ) + ctx._insns.append(noop) + + translation_unit = lp.make_kernel( + ctx.domains, + ctx.instructions, + ctx.arguments, + name=name, + target=loopy_target(), + lang_version=loopy_lang_version(), + # options=lp.Options(check_dep_resolution=False), + ) + tu = lp.merge((translation_unit, *ctx.subkernels)) + # breakpoint() + return tu.with_entrypoints("mykernel") + + +@functools.singledispatch +def _compile(expr: Any, ctx: LoopyCodegenContext) -> None: + raise TypeError + + +@_compile.register +def _( + loop: Loop, + loop_indices, + codegen_context: LoopyCodegenContext, +) -> None: + loop_context = {} + for loop_index, (source_path, target_path, _, _) in loop_indices.items(): + loop_context[loop_index] = source_path, target_path + loop_context = pmap(loop_context) + + iterset = loop.index.iterset.with_context(loop_context) + parse_loop_properly_this_time(loop, iterset, loop_indices, codegen_context) + + +def parse_loop_properly_this_time( + loop, + axes, + loop_indices, + codegen_context, + *, + axis=None, + source_path=pmap(), + target_path=pmap(), + iname_replace_map=pmap(), + jname_replace_map=pmap(), +): + from pyop3.distarray.multiarray import IndexExpressionReplacer + + if axes.is_empty: + raise NotImplementedError("does this even make sense?") + + axis = axis or axes.root + + domain_insns = [] + leaf_data = [] + + for component in axis.components: + iname = codegen_context.unique_name("i") + extent_var = register_extent( + component.count, iname_replace_map | jname_replace_map, codegen_context + ) + codegen_context.add_domain(iname, extent_var) + + new_source_path = source_path | {axis.label: component.label} + new_target_path = target_path | axes.target_path_per_component.get( + (axis.id, component.label), {} + ) + new_iname_replace_map = iname_replace_map | {axis.label: pym.var(iname)} + + # these aren't jnames! + my_index_exprs = axes.index_exprs_per_component.get( + (axis.id, component.label), {} + ) + jname_extras = {} + for axis_label, index_expr in my_index_exprs.items(): + jname_expr = JnameSubstitutor( + new_iname_replace_map | jname_replace_map, codegen_context + )(index_expr) + jname_extras[axis_label] = jname_expr + + new_jname_replace_map = jname_replace_map | jname_extras + + with codegen_context.within_inames({iname}): + if subaxis := axes.child(axis, component): + parse_loop_properly_this_time( + loop, + axes, + loop_indices, + codegen_context, + axis=subaxis, + source_path=new_source_path, + target_path=new_target_path, + iname_replace_map=new_iname_replace_map, + jname_replace_map=new_jname_replace_map, + ) + else: + for stmt in loop.statements: + _compile( + stmt, + loop_indices + | { + loop.index: ( + new_source_path, + new_target_path, + new_jname_replace_map, + new_iname_replace_map, + ) + }, + codegen_context, + ) + + +@_compile.register +def _(call: FunctionCall, loop_indices, ctx: LoopyCodegenContext) -> None: + """ + Turn an exprs.FunctionCall into a series of assignment instructions etc. + Handles packing/accessor logic. + """ + + temporaries = [] + subarrayrefs = {} + extents = {} + + # loopy args can contain ragged params too + loopy_args = call.function.code.default_entrypoint.args[: len(call.arguments)] + for loopy_arg, arg, spec in checked_zip(loopy_args, call.arguments, call.argspec): + # create an appropriate temporary + # we need the indices here because the temporary shape needs to be indexed + # by the same indices as the original array + # is this definitely true??? think so. because it gives us the right loops + # but we only really need it to determine "within" or not... + # if not isinstance(arg, MultiArray): + # # think PetscMat etc + # raise NotImplementedError( + # "Need to handle indices to create temp shape differently" + # ) + + loop_context = {} + for loop_index, (source_path, target_path, _, _) in loop_indices.items(): + loop_context[loop_index] = source_path, target_path + loop_context = pmap(loop_context) + + axes = arg.axes.with_context(loop_context).copy( + index_exprs=None, layout_exprs=None + ) + temporary = MultiArray( + axes, + name=ctx.unique_name("t"), + dtype=arg.dtype, + ) + indexed_temp = temporary + + if loopy_arg.shape is None: + shape = (temporary.alloc_size,) + else: + if np.prod(loopy_arg.shape, dtype=int) != temporary.alloc_size: + raise RuntimeError("Shape mismatch between inner and outer kernels") + shape = loopy_arg.shape + + temporaries.append((arg, indexed_temp, spec.access, shape)) + + # Register data + if not isinstance(arg, CalledAxisTree): + ctx.add_argument(arg.name, arg.dtype) + + ctx.add_temporary(temporary.name, temporary.dtype, shape) + + # subarrayref nonsense/magic + indices = [] + for s in shape: + iname = ctx.unique_name("i") + ctx.add_domain(iname, s) + indices.append(pym.var(iname)) + indices = tuple(indices) + + subarrayrefs[arg.name] = lp.symbolic.SubArrayRef( + indices, pym.subscript(pym.var(temporary.name), indices) + ) + + # we need to pass sizes through if they are only known at runtime (ragged) + # NOTE: If we register an extent to pass through loopy will complain + # unless we register it as an assumption of the local kernel (e.g. "n <= 3") + + # FIXME ragged is broken since I commented this out! determining shape of + # ragged things requires thought! + # for cidx in range(indexed_temp.index.root.degree): + # extents |= self.collect_extents( + # indexed_temp.index, + # indexed_temp.index.root, + # cidx, + # within_indices, + # within_inames, + # depends_on, + # ) + + # TODO this is pretty much the same as what I do in fix_intents in loopexpr.py + # probably best to combine them - could add a sensible check there too. + assignees = tuple( + subarrayrefs[arg.name] + for arg, spec in checked_zip(call.arguments, call.argspec) + if spec.access in {WRITE, RW, INC, MIN_RW, MIN_WRITE, MAX_RW, MAX_WRITE} + ) + expression = pym.primitives.Call( + pym.var(call.function.code.default_entrypoint.name), + tuple( + subarrayrefs[arg.name] + for arg, spec in checked_zip(call.arguments, call.argspec) + if spec.access in {READ, RW, INC, MIN_RW, MAX_RW} + ) + + tuple(extents.values()), + ) + + # gathers + for arg, temp, access, shape in temporaries: + if access in {READ, RW, MIN_RW, MAX_RW}: + gather = Read(arg, temp, shape) + else: + assert access in {WRITE, INC, MIN_WRITE, MAX_WRITE} + gather = Zero(arg, temp, shape) + build_assignment(gather, loop_indices, ctx) + + ctx.add_function_call(assignees, expression) + ctx.add_subkernel(call.function.code) + + # scatters + for arg, temp, access, shape in temporaries: + if access == READ: + continue + elif access in {WRITE, RW, MIN_RW, MIN_WRITE, MAX_RW, MAX_WRITE}: + scatter = Write(arg, temp, shape) + else: + assert access == INC + scatter = Increment(arg, temp, shape) + build_assignment(scatter, loop_indices, ctx) + + +# FIXME this is practically identical to what we do in build_loop +# parse_assignment? +def build_assignment( + assignment, + loop_indices, + codegen_ctx, +): + # each application of an index tree takes an input axis tree and the + # jnames that apply to each axis component and then filters/transforms the + # tree and determines instructions that generate these jnames. The resulting + # axis tree also has unspecified jnames. These are parsed in a final step into + # actual loops. + # The first step is therefore to generate these initial jnames, and the last + # is to emit the loops for the final tree. + # jnames_per_cpt, array_expr_per_leaf, insns_per_leaf = _prepare_assignment( + # assignment, codegen_ctx + # ) + + """ + The difference between iterating over map0(map1(p)).index() and axes.index() + is that the former may emit multiple loops but only a single jname is produced. For + the latter multiple jnames may result. + + (This is not quite true. We produce multiple jnames but only a single "jname expr" that + gets used to index the "prior" thing)??? I suppose the distinction is between whether we + are indexing the thing (in which case we want the jnames), or using it to index something + else, where we would want the "jname expr". Maybe this can be thought of as a function from + jnames -> "jname expr" and we want to go backwards. + + In both cases though the pattern is "loop over this object as if it were a tree". + I want to generalise this to both of these. + + This seems like a natural thing to do. In the rest of this we maintain the concept of + "prior" things and transform between indexed axes. In these cases we do not have to. It + is equivalent to a single step of this mapping. Sort of. + """ + + # get the right index tree given the loop context + loop_context = {} + for loop_index, (source_path, target_path, _, _) in loop_indices.items(): + loop_context[loop_index] = source_path, target_path + loop_context = pmap(loop_context) + + parse_assignment_properly_this_time( + assignment, + assignment.array.axes.with_context(loop_context), + loop_indices, + codegen_ctx, + ) + + +def parse_assignment_properly_this_time( + assignment, + axes, + loop_indices, + codegen_context, + *, + axis=None, + source_path=pmap(), + target_path=None, + iname_replace_map=pmap(), + jname_replace_map=None, +): + from pyop3.distarray.multiarray import IndexExpressionReplacer + + if axis is None: + axis = axes.root + target_path = axes.target_path_per_component.get(None, pmap()) + iname_replace_map = pmap( + { + axis_label: iname_var + for _, _, _, replace_map in loop_indices.values() + for axis_label, iname_var in replace_map.items() + } + ) + jname_replace_map = pmap( + { + axis_label: iname_var + for _, _, replace_map, _ in loop_indices.values() + for axis_label, iname_var in replace_map.items() + } + ) + + my_index_exprs = axes.index_exprs_per_component.get(None, pmap()) + jname_extras = {} + for axis_label, index_expr in my_index_exprs.items(): + jname_expr = JnameSubstitutor( + iname_replace_map | jname_replace_map, codegen_context + )(index_expr) + jname_extras[axis_label] = jname_expr + jname_replace_map = jname_replace_map | jname_extras + + if axes.is_empty: + add_leaf_assignment( + assignment, + axes, + source_path, + target_path, + iname_replace_map, + jname_replace_map, + codegen_context, + ) + return + + for component in axis.components: + iname = codegen_context.unique_name("i") + extent_var = register_extent( + component.count, iname_replace_map | jname_replace_map, codegen_context + ) + codegen_context.add_domain(iname, extent_var) + + new_source_path = source_path | {axis.label: component.label} # not used + new_target_path = target_path | axes.target_path_per_component.get( + (axis.id, component.label), {} + ) + + new_iname_replace_map = iname_replace_map | {axis.label: pym.var(iname)} + + # I don't like that I need to do this here and also when I emit the layout + # instructions. + my_index_exprs = axes.index_exprs_per_component.get( + (axis.id, component.label), {} + ) + jname_extras = {} + for axis_label, index_expr in my_index_exprs.items(): + jname_expr = JnameSubstitutor( + new_iname_replace_map | jname_replace_map, codegen_context + )(index_expr) + jname_extras[axis_label] = jname_expr + new_jname_replace_map = jname_replace_map | jname_extras + + with codegen_context.within_inames({iname}): + if subaxis := axes.child(axis, component): + parse_assignment_properly_this_time( + assignment, + axes, + loop_indices, + codegen_context, + axis=subaxis, + source_path=new_source_path, + target_path=new_target_path, + iname_replace_map=new_iname_replace_map, + jname_replace_map=new_jname_replace_map, + ) + + else: + add_leaf_assignment( + assignment, + axes, + new_source_path, + new_target_path, + new_iname_replace_map, + new_jname_replace_map, + codegen_context, + ) + + +# TODO I should disable emitting instructions for things like zero where we +# don't want insns for the array +def add_leaf_assignment( + assignment, + axes, + source_path, + target_path, + iname_replace_map, + jname_replace_map, + codegen_context, +): + from pyop3.distarray.multiarray import IndexExpressionReplacer + + array_expr = make_array_expr( + assignment, + axes.orig_layout_fn[target_path], + target_path, + iname_replace_map | jname_replace_map, + codegen_context, + ) + temp_expr = make_temp_expr( + assignment, source_path, iname_replace_map, codegen_context + ) + _shared_assignment_insn(assignment, array_expr, temp_expr, codegen_context) + + +def make_array_expr(assignment, layouts, path, jnames, ctx): + """ + + Return a list of (assignee, expression) tuples and the array expr used + in the assignment. + + """ + array_offset = make_offset_expr( + layouts, + jnames, + ctx, + ) + array = assignment.array + array_expr = pym.subscript(pym.var(array.name), array_offset) + + return array_expr + + +def make_temp_expr(assignment, path, jnames, ctx): + """ + + Return a list of (assignee, expression) tuples and the temp expr used + in the assignment. + + """ + layout = assignment.temporary.axes.layouts[path] + temp_offset = make_offset_expr( + layout, + jnames, + ctx, + ) + + temporary = assignment.temporary + + # hack to handle the fact that temporaries can have shape but we want to + # linearly index it here + extra_indices = (0,) * (len(assignment.shape) - 1) + # also has to be a scalar, not an expression + temp_offset_var = ctx.unique_name("off") + ctx.add_temporary(temp_offset_var) + ctx.add_assignment(temp_offset_var, temp_offset) + temp_offset_var = pym.var(temp_offset_var) + temp_expr = pym.subscript( + pym.var(temporary.name), extra_indices + (temp_offset_var,) + ) + return temp_expr + + +def _shared_assignment_insn(assignment, array_expr, temp_expr, ctx): + if isinstance(assignment, Read): + lexpr = temp_expr + rexpr = array_expr + elif isinstance(assignment, Write): + lexpr = array_expr + rexpr = temp_expr + elif isinstance(assignment, Increment): + lexpr = array_expr + rexpr = array_expr + temp_expr + elif isinstance(assignment, Zero): + lexpr = temp_expr + rexpr = 0 + else: + raise NotImplementedError + + ctx.add_assignment(lexpr, rexpr) + + +class JnameSubstitutor(pym.mapper.IdentityMapper): + # def __init__(self, path, jnames, codegen_context): + def __init__(self, replace_map, codegen_context): + # self._path = path + self._labels_to_jnames = replace_map + self._codegen_context = codegen_context + + def map_axis_variable(self, expr): + return self._labels_to_jnames[expr.axis_label] + + # I don't think that this should be required. + # def map_subscript(self, subscript): + # index = self.rec(subscript.index) + # + # trimmed_path = {} + # trimmed_jnames = {} + # axes = subscript.aggregate.axes + # axis = axes.root + # while axis: + # trimmed_path[axis.label] = self._path[axis.label] + # trimmed_jnames[axis.label] = self._labels_to_jnames[axis.label] + # cpt = just_one(axis.components) + # axis = axes.child(axis, cpt) + # trimmed_path = pmap(trimmed_path) + # trimmed_jnames = pmap(trimmed_jnames) + # + # insns, varname = _scalar_assignment( + # subscript.aggregate, + # trimmed_path, + # trimmed_jnames, + # self._codegen_context, + # ) + # for insn in insns: + # self._codegen_context.add_assignment(*insn) + # return varname + + # this is cleaner if I do it as a single line expression + # rather than register assignments for things. + def map_multi_array(self, array): + # must be single-component here + path = array.axes.path(*array.axes.leaf) + + trimmed_jnames = {} + axes = array.axes + axis = axes.root + while axis: + trimmed_jnames[axis.label] = self._labels_to_jnames[axis.label] + cpt = just_one(axis.components) + axis = axes.child(axis, cpt) + trimmed_jnames = pmap(trimmed_jnames) + + varname = _scalar_assignment( + array, + path, + trimmed_jnames, + self._codegen_context, + ) + return varname + + def map_called_map(self, expr): + if not isinstance(expr.function.map_component.array, MultiArray): + raise NotImplementedError("Affine map stuff not supported yet") + + inner_expr = [self.rec(param) for param in expr.parameters] + map_array = expr.function.map_component.array + + # handle [map0(p)][map1(p)] where map0 does not have an associated loop + try: + jname = self._labels_to_jnames[expr.function.full_map.name] + except KeyError: + jname = self._codegen_context.unique_name("j") + self._codegen_context.add_temporary(jname) + jname = pym.var(jname) + + # ? = map[j0, j1] + # where j0 comes from the from_index and j1 is advertised as the shape + # of the resulting axis (jname_per_cpt) + # j0 is now fixed but j1 can still be changed + rootaxis = map_array.axes.root + inner_axis, inner_cpt = map_array.axes.leaf + jname_expr = _scalar_assignment( + map_array, + pmap({rootaxis.label: just_one(rootaxis.components).label}) + | pmap({inner_axis.label: inner_cpt.label}), + {rootaxis.label: inner_expr[0]} | {inner_axis.label: inner_expr[1]}, + self._codegen_context, + ) + return jname_expr + + +def make_offset_expr( + layouts, + jname_replace_map, + codegen_context, +): + expr = JnameSubstitutor(jname_replace_map, codegen_context)(layouts) + + if expr == (): + expr = 0 + + return expr + + +def register_extent(extent, jnames, ctx): + if isinstance(extent, numbers.Integral): + return extent + + # actually a pymbolic expression + if not isinstance(extent, MultiArray): + raise NotImplementedError("need to tidy up assignment logic") + + path = extent.axes.path(*extent.axes.leaf) + expr = _scalar_assignment(extent, path, jnames, ctx) + + varname = ctx.unique_name("p") + ctx.add_temporary(varname) + ctx.add_assignment(pym.var(varname), expr) + return varname + + +class MultiArrayCollector(pym.mapper.Collector): + def map_multi_array(self, expr): + return {expr} + + +class VariableReplacer(pym.mapper.IdentityMapper): + def __init__(self, replace_map): + self._replace_map = replace_map + + def map_variable(self, expr): + return self._replace_map.get(expr.name, expr) + + +def collect_arrays(expr: pym.primitives.Expr): + collector = MultiArrayCollector() + return collector(expr) + + +def replace_variables( + expr: pym.primitives.Expr, replace_map: dict[str, pym.primitives.Variable] +): + return VariableReplacer(replace_map)(expr) + + +def _scalar_assignment( + array, + path, + array_labels_to_jnames, + ctx, +): + # Register data + ctx.add_argument(array.name, array.dtype) + + offset_expr = make_offset_expr( + array.axes.layouts[path], + array_labels_to_jnames, + ctx, + ) + rexpr = pym.subscript(pym.var(array.name), offset_expr) + return rexpr + + +def find_axis(axes, path, target, current_axis=None): + """Return the axis matching ``target`` along ``path``. + + ``path`` is a mapping between axis labels and the selected component indices. + """ + current_axis = current_axis or axes.root + + if current_axis.label == target: + return current_axis + else: + subaxis = axes.child(current_axis, path[current_axis.label]) + if not subaxis: + assert False, "oops" + return find_axis(axes, path, target, subaxis) diff --git a/pyop3/codegen/transforms.py b/pyop3/codegen/transforms.py new file mode 100644 index 00000000..e1295011 --- /dev/null +++ b/pyop3/codegen/transforms.py @@ -0,0 +1,123 @@ +def preprocess_t_unit_for_gpu(t_unit): + # {{{ inline all kernels in t_unit + + kernels_to_inline = { + name + for name, clbl in t_unit.callables_table.items() + if isinstance(clbl, lp.CallableKernel) + } + + for knl_name in kernels_to_inline: + t_unit = lp.inline_callable_kernel(t_unit, knl_name) + + # }}} + + kernel = t_unit.default_entrypoint + + # changing the address space of temps + def _change_aspace_tvs(tv): + if tv.read_only: + assert tv.initializer is not None + return tv.copy(address_space=lp.AddressSpace.GLOBAL) + else: + return tv.copy(address_space=lp.AddressSpace.PRIVATE) + + new_tvs = { + tv_name: _change_aspace_tvs(tv) + for tv_name, tv in kernel.temporary_variables.items() + } + kernel = kernel.copy(temporary_variables=new_tvs) + + def insn_needs_atomic(insn): + # updates to global variables are atomic + import pymbolic + + if isinstance(insn, lp.Assignment): + if isinstance(insn.assignee, pymbolic.primitives.Subscript): + assignee_name = insn.assignee.aggregate.name + else: + assert isinstance(insn.assignee, pymbolic.primitives.Variable) + assignee_name = insn.assignee.name + + if assignee_name in kernel.arg_dict: + return assignee_name in insn.read_dependency_names() + return False + + new_insns = [] + args_marked_for_atomic = set() + for insn in kernel.instructions: + if insn_needs_atomic(insn): + atomicity = (lp.AtomicUpdate(insn.assignee.aggregate.name),) + insn = insn.copy(atomicity=atomicity) + args_marked_for_atomic |= set([insn.assignee.aggregate.name]) + + new_insns.append(insn) + + # label args as atomic + new_args = [] + for arg in kernel.args: + if arg.name in args_marked_for_atomic: + new_args.append(arg.copy(for_atomic=True)) + else: + new_args.append(arg) + + kernel = kernel.copy(instructions=new_insns, args=new_args) + + return t_unit.with_kernel(kernel) + + +def _make_tv_array_arg(tv): + assert tv.address_space != lp.AddressSpace.PRIVATE + arg = lp.ArrayArg( + name=tv.name, + dtype=tv.dtype, + shape=tv.shape, + dim_tags=tv.dim_tags, + offset=tv.offset, + dim_names=tv.dim_names, + order=tv.order, + alignment=tv.alignment, + address_space=tv.address_space, + is_output=not tv.read_only, + is_input=tv.read_only, + ) + return arg + + +def split_n_across_workgroups(kernel, workgroup_size): + """ + Returns a transformed version of *kernel* with the workload in the loop + with induction variable 'n' distributed across work-groups of size + *workgroup_size* and each work-item in the work-group performing the work + of a single iteration of 'n'. + """ + + kernel = lp.assume(kernel, "start < end") + kernel = lp.split_iname( + kernel, "n", workgroup_size, outer_tag="g.0", inner_tag="l.0" + ) + + # {{{ making consts as globals: necessary to make the strategy emit valid + # kernels for all forms + + old_temps = kernel.temporary_variables.copy() + args_to_make_global = [ + tv.initializer.flatten() + for tv in old_temps.values() + if tv.initializer is not None + ] + + new_temps = {tv.name: tv for tv in old_temps.values() if tv.initializer is None} + kernel = kernel.copy( + args=kernel.args + + [ + _make_tv_array_arg(tv) + for tv in old_temps.values() + if tv.initializer is not None + ], + temporary_variables=new_temps, + ) + + # }}} + + return kernel, args_to_make_global diff --git a/pyop3/device.py b/pyop3/device.py new file mode 100644 index 00000000..f1c0b5d9 --- /dev/null +++ b/pyop3/device.py @@ -0,0 +1,40 @@ +import abc +import contextlib + + +class OffloadingDevice(abc.ABC): + pass + + +class CPUDevice(OffloadingDevice): + pass + + +class GPUDevice(OffloadingDevice, abc.ABC): + def __init__(self, num_threads=32): + self.num_threads = num_threads + + +class CUDADevice(GPUDevice): + pass + + +class OpenCLDevice(GPUDevice): + pass + + +host_device = CPUDevice() +offloading_device = host_device + + +@contextlib.contextmanager +def offloading(device: OffloadingDevice): + global offloading_device + + orig_offloading_device = offloading_device + if not isinstance(orig_offloading_device, CPUDevice): + raise NotImplementedError("Not sure what to do when offloading from not a CPU") + + offloading_device = device + yield + offloading_device = orig_offloading_device diff --git a/pyop3/distarray/multiarray.py b/pyop3/distarray/multiarray.py index 205957f2..43b04f57 100644 --- a/pyop3/distarray/multiarray.py +++ b/pyop3/distarray/multiarray.py @@ -30,6 +30,7 @@ from pyop3.dtypes import IntType, ScalarType, get_mpi_dtype from pyop3.indices import IndexedAxisTree, IndexTree, as_index_forest, index_axes from pyop3.indices.tree import collect_loop_indices +from pyop3.mirrored_array import MirroredArray from pyop3.utils import ( PrettyTuple, UniqueNameGenerator, @@ -109,24 +110,13 @@ def __init__( # # must be context-free # raise TypeError() - if isinstance(data, np.ndarray): - if dtype: - data = np.asarray(data, dtype=dtype) - else: - dtype = data.dtype - elif isinstance(data, Sequence): - data = np.asarray(data, dtype=dtype) - dtype = data.dtype - elif data is None: - if not dtype: - raise ValueError("Must either specify a dtype or provide an array") - dtype = np.dtype(dtype) - data = np.zeros(axes.size, dtype=dtype) + if data is not None: + data = MirroredArray(data, dtype) else: - raise TypeError("data argument not recognised") + data = MirroredArray((axes.size,), dtype) self._data = data - self.dtype = dtype + self.dtype = data.dtype self.temporary_axes = as_axis_tree(axes).freeze() # used for the temporary self.axes = layout_axes(axes) @@ -250,17 +240,17 @@ def data(self): @property def data_rw(self): - return self._data + return self._data.data_rw @property def data_ro(self): # TODO - return self._data + return self._data.data_ro @property def data_wo(self): # TODO - return self._data + return self._data.data_wo @functools.cached_property def datamap(self) -> dict[str:DistributedArray]: diff --git a/pyop3/lang.py b/pyop3/lang.py index 3887bff5..02236987 100644 --- a/pyop3/lang.py +++ b/pyop3/lang.py @@ -419,7 +419,7 @@ def _as_pointer(array: DistributedArray) -> int: @_as_pointer.register def _(array: MultiArray): - return array.data.ctypes.data + return array._data.ptr_rw @_as_pointer.register diff --git a/pyop3/mirrored_array.py b/pyop3/mirrored_array.py new file mode 100644 index 00000000..ed541262 --- /dev/null +++ b/pyop3/mirrored_array.py @@ -0,0 +1,279 @@ +import collections +import functools +import weakref + +import numpy as np + +from pyop3.device import ( + CPUDevice, + CUDADevice, + OpenCLDevice, + host_device, + offloading_device, +) +from pyop3.dtypes import ScalarType + +try: + import pycuda +except ImportError: + pycuda = None + +try: + import pyopencl +except ImportError: + pyopencl = None + + +class NoValidDevicesError(RuntimeError): + pass + + +class InvalidDeviceError(RuntimeError): + pass + + +class MirroredArray: + """An array that is transparently available on different devices.""" + + def __init__(self, data_or_shape, dtype=None): + # option 1: passed shape + if isinstance(data_or_shape, tuple): + shape = data_or_shape + if not dtype: + dtype = ScalarType + data_per_device = weakref.WeakKeyDictionary() + device_validity = collections.defaultdict(bool) + # option 2: passed a numpy array + elif isinstance(data_or_shape, np.ndarray): + data = data_or_shape + shape = data.shape + if not dtype: + dtype = data.dtype + data_per_device = weakref.WeakKeyDictionary( + {host_device: np.asarray(data, dtype)} + ) + device_validity = collections.defaultdict(bool, {host_device: True}) + # option 3: passed a CUDA array + elif pycuda and isinstance(data_or_shape, pycuda.gpuarray.GPUArray): + # TODO I suppose the device could be passed as a kwarg + if not isinstance(offloading_device, CUDADevice): + raise InvalidDeviceError( + "Cannot pass a CUDA array to the MirroredArray constructor " + "outside a CUDA offloading context" + ) + data = data_or_shape + shape = data.shape + if not dtype: + dtype = data.dtype + data_per_device = weakref.WeakKeyDictionary( + {offloading_device: data.astype(dtype)} + ) + device_validity = collections.defaultdict(bool, {offloading_device: True}) + # option 4: passed an OpenCL array + elif pyopencl and isinstance(data_or_shape, pyopencl.array.Array): + if not isinstance(offloading_device, OpenCLDevice): + raise InvalidDeviceError( + "Cannot pass an OpenCL array to the MirroredArray constructor " + "outside an OpenCL offloading context" + ) + data = data_or_shape + shape = data.shape + if not dtype: + dtype = data.dtype + data_per_device = weakref.WeakKeyDictionary( + {offloading_device: data.astype(dtype)} + ) + device_validity = collections.defaultdict(bool, {offloading_device: True}) + else: + raise ValueError("Unexpected arguments encountered") + + self.shape = shape + self.dtype = dtype + self._data_per_device = data_per_device + self._device_validity = device_validity + + # counter used to keep track of modifications + self.state = 0 + + @property + def data(self): + return self.data_rw + + @property + def data_rw(self): + self.state += 1 + # FIXME this needs to come before ensuring validity for now + array = self._device_array(offloading_device) + self._ensure_valid_on_device(offloading_device) + self._invalidate_other_devices(offloading_device) + return self._as_rw_array(array) + + @property + def data_ro(self): + # FIXME this needs to come before ensuring validity for now + array = self._device_array(offloading_device) + self._ensure_valid_on_device(offloading_device) + return self._as_ro_array(array) + + @property + def data_wo(self): + self.state += 1 + # FIXME this needs to come before ensuring validity for now + array = self._device_array(offloading_device) + self._invalidate_other_devices(offloading_device) + self._device_validity[offloading_device] = True + return self._as_wo_array(array) + + @property + def ptr_rw(self): + return self._as_ptr(self.data_rw) + + @property + def ptr_ro(self): + return self._as_ptr(self.data_ro) + + @property + def ptr_wo(self): + return self._as_ptr(self.data_wo) + + @property + def size(self): + return np.prod(self.shape, dtype=int) + + def _ensure_valid_on_device(self, device): + if not self._device_validity[device]: + if self._device_validity[host_device]: + self._host_to_device_copy(device) + else: + self._device_to_host_copy(self._first_valid_device) + self._device_validity[host_device] = True + self._host_to_device_copy(device) + self._device_validity[device] = True + + def _invalidate_other_devices(self, device): + for dev in self._device_validity.keys(): + if dev is not device: + self._device_validity[dev] = False + + @functools.singledispatchmethod + def _host_to_device_copy(self, device): + raise TypeError(f"No handler registered for {type(device).__name__}") + + @_host_to_device_copy.register + def _(self, device: CPUDevice): + if device is host_device: + return + else: + raise NotImplementedError("Cannot offload to other CPUs") + + @_host_to_device_copy.register + def _(self, device: CUDADevice): + self._data_per_device[device].set(self._data_per_device[host_device]) + + @_host_to_device_copy.register + def _(self, device: OpenCLDevice): + self._data_per_device[device].set(self._data_per_device[host_device]) + + @functools.singledispatchmethod + def _device_to_host_copy(self, device): + raise TypeError(f"No handler registered for {type(device).__name__}") + + @_device_to_host_copy.register + def _(self, device: CPUDevice): + if device is host_device: + return + else: + raise NotImplementedError("Cannot offload to other CPUs") + + @_device_to_host_copy.register + def _(self, device: CUDADevice): + self._data_per_device[device].get(self._data_per_device[host_device]) + + @_device_to_host_copy.register + def _(self, device: OpenCLDevice): + self._data_per_device[device].get(self._data_per_device[host_device]) + + @property + def _first_valid_device(self): + for device, valid in self._device_validity.items(): + if valid: + return device + raise NoValidDevicesError("No valid devices found") + + def _device_array(self, device): + try: + return self._data_per_device[device] + except KeyError: + if isinstance(device, CPUDevice): + data = self._alloc_cpu() + elif isinstance(device, CUDADevice): + data = self._alloc_cuda() + elif isinstance(device, OpenCLDevice): + data = self._alloc_opencl(device.queue) + else: + raise AssertionError + + # this is valid if nothing is already there + if not self._device_validity: + self._device_validity[device] = True + + return self._data_per_device.setdefault(offloading_device, data) + + def _alloc_cpu(self): + return np.zeros(self.shape, self.dtype) + + def _alloc_cuda(self): + return pycuda.gpuarray.zeros(shape=self.shape, dtype=self.dtype) + + def _alloc_opencl(self, queue): + return pyopencl.array.zeros(queue, shape=self.shape, dtype=self.dtype) + + @staticmethod + def _as_rw_array(array): + # pycuda and pyopencl are optional dependencies so can't be singledispatch-ed + if isinstance(array, np.ndarray): + rw_array = array.view() + rw_array.setflags(write=True) + elif ( + pycuda + and isinstance(array, pycuda.gpuarray.GPUArray) + or pyopencl + and isinstance(array, pyopencl.array.Array) + ): + rw_array = array + else: + raise TypeError(f"No handler provided for {type(array).__name__}") + return rw_array + + @staticmethod + def _as_ro_array(array): + # pycuda and pyopencl are optional dependencies so can't be singledispatch-ed + if isinstance(array, np.ndarray): + ro_array = array.view() + ro_array.setflags(write=False) + elif ( + pycuda + and isinstance(array, pycuda.gpuarray.GPUArray) + or pyopencl + and isinstance(array, pyopencl.array.Array) + ): + # don't have specific readonly arrays + ro_array = array + else: + raise TypeError(f"No handler provided for {type(array).__name__}") + return ro_array + + @staticmethod + def _as_wo_array(array): + return self._as_rw_array(array) + + @staticmethod + def _as_ptr(array): + if isinstance(array, np.ndarray): + return array.ctypes.data + elif pycuda and isinstance(array, pycuda.gpuarray.GPUArray): + return array.gpudata + elif pyopencl and isinstance(array, pyopencl.array.Array): + return array.data + else: + raise TypeError(f"No handler provided for {type(array).__name__}") diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index c1b1ac7a..e6f3d41e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -2,7 +2,7 @@ import pytest from pyop3 import INC, READ, WRITE, Function, IntType, ScalarType -from pyop3.codegen.ir import LOOPY_LANG_VERSION, LOOPY_TARGET +from pyop3.codegen.ir import loopy_lang_version, loopy_target @pytest.fixture @@ -15,8 +15,8 @@ def scalar_copy_kernel(): lp.GlobalArg("y", ScalarType, (1,), is_input=False, is_output=True), ], name="scalar_copy", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(code, [READ, WRITE]) @@ -31,8 +31,8 @@ def scalar_copy_kernel_int(): lp.GlobalArg("y", IntType, (1,), is_input=False, is_output=True), ], name="scalar_copy_int", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(code, [READ, WRITE]) @@ -47,7 +47,7 @@ def scalar_inc_kernel(): lp.GlobalArg("y", ScalarType, (1,), is_input=True, is_output=True), ], name="scalar_inc", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(lpy_kernel, [READ, INC]) diff --git a/tests/integration/test_access_descriptors.py b/tests/integration/test_access_descriptors.py index 76dd28eb..8a0db44d 100644 --- a/tests/integration/test_access_descriptors.py +++ b/tests/integration/test_access_descriptors.py @@ -16,7 +16,7 @@ ScalarType, do_loop, ) -from pyop3.codegen.ir import LOOPY_LANG_VERSION, LOOPY_TARGET +from pyop3.codegen.ir import loopy_lang_version, loopy_target # NOTE: It is only meaningful to test min/max in parallel as otherwise they behave the @@ -30,8 +30,8 @@ def min_rw_kernel(): lp.GlobalArg("y", ScalarType, (1,), is_input=True, is_output=False), ], name="min_rw", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(code, [MIN_RW, READ]) @@ -46,8 +46,8 @@ def min_write_kernel(): lp.GlobalArg("z", ScalarType, (1,), is_input=True, is_output=False), ], name="min_write", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(code, [MIN_WRITE, READ, READ]) @@ -61,8 +61,8 @@ def max_rw_kernel(): lp.GlobalArg("y", ScalarType, (1,), is_input=True, is_output=False), ], name="max_rw", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(code, [MAX_RW, READ]) @@ -77,8 +77,8 @@ def max_write_kernel(): lp.GlobalArg("z", ScalarType, (1,), is_input=True, is_output=False), ], name="max_write", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(code, [MAX_WRITE, READ, READ]) @@ -104,8 +104,8 @@ def test_pointwise_accesses_descriptors_fail_with_vector_shape(access): "", kernel_data, name="dummy", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) with pytest.raises(ValueError): diff --git a/tests/integration/test_axis_ordering.py b/tests/integration/test_axis_ordering.py index a0964f88..15ddd419 100644 --- a/tests/integration/test_axis_ordering.py +++ b/tests/integration/test_axis_ordering.py @@ -23,7 +23,7 @@ do_loop, loop, ) -from pyop3.codegen.ir import LOOPY_LANG_VERSION, LOOPY_TARGET +from pyop3.codegen.ir import loopy_lang_version, loopy_target from pyop3.utils import flatten, just_one @@ -38,9 +38,9 @@ def test_different_axis_orderings_do_not_change_packing_order(): lp.GlobalArg("x", np.float64, (m1, m2), is_input=True, is_output=False), lp.GlobalArg("y", np.float64, (m1, m2), is_input=False, is_output=True), ], - target=LOOPY_TARGET, name="copy", - lang_version=(2018, 2), + target=loopy_target(), + lang_version=loopy_lang_version(), ) copy_kernel = Function(lpy_kernel, [READ, WRITE]) diff --git a/tests/integration/test_basics.py b/tests/integration/test_basics.py index 2c39d0dc..f873d42d 100644 --- a/tests/integration/test_basics.py +++ b/tests/integration/test_basics.py @@ -24,7 +24,7 @@ do_loop, loop, ) -from pyop3.codegen.ir import LOOPY_LANG_VERSION, LOOPY_TARGET +from pyop3.codegen.ir import loopy_lang_version, loopy_target from pyop3.utils import flatten @@ -37,9 +37,9 @@ def scalar_copy_kernel(): lp.GlobalArg("x", ScalarType, (1,), is_input=True, is_output=False), lp.GlobalArg("y", ScalarType, (1,), is_input=False, is_output=True), ], - target=LOOPY_TARGET, name="scalar_copy", - lang_version=(2018, 2), + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(code, [READ, WRITE]) @@ -53,9 +53,9 @@ def vector_copy_kernel(): lp.GlobalArg("x", ScalarType, (3,), is_input=True, is_output=False), lp.GlobalArg("y", ScalarType, (3,), is_input=False, is_output=True), ], - target=LOOPY_TARGET, name="vector_copy", - lang_version=(2018, 2), + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(code, [READ, WRITE]) @@ -203,9 +203,9 @@ def test_inc_with_shared_global_value(): "{ [i]: 0 <= i < 3 }", "x[i] = x[i] + 1", [lp.GlobalArg("x", ScalarType, (3,), is_input=True, is_output=True)], - target=LOOPY_TARGET, name="plus_one", - lang_version=(2018, 2), + target=loopy_target(), + lang_version=loopy_lang_version(), ) plus_one = Function(knl, [INC]) diff --git a/tests/integration/test_maps.py b/tests/integration/test_maps.py index daab6491..2926da14 100644 --- a/tests/integration/test_maps.py +++ b/tests/integration/test_maps.py @@ -23,7 +23,7 @@ do_loop, loop, ) -from pyop3.codegen.ir import LOOPY_LANG_VERSION, LOOPY_TARGET +from pyop3.codegen.ir import loopy_lang_version, loopy_target from pyop3.utils import flatten @@ -37,8 +37,8 @@ def vector_inc_kernel(): lp.GlobalArg("y", ScalarType, (1,), is_input=True, is_output=True), ], name="vector_inc", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(lpy_kernel, [READ, INC]) @@ -53,8 +53,8 @@ def vec2_inc_kernel(): lp.GlobalArg("y", ScalarType, (2,), is_input=True, is_output=True), ], name="vec2_inc", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(lpy_kernel, [READ, INC]) @@ -69,8 +69,8 @@ def vec6_inc_kernel(): lp.GlobalArg("y", ScalarType, (1,), is_input=True, is_output=True), ], name="vector_inc", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(code, [READ, INC]) @@ -85,8 +85,8 @@ def vec12_inc_kernel(): lp.GlobalArg("y", ScalarType, (2,), is_input=True, is_output=True), ], name="vector_inc", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(code, [READ, INC]) @@ -414,7 +414,7 @@ def test_map_composition(vec2_inc_kernel): mapaxes0 = iterset.add_node(Axis(arity0), *iterset.leaf) mapdata0 = np.asarray([[2, 4, 0], [6, 7, 1]], dtype=int) - maparray0 = MultiArray(mapaxes0, name="map0", data=flatten(mapdata0)) + maparray0 = MultiArray(mapaxes0, name="map0", data=mapdata0.flatten()) map0 = Map( { pmap({iterset.root.label: "cpt0"}): [ @@ -543,8 +543,8 @@ def test_recursive_multi_component_maps(): lp.GlobalArg("y", ScalarType, (1,), is_input=False, is_output=True), ], name="sum_kernel", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) sum_kernel = Function(lpy_kernel, [READ, WRITE]) @@ -620,8 +620,8 @@ def test_sum_with_consecutive_maps(): lp.GlobalArg("y", ScalarType, (1,), is_input=False, is_output=True), ], name="sum", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) sum_kernel = Function(lpy_kernel, [READ, WRITE]) diff --git a/tests/integration/test_permuted.py b/tests/integration/test_permuted.py index 18bcef0c..6a4d9026 100644 --- a/tests/integration/test_permuted.py +++ b/tests/integration/test_permuted.py @@ -20,7 +20,7 @@ do_loop, loop, ) -from pyop3.codegen.ir import LOOPY_LANG_VERSION, LOOPY_TARGET +from pyop3.codegen.ir import loopy_lang_version, loopy_target from pyop3.utils import flatten @@ -33,9 +33,9 @@ def scalar_copy_kernel(): lp.GlobalArg("x", ScalarType, (1,), is_input=True, is_output=False), lp.GlobalArg("y", ScalarType, (1,), is_input=False, is_output=True), ], - target=LOOPY_TARGET, + target=loopy_target(), name="scalar_copy", - lang_version=(2018, 2), + lang_version=loopy_lang_version(), ) return Function(code, [READ, WRITE]) @@ -49,9 +49,9 @@ def vector_copy_kernel(): lp.GlobalArg("x", ScalarType, (3,), is_input=True, is_output=False), lp.GlobalArg("y", ScalarType, (3,), is_input=False, is_output=True), ], - target=LOOPY_TARGET, + target=loopy_target(), name="vector_copy", - lang_version=(2018, 2), + lang_version=loopy_lang_version(), ) return Function(code, [READ, WRITE]) diff --git a/tests/integration/test_petscmat.py b/tests/integration/test_petscmat.py index e8c96c4c..9dded3ca 100644 --- a/tests/integration/test_petscmat.py +++ b/tests/integration/test_petscmat.py @@ -4,7 +4,7 @@ from pyrsistent import pmap import pyop3 as op3 -from pyop3.codegen.ir import LOOPY_LANG_VERSION, LOOPY_TARGET +from pyop3.codegen.ir import loopy_lang_version, loopy_target from pyop3.utils import flatten @@ -147,8 +147,8 @@ def test_read_matrix_values(): lp.GlobalArg("array", array.dtype, (1,), is_input=False, is_output=True), ], name="inc", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) inc = op3.Function(lpy_kernel, [op3.READ, op3.INC]) op3.do_loop( diff --git a/tests/integration/test_ragged.py b/tests/integration/test_ragged.py index 2b41fa0e..0de51dd8 100644 --- a/tests/integration/test_ragged.py +++ b/tests/integration/test_ragged.py @@ -20,7 +20,7 @@ do_loop, loop, ) -from pyop3.codegen.ir import LOOPY_LANG_VERSION, LOOPY_TARGET +from pyop3.codegen.ir import loopy_lang_version, loopy_target from pyop3.utils import flatten @@ -33,9 +33,9 @@ def scalar_copy_kernel(): lp.GlobalArg("x", ScalarType, (1,), is_input=True, is_output=False), lp.GlobalArg("y", ScalarType, (1,), is_input=False, is_output=True), ], - target=LOOPY_TARGET, name="scalar_copy", - lang_version=(2018, 2), + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(code, [READ, WRITE]) diff --git a/tests/integration/test_slice_composition.py b/tests/integration/test_slice_composition.py index ae65e1f8..e4dd900e 100644 --- a/tests/integration/test_slice_composition.py +++ b/tests/integration/test_slice_composition.py @@ -23,7 +23,7 @@ do_loop, loop, ) -from pyop3.codegen.ir import LOOPY_LANG_VERSION, LOOPY_TARGET +from pyop3.codegen.ir import loopy_lang_version, loopy_target from pyop3.utils import flatten @@ -37,8 +37,8 @@ def vec2_copy_kernel(): lp.GlobalArg("y", ScalarType, (2,), is_input=False, is_output=True), ], name="copy", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=loopy_target(), + lang_version=loopy_lang_version(), ) return Function(lpy_kernel, [READ, WRITE]) diff --git a/tests/integration/test_subsets.py b/tests/integration/test_subsets.py index ee973eb3..a92976e3 100644 --- a/tests/integration/test_subsets.py +++ b/tests/integration/test_subsets.py @@ -23,7 +23,7 @@ do_loop, loop, ) -from pyop3.codegen.ir import LOOPY_LANG_VERSION, LOOPY_TARGET +from pyop3.codegen.ir import loopy_lang_version, loopy_target from pyop3.utils import flatten