Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 111 additions & 39 deletions pykokkos/core/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self):

# caches the result of CppSetup.is_compiled(path)
self.is_compiled_cache: Dict[str, bool] = {}
self.is_current_cache: Dict[str, bool] = {}
self.parser_cache: Dict[str, Parser] = {}

self.functor_file: str = "functor.hpp"
Expand Down Expand Up @@ -172,26 +173,6 @@ def compile_object(
}:
raise Exception(f"Types are required for style: {entity.style}")

if self.is_compiled(module_setup.output_dir):
if hash not in self.members: # True if pre-compiled
if len(metadata) > 1:
entity, classtypes = self.fuse_objects(
metadata, fuse_ASTs=True, **kwargs
)

if types_inferred:
entity.AST = parser.fix_types(entity, updated_types)
if decorator_inferred:
entity.AST = parser.fix_decorator(entity, updated_decorator)
self.members[hash] = self.extract_members(entity, classtypes)

return self.members[hash]

if len(metadata) > 1:
entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=True, **kwargs)

self.is_compiled_cache[module_setup.output_dir] = True

members: PyKokkosMembers

if types_inferred:
Expand All @@ -205,6 +186,32 @@ def compile_object(
members = self.extract_members(entity, classtypes)
self.members[hash] = members

if len(metadata) > 1:
entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=True, **kwargs)

if self.is_compiled(module_setup.output_dir):
if hash not in self.members: # True if pre-compiled
if types_inferred:
entity.AST = parser.fix_types(entity, updated_types)
if decorator_inferred:
entity.AST = parser.fix_decorator(entity, updated_decorator)
self.members[hash] = self.extract_members(entity, classtypes)

if self.is_current(
module_setup, members, entity, classtypes, restrict_views
):
return self.members[hash]
else:
# remove out-of-date directories
for child in module_setup.output_dir.iterdir():
child.unlink()
module_setup.output_dir.rmdir()

# reset is_current_cache
self.is_current_cache[module_setup.output_dir] = True

self.is_compiled_cache[module_setup.output_dir] = True

self.compile_entity(
module_setup.main,
module_setup,
Expand Down Expand Up @@ -232,7 +239,7 @@ def compile_entity(
Compile the entity

:param main: the path to the main file in the current PyKokkos application
:param metadata: the metadata of the entity being compiled
:param module_setup: the module_setup object containing module info
:param entity: the parsed entity being compiled
:param classtypes: the list of parsed classtypes being compiled
:param space: the execution space to compile for
Expand All @@ -251,26 +258,10 @@ def compile_entity(
return

cpp_setup = CppSetup(module_setup.module_file, module_setup.gpu_module_files)
translator = StaticTranslator(
module_setup.name, self.functor_file, self.functor_cast_file, members
)
t_start: float = time.perf_counter()
functor: List[str]
bindings: List[str]
cast: List[str]

if entity.style in {PyKokkosStyles.workunit, PyKokkosStyles.fused}:
if "PK_LOOP_FUSE" in os.environ:
loop_fuse(entity.AST)
if "PK_MEM_FUSE" in os.environ:
memory_ops_fuse(entity.AST, entity.pk_import)
functor, bindings, cast = translator.translate(
entity, classtypes, restrict_views
functor, bindings, cast = self.translate_entity(
module_setup, members, entity, classtypes, restrict_views
)

t_end: float = time.perf_counter() - t_start
self.logger.info(f"translation {t_end}")

output_dir: Path = module_setup.get_output_dir(
main,
module_setup.metadata,
Expand All @@ -294,6 +285,44 @@ def compile_entity(
c_end: float = time.perf_counter() - c_start
self.logger.info(f"compilation {c_end}")

def translate_entity(
self,
module_setup: ModuleSetup,
members: PyKokkosMembers,
entity: PyKokkosEntity,
classtypes: List[PyKokkosEntity],
restrict_views: Set[str],
):
"""
Translate the entity

:param module_setup: the module_setup object containing module info
:param members: the PyKokkos related members of the entity
:param entity: the parsed entity being compiled
:param classtypes: the list of parsed classtypes being compiled
:param restrict_views: a set of view names that do not alias any other views
"""
translator = StaticTranslator(
module_setup.name, self.functor_file, self.functor_cast_file, members
)
t_start: float = time.perf_counter()
functor: List[str]
bindings: List[str]
cast: List[str]

if entity.style in {PyKokkosStyles.workunit, PyKokkosStyles.fused}:
if "PK_LOOP_FUSE" in os.environ:
loop_fuse(entity.AST)
if "PK_MEM_FUSE" in os.environ:
memory_ops_fuse(entity.AST, entity.pk_import)
functor, bindings, cast = translator.translate(
entity, classtypes, restrict_views
)

t_end: float = time.perf_counter() - t_start
self.logger.info(f"translation {t_end}")
return functor, bindings, cast

def compile_raw_source(
self,
output_dir: Path,
Expand Down Expand Up @@ -443,6 +472,49 @@ def is_compiled(self, output_dir: str) -> bool:

return is_compiled

def is_current(
self,
module_setup: ModuleSetup,
members: PyKokkosMembers,
entity: PyKokkosEntity,
classtypes: List[PyKokkosEntity],
restrict_views: Set[str],
) -> bool:
"""
Check if the entity has changed since the last
compilation. The result will be cached,
as accessing the filesystem is costly.

:param module_setup: the module_setup object containing module info
:param members: the PyKokkos related members of the entity
:param entity: the parsed entity being compiled
:param classtypes: the list of parsed classtypes being compiled
:param restrict_views: a set of view names that do not alias any other views
:returns: True if current functor matches compiled functor.
"""
is_current: bool = False

output_dir = module_setup.output_dir
if output_dir in self.is_current_cache:
return self.is_current_cache[output_dir]

# get current functor
current_functor_file = output_dir.parent / self.functor_file
current_functor_string = current_functor_file.read_text()

# get translated functor
functor, _, _ = self.translate_entity(
module_setup, members, entity, classtypes, restrict_views
)
functor_string = "\n".join(functor)

# check that they are equal
is_current = current_functor_string == functor_string

self.is_current_cache[output_dir] = is_current

return is_current

def get_parser(self, path: str) -> Parser:
"""
Get the parser for a particular file
Expand Down
Loading