Skip to content

Support for Nested or Dynamically Defined @pk.workunits #299

@kennykos

Description

@kennykos

Description

I am trying to define a generic outer-level @pk.workunit that calls a user-defined @pk.function, similar to how you can pass functors in Kokkos C++. The goal is to enable modular code where users define their own computation kernels as @pk.functions and plug them into a larger parallel loop.

This pattern is standard in Kokkos C++ via templates. For example:

struct Printer {
  KOKKOS_INLINE_FUNCTION
  void operator()(int i) const {
    printf("Index: %d\n", i);
  }
};

template <typename Functor>
struct Hello {
  Functor fun;
  KOKKOS_INLINE_FUNCTION
  void operator()(int i) const {
    fun(i);
  }
};

When trying to mimic this in PyKokkos:

import pykokkos as pk

@pk.function
def printer(i: int):
    print(f"Index: {i}")

def make_hello(functor):
    @pk.workunit
    def hello(i: int):
        functor(i)
    return hello

def main():
    pk.initialize()
    hello = make_hello(printer)
    pk.parallel_for("hello_loop", 10, hello)
    pk.finalize()

I encounter a runtime error:

RuntimeError: Entity 'hello' not found by parser

This suggests that the PyKokkos parser cannot find or handle internally defined (i.e., dynamically scoped) workunits.

Request

Enable PyKokkos to support @pk.workunits that are defined within functions or dynamically returned. This would enable patterns like:

def make_workunit(fn: Callable[[int], None]):
    @pk.workunit
    def wrapper(i: int):
        fn(i)
    return wrapper

Questions

  • Is it feasible for the PyKokkos parser to support parsing such internally-defined workunits?
  • Would it be possible to register workunits dynamically with a name/AST hook?
  • Could a workaround be supported in the short term — e.g., explicitly registering the inner workunit?

Reference code and errors

Details

Kokkos:

#include <Kokkos_Core.hpp>
#include <iostream>

struct Printer {
  KOKKOS_INLINE_FUNCTION
  void operator()(int i) const {
    printf("Index: %d\n", i);
  }
};

// Define a work unit that takes a callable
template <typename Functor>
struct Hello {
  Functor fun;

  Hello(Functor f) : fun(f) {}

  KOKKOS_INLINE_FUNCTION
  void operator()(int i) const {
    fun(i);
  }
};

int main(int argc, char* argv[]) {
  Kokkos::initialize(argc, argv);
  {
    Printer printer;
    Hello<Printer> hello(printer);

    Kokkos::parallel_for("hello_loop", 10, hello);
  }
  Kokkos::finalize();
}

PyKokkos

import pykokkos as pk

# Define the printer function
@pk.function
def printer(i: int):
    print(f"Index: {i}")

# Define the hello wrapper that accepts a functor
def make_hello(functor):
    @pk.workunit
    def hello(i: int):
        functor(i)
    return hello

def main():
    pk.initialize()

    # Wrap printer inside hello
    hello = make_hello(printer)

    pk.parallel_for("hello_loop", 10, hello)

    pk.finalize()

if __name__ == "__main__":
    main()

stdout

Traceback (most recent call last):
  File "/work/09661/gkk345/vista/kokkos/example/tutorial/01_hello_world/hello_world.py", line 26, in <module>
    main()
  File "/work/09661/gkk345/vista/kokkos/example/tutorial/01_hello_world/hello_world.py", line 21, in main
    pk.parallel_for("hello_loop", 10, hello)
  File "/work/09661/gkk345/vista/pykokkos/pykokkos/interface/parallel_dispatch.py", line 158, in parallel_for
    runtime_singleton.runtime.run_workunit(
  File "/work/09661/gkk345/vista/pykokkos/pykokkos/core/runtime.py", line 153, in run_workunit
    return self.execute_workunit(name, policy, workunit, operation, parser, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/vista/pykokkos/pykokkos/core/runtime.py", line 181, in execute_workunit
    updated_types, updated_decorator, types_signature = get_type_info(operation, parser, policy, workunit, kwargs)
                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/vista/pykokkos/pykokkos/core/type_inference/args_type_inference.py", line 419, in get_type_info
    this_tree = this_parser.get_entity(this_metadata.name).AST
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/vista/pykokkos/pykokkos/core/parsers/parser.py", line 107, in get_entity
    raise RuntimeError(f"Entity '{name}' not found by parser")
RuntimeError: Entity 'hello' not found by parser

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions