diff --git a/pykokkos/core/compiler.py b/pykokkos/core/compiler.py index 7f9c86f5..ead7f1fa 100644 --- a/pykokkos/core/compiler.py +++ b/pykokkos/core/compiler.py @@ -42,6 +42,7 @@ def __init__(self): # caches the result of CppSetup.is_compiled(path) self.is_compiled_cache: Dict[str, bool] = {} + self.is_current_cache: Dict[str, bool] = {} self.parser_cache: Dict[str, Parser] = {} self.functor_file: str = "functor.hpp" @@ -172,26 +173,6 @@ def compile_object( }: raise Exception(f"Types are required for style: {entity.style}") - if self.is_compiled(module_setup.output_dir): - if hash not in self.members: # True if pre-compiled - if len(metadata) > 1: - entity, classtypes = self.fuse_objects( - metadata, fuse_ASTs=True, **kwargs - ) - - if types_inferred: - entity.AST = parser.fix_types(entity, updated_types) - if decorator_inferred: - entity.AST = parser.fix_decorator(entity, updated_decorator) - self.members[hash] = self.extract_members(entity, classtypes) - - return self.members[hash] - - if len(metadata) > 1: - entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=True, **kwargs) - - self.is_compiled_cache[module_setup.output_dir] = True - members: PyKokkosMembers if types_inferred: @@ -205,6 +186,32 @@ def compile_object( members = self.extract_members(entity, classtypes) self.members[hash] = members + if len(metadata) > 1: + entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=True, **kwargs) + + if self.is_compiled(module_setup.output_dir): + if hash not in self.members: # True if pre-compiled + if types_inferred: + entity.AST = parser.fix_types(entity, updated_types) + if decorator_inferred: + entity.AST = parser.fix_decorator(entity, updated_decorator) + self.members[hash] = self.extract_members(entity, classtypes) + + if self.is_current( + module_setup, members, entity, classtypes, restrict_views + ): + return self.members[hash] + else: + # remove out-of-date directories + for child in module_setup.output_dir.iterdir(): + child.unlink() + module_setup.output_dir.rmdir() + + # reset is_current_cache + self.is_current_cache[module_setup.output_dir] = True + + self.is_compiled_cache[module_setup.output_dir] = True + self.compile_entity( module_setup.main, module_setup, @@ -232,7 +239,7 @@ def compile_entity( Compile the entity :param main: the path to the main file in the current PyKokkos application - :param metadata: the metadata of the entity being compiled + :param module_setup: the module_setup object containing module info :param entity: the parsed entity being compiled :param classtypes: the list of parsed classtypes being compiled :param space: the execution space to compile for @@ -251,26 +258,10 @@ def compile_entity( return cpp_setup = CppSetup(module_setup.module_file, module_setup.gpu_module_files) - translator = StaticTranslator( - module_setup.name, self.functor_file, self.functor_cast_file, members - ) - t_start: float = time.perf_counter() - functor: List[str] - bindings: List[str] - cast: List[str] - - if entity.style in {PyKokkosStyles.workunit, PyKokkosStyles.fused}: - if "PK_LOOP_FUSE" in os.environ: - loop_fuse(entity.AST) - if "PK_MEM_FUSE" in os.environ: - memory_ops_fuse(entity.AST, entity.pk_import) - functor, bindings, cast = translator.translate( - entity, classtypes, restrict_views + functor, bindings, cast = self.translate_entity( + module_setup, members, entity, classtypes, restrict_views ) - t_end: float = time.perf_counter() - t_start - self.logger.info(f"translation {t_end}") - output_dir: Path = module_setup.get_output_dir( main, module_setup.metadata, @@ -294,6 +285,44 @@ def compile_entity( c_end: float = time.perf_counter() - c_start self.logger.info(f"compilation {c_end}") + def translate_entity( + self, + module_setup: ModuleSetup, + members: PyKokkosMembers, + entity: PyKokkosEntity, + classtypes: List[PyKokkosEntity], + restrict_views: Set[str], + ): + """ + Translate the entity + + :param module_setup: the module_setup object containing module info + :param members: the PyKokkos related members of the entity + :param entity: the parsed entity being compiled + :param classtypes: the list of parsed classtypes being compiled + :param restrict_views: a set of view names that do not alias any other views + """ + translator = StaticTranslator( + module_setup.name, self.functor_file, self.functor_cast_file, members + ) + t_start: float = time.perf_counter() + functor: List[str] + bindings: List[str] + cast: List[str] + + if entity.style in {PyKokkosStyles.workunit, PyKokkosStyles.fused}: + if "PK_LOOP_FUSE" in os.environ: + loop_fuse(entity.AST) + if "PK_MEM_FUSE" in os.environ: + memory_ops_fuse(entity.AST, entity.pk_import) + functor, bindings, cast = translator.translate( + entity, classtypes, restrict_views + ) + + t_end: float = time.perf_counter() - t_start + self.logger.info(f"translation {t_end}") + return functor, bindings, cast + def compile_raw_source( self, output_dir: Path, @@ -443,6 +472,49 @@ def is_compiled(self, output_dir: str) -> bool: return is_compiled + def is_current( + self, + module_setup: ModuleSetup, + members: PyKokkosMembers, + entity: PyKokkosEntity, + classtypes: List[PyKokkosEntity], + restrict_views: Set[str], + ) -> bool: + """ + Check if the entity has changed since the last + compilation. The result will be cached, + as accessing the filesystem is costly. + + :param module_setup: the module_setup object containing module info + :param members: the PyKokkos related members of the entity + :param entity: the parsed entity being compiled + :param classtypes: the list of parsed classtypes being compiled + :param restrict_views: a set of view names that do not alias any other views + :returns: True if current functor matches compiled functor. + """ + is_current: bool = False + + output_dir = module_setup.output_dir + if output_dir in self.is_current_cache: + return self.is_current_cache[output_dir] + + # get current functor + current_functor_file = output_dir.parent / self.functor_file + current_functor_string = current_functor_file.read_text() + + # get translated functor + functor, _, _ = self.translate_entity( + module_setup, members, entity, classtypes, restrict_views + ) + functor_string = "\n".join(functor) + + # check that they are equal + is_current = current_functor_string == functor_string + + self.is_current_cache[output_dir] = is_current + + return is_current + def get_parser(self, path: str) -> Parser: """ Get the parser for a particular file