1- import difflib
1+ import importlib
22import inspect
33import json
44import os
55import pkgutil
66import re
7+ import subprocess
78import warnings
89from abc import abstractmethod
910from typing import Any , Dict , List , Optional , Tuple , Union , final
2223 separate_inside_and_outside_square_brackets ,
2324)
2425from .settings_utils import get_constants , get_settings
25- from .text_utils import camel_to_snake_case , is_camel_case
26+ from .text_utils import camel_to_snake_case , is_camel_case , snake_to_camel_case
2627from .type_utils import isoftype , issubtype
2728from .utils import (
2829 artifacts_json_cache ,
3637constants = get_constants ()
3738
3839
40+ def import_module_from_file (file_path ):
41+ # Get the module name (file name without extension)
42+ module_name = os .path .splitext (os .path .basename (file_path ))[0 ]
43+ # Create a module specification
44+ spec = importlib .util .spec_from_file_location (module_name , file_path )
45+ # Create a new module based on the specification
46+ module = importlib .util .module_from_spec (spec )
47+ spec .loader .exec_module (module )
48+ return module
49+
50+
51+ # type is read from a catelog entry, the value of a key "__type__"
52+ def get_class_from_artifact_type (type : str ):
53+ if type in Artifact ._class_register :
54+ return Artifact ._class_register [type ]
55+
56+ module_path , class_name = find_unitxt_module_and_class_by_classname (
57+ snake_to_camel_case (type )
58+ )
59+ if module_path == "class_register" :
60+ if class_name not in Artifact ._class_register :
61+ raise ValueError (
62+ f"Can not instantiate a class from type { type } , because { class_name } is currently not registered in Artifact._class_register."
63+ )
64+ return Artifact ._class_register [class_name ]
65+
66+ module = importlib .import_module (module_path )
67+
68+ if "." not in class_name :
69+ if hasattr (module , class_name ) and inspect .isclass (getattr (module , class_name )):
70+ return getattr (module , class_name )
71+ if class_name in Artifact ._class_register :
72+ return Artifact ._class_register [class_name ]
73+ module_file = module .__file__ if hasattr (module , "__file__" ) else None
74+ if module_file :
75+ module = import_module_from_file (module_file )
76+
77+ assert class_name in Artifact ._class_register
78+ return Artifact ._class_register [class_name ]
79+
80+ class_name_components = class_name .split ("." )
81+ klass = getattr (module , class_name_components [0 ])
82+ for i in range (1 , len (class_name_components )):
83+ klass = getattr (klass , class_name_components [i ])
84+ return klass
85+
86+
3987def is_name_legal_for_catalog (name ):
4088 return re .match (r"^[\w" + constants .catalog_hierarchy_sep + "]+$" , name )
4189
@@ -133,21 +181,10 @@ def maybe_recover_artifacts_structure(obj):
133181 return obj
134182
135183
136- def get_closest_artifact_type (type ):
137- artifact_type_options = list (Artifact ._class_register .keys ())
138- matches = difflib .get_close_matches (type , artifact_type_options )
139- if matches :
140- return matches [0 ] # Return the closest match
141- return None
142-
143-
144184class UnrecognizedArtifactTypeError (ValueError ):
145185 def __init__ (self , type ) -> None :
146186 maybe_class = "" .join (word .capitalize () for word in type .split ("_" ))
147187 message = f"'{ type } ' is not a recognized artifact 'type'. Make sure a the class defined this type (Probably called '{ maybe_class } ' or similar) is defined and/or imported anywhere in the code executed."
148- closest_artifact_type = get_closest_artifact_type (type )
149- if closest_artifact_type is not None :
150- message += f"\n \n Did you mean '{ closest_artifact_type } '?"
151188 super ().__init__ (message )
152189
153190
@@ -200,8 +237,6 @@ def verify_artifact_dict(cls, d):
200237 )
201238 if "__type__" not in d :
202239 raise MissingArtifactTypeError (d )
203- if not cls .is_registered_type (d ["__type__" ]):
204- raise UnrecognizedArtifactTypeError (d ["__type__" ])
205240
206241 @classmethod
207242 def get_artifact_type (cls ):
@@ -218,13 +253,6 @@ def register_class(cls, artifact_class):
218253
219254 snake_case_key = camel_to_snake_case (artifact_class .__name__ )
220255
221- if cls .is_registered_type (snake_case_key ):
222- assert (
223- str (cls ._class_register [snake_case_key ]) == str (artifact_class )
224- ), f"Artifact class name must be unique, '{ snake_case_key } ' already exists for { cls ._class_register [snake_case_key ]} . Cannot be overridden by { artifact_class } ."
225-
226- return snake_case_key
227-
228256 cls ._class_register [snake_case_key ] = artifact_class
229257
230258 return snake_case_key
@@ -241,19 +269,6 @@ def is_artifact_file(cls, path):
241269 d = json .load (f )
242270 return cls .is_artifact_dict (d )
243271
244- @classmethod
245- def is_registered_type (cls , type : str ):
246- return type in cls ._class_register
247-
248- @classmethod
249- def is_registered_class_name (cls , class_name : str ):
250- snake_case_key = camel_to_snake_case (class_name )
251- return cls .is_registered_type (snake_case_key )
252-
253- @classmethod
254- def is_registered_class (cls , clz : object ):
255- return clz in set (cls ._class_register .values ())
256-
257272 @classmethod
258273 def _recursive_load (cls , obj ):
259274 if isinstance (obj , dict ):
@@ -267,7 +282,7 @@ def _recursive_load(cls, obj):
267282 pass
268283 if cls .is_artifact_dict (obj ):
269284 cls .verify_artifact_dict (obj )
270- artifact_class = cls . _class_register [ obj .pop ("__type__" )]
285+ artifact_class = get_class_from_artifact_type ( obj .pop ("__type__" ))
271286 obj = artifact_class .process_data_after_load (obj )
272287 return artifact_class (** obj )
273288
@@ -684,3 +699,29 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]:
684699 return None
685700
686701 return data_classification .get (artifact )
702+
703+
704+ def find_unitxt_module_and_class_by_classname (camel_case_class_name : str ):
705+ """Find a module, a member of src/unitxt, that contains the definition of the class."""
706+ dir = os .path .dirname (__file__ ) # dir src/unitxt
707+ try :
708+ result = subprocess .run (
709+ ["grep" , "-irwE" , "^class +" + camel_case_class_name , dir ],
710+ capture_output = True ,
711+ ).stdout .decode ("ascii" )
712+ results = result .split ("\n " )
713+ assert len (results ) == 2 , f"returned: { results } "
714+ assert results [- 1 ] == "" , f"last result is { results [- 1 ]} rather than ''"
715+ to_return_module = (
716+ results [0 ].split (":" )[0 ][:- 3 ].replace ("/" , "." )
717+ ) # trim the .py and replace
718+ to_return_class_name = results [0 ].split (":" )[1 ][
719+ 6 : 6 + len (camel_case_class_name )
720+ ]
721+ return to_return_module [
722+ to_return_module .rfind ("unitxt." ) :
723+ ], to_return_class_name
724+ except Exception as e :
725+ raise ValueError (
726+ f"Could not find the unitxt module, under unitxt/src/unitxt, in which class { camel_case_class_name } is defined"
727+ ) from e
0 commit comments