Skip to content

Commit 5af1edf

Browse files
committed
light fast removal of register_all_artifacts for unitxt classes
Signed-off-by: dafnapension <[email protected]>
1 parent 95ad743 commit 5af1edf

File tree

12 files changed

+187
-79
lines changed

12 files changed

+187
-79
lines changed

docs/catalog.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pygments import highlight
1111
from pygments.formatters import HtmlFormatter
1212
from pygments.lexers import PythonLexer
13-
from unitxt.artifact import Artifact
13+
from unitxt.artifact import get_class_from_artifact_type
1414
from unitxt.text_utils import print_dict_as_python
1515
from unitxt.utils import load_json
1616

@@ -51,7 +51,7 @@ def imports_to_syntax_highlighted_html(subtypes: List[str]) -> str:
5151
return ""
5252
module_to_class_names = defaultdict(list)
5353
for subtype in subtypes:
54-
subtype_class = Artifact._class_register.get(subtype)
54+
subtype_class = get_class_from_artifact_type(subtype)
5555
module_to_class_names[subtype_class.__module__].append(subtype_class.__name__)
5656

5757
imports_txt = ""
@@ -150,7 +150,7 @@ def recursive_search(d):
150150

151151
@lru_cache(maxsize=None)
152152
def artifact_type_to_link(artifact_type):
153-
artifact_class = Artifact._class_register.get(artifact_type)
153+
artifact_class = get_class_from_artifact_type(artifact_type)
154154
type_class_name = artifact_class.__name__
155155
artifact_class_id = f"{artifact_class.__module__}.{type_class_name}"
156156
return f'<a class="reference internal" href="../{artifact_class.__module__}.html#{artifact_class_id}" title="{artifact_class_id}"><code class="xref py py-class docutils literal notranslate"><span class="pre">{type_class_name}</span></code></a>'
@@ -159,7 +159,7 @@ def artifact_type_to_link(artifact_type):
159159
# flake8: noqa: C901
160160
def make_content(artifact, label, all_labels):
161161
artifact_type = artifact["__type__"]
162-
artifact_class = Artifact._class_register.get(artifact_type)
162+
artifact_class = get_class_from_artifact_type(artifact_type)
163163
type_class_name = artifact_class.__name__
164164
catalog_id = label.replace("catalog.", "")
165165

@@ -243,7 +243,7 @@ def make_content(artifact, label, all_labels):
243243
result += artifact_class.__doc__ + "\n"
244244

245245
for subtype in subtypes:
246-
subtype_class = Artifact._class_register.get(subtype)
246+
subtype_class = get_class_from_artifact_type(subtype)
247247
subtype_class_name = subtype_class.__name__
248248
if subtype_class.__doc__:
249249
explanation_str = f"Explanation about `{subtype_class_name}`"

prepare/cards/mtrag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105
card = TaskCard(
106106
loader=LoadJsonFile(
107107
files={
108-
"test": f"https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/{subset}.jsonl.zip"
108+
"test": f"https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/document_level/{subset}.jsonl.zip"
109109
},
110110
compression="zip",
111111
lines=True,

src/unitxt/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
)
1212
from .catalog import add_to_catalog, get_from_catalog
1313
from .logging_utils import get_logger
14-
from .register import register_all_artifacts, register_local_catalog
14+
from .register import ProjectArtifactRegisterer, register_local_catalog
1515
from .settings_utils import get_constants, get_settings
1616

17-
register_all_artifacts()
17+
ProjectArtifactRegisterer()
1818
random.seed(0)
1919

2020
constants = get_constants()

src/unitxt/artifact.py

Lines changed: 77 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import difflib
1+
import importlib
22
import inspect
33
import json
44
import os
55
import pkgutil
66
import re
7+
import subprocess
78
import warnings
89
from abc import abstractmethod
910
from typing import Any, Dict, List, Optional, Tuple, Union, final
@@ -22,7 +23,7 @@
2223
separate_inside_and_outside_square_brackets,
2324
)
2425
from .settings_utils import get_constants, get_settings
25-
from .text_utils import camel_to_snake_case, is_camel_case
26+
from .text_utils import camel_to_snake_case, is_camel_case, snake_to_camel_case
2627
from .type_utils import isoftype, issubtype
2728
from .utils import (
2829
artifacts_json_cache,
@@ -36,6 +37,53 @@
3637
constants = get_constants()
3738

3839

40+
def import_module_from_file(file_path):
41+
# Get the module name (file name without extension)
42+
module_name = os.path.splitext(os.path.basename(file_path))[0]
43+
# Create a module specification
44+
spec = importlib.util.spec_from_file_location(module_name, file_path)
45+
# Create a new module based on the specification
46+
module = importlib.util.module_from_spec(spec)
47+
spec.loader.exec_module(module)
48+
return module
49+
50+
51+
# type is read from a catelog entry, the value of a key "__type__"
52+
def get_class_from_artifact_type(type: str):
53+
if type in Artifact._class_register:
54+
return Artifact._class_register[type]
55+
56+
module_path, class_name = find_unitxt_module_and_class_by_classname(
57+
snake_to_camel_case(type)
58+
)
59+
if module_path == "class_register":
60+
if class_name not in Artifact._class_register:
61+
raise ValueError(
62+
f"Can not instantiate a class from type {type}, because {class_name} is currently not registered in Artifact._class_register."
63+
)
64+
return Artifact._class_register[class_name]
65+
66+
module = importlib.import_module(module_path)
67+
68+
if "." not in class_name:
69+
if hasattr(module, class_name) and inspect.isclass(getattr(module, class_name)):
70+
return getattr(module, class_name)
71+
if class_name in Artifact._class_register:
72+
return Artifact._class_register[class_name]
73+
module_file = module.__file__ if hasattr(module, "__file__") else None
74+
if module_file:
75+
module = import_module_from_file(module_file)
76+
77+
assert class_name in Artifact._class_register
78+
return Artifact._class_register[class_name]
79+
80+
class_name_components = class_name.split(".")
81+
klass = getattr(module, class_name_components[0])
82+
for i in range(1, len(class_name_components)):
83+
klass = getattr(klass, class_name_components[i])
84+
return klass
85+
86+
3987
def is_name_legal_for_catalog(name):
4088
return re.match(r"^[\w" + constants.catalog_hierarchy_sep + "]+$", name)
4189

@@ -133,21 +181,10 @@ def maybe_recover_artifacts_structure(obj):
133181
return obj
134182

135183

136-
def get_closest_artifact_type(type):
137-
artifact_type_options = list(Artifact._class_register.keys())
138-
matches = difflib.get_close_matches(type, artifact_type_options)
139-
if matches:
140-
return matches[0] # Return the closest match
141-
return None
142-
143-
144184
class UnrecognizedArtifactTypeError(ValueError):
145185
def __init__(self, type) -> None:
146186
maybe_class = "".join(word.capitalize() for word in type.split("_"))
147187
message = f"'{type}' is not a recognized artifact 'type'. Make sure a the class defined this type (Probably called '{maybe_class}' or similar) is defined and/or imported anywhere in the code executed."
148-
closest_artifact_type = get_closest_artifact_type(type)
149-
if closest_artifact_type is not None:
150-
message += f"\n\nDid you mean '{closest_artifact_type}'?"
151188
super().__init__(message)
152189

153190

@@ -200,8 +237,6 @@ def verify_artifact_dict(cls, d):
200237
)
201238
if "__type__" not in d:
202239
raise MissingArtifactTypeError(d)
203-
if not cls.is_registered_type(d["__type__"]):
204-
raise UnrecognizedArtifactTypeError(d["__type__"])
205240

206241
@classmethod
207242
def get_artifact_type(cls):
@@ -218,13 +253,6 @@ def register_class(cls, artifact_class):
218253

219254
snake_case_key = camel_to_snake_case(artifact_class.__name__)
220255

221-
if cls.is_registered_type(snake_case_key):
222-
assert (
223-
str(cls._class_register[snake_case_key]) == str(artifact_class)
224-
), f"Artifact class name must be unique, '{snake_case_key}' already exists for {cls._class_register[snake_case_key]}. Cannot be overridden by {artifact_class}."
225-
226-
return snake_case_key
227-
228256
cls._class_register[snake_case_key] = artifact_class
229257

230258
return snake_case_key
@@ -241,19 +269,6 @@ def is_artifact_file(cls, path):
241269
d = json.load(f)
242270
return cls.is_artifact_dict(d)
243271

244-
@classmethod
245-
def is_registered_type(cls, type: str):
246-
return type in cls._class_register
247-
248-
@classmethod
249-
def is_registered_class_name(cls, class_name: str):
250-
snake_case_key = camel_to_snake_case(class_name)
251-
return cls.is_registered_type(snake_case_key)
252-
253-
@classmethod
254-
def is_registered_class(cls, clz: object):
255-
return clz in set(cls._class_register.values())
256-
257272
@classmethod
258273
def _recursive_load(cls, obj):
259274
if isinstance(obj, dict):
@@ -267,7 +282,7 @@ def _recursive_load(cls, obj):
267282
pass
268283
if cls.is_artifact_dict(obj):
269284
cls.verify_artifact_dict(obj)
270-
artifact_class = cls._class_register[obj.pop("__type__")]
285+
artifact_class = get_class_from_artifact_type(obj.pop("__type__"))
271286
obj = artifact_class.process_data_after_load(obj)
272287
return artifact_class(**obj)
273288

@@ -684,3 +699,29 @@ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]:
684699
return None
685700

686701
return data_classification.get(artifact)
702+
703+
704+
def find_unitxt_module_and_class_by_classname(camel_case_class_name: str):
705+
"""Find a module, a member of src/unitxt, that contains the definition of the class."""
706+
dir = os.path.dirname(__file__) # dir src/unitxt
707+
try:
708+
result = subprocess.run(
709+
["grep", "-irwE", "^class +" + camel_case_class_name, dir],
710+
capture_output=True,
711+
).stdout.decode("ascii")
712+
results = result.split("\n")
713+
assert len(results) == 2, f"returned: {results}"
714+
assert results[-1] == "", f"last result is {results[-1]} rather than ''"
715+
to_return_module = (
716+
results[0].split(":")[0][:-3].replace("/", ".")
717+
) # trim the .py and replace
718+
to_return_class_name = results[0].split(":")[1][
719+
6 : 6 + len(camel_case_class_name)
720+
]
721+
return to_return_module[
722+
to_return_module.rfind("unitxt.") :
723+
], to_return_class_name
724+
except Exception as e:
725+
raise ValueError(
726+
f"Could not find the unitxt module, under unitxt/src/unitxt, in which class {camel_case_class_name} is defined"
727+
) from e

src/unitxt/catalog/cards/rag/mtrag/documents/clapnq.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"loader": {
44
"__type__": "load_json_file",
55
"files": {
6-
"test": "https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/clapnq.jsonl.zip"
6+
"test": "https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/document_level/clapnq.jsonl.zip"
77
},
88
"compression": "zip",
99
"lines": true,

src/unitxt/catalog/cards/rag/mtrag/documents/cloud.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"loader": {
44
"__type__": "load_json_file",
55
"files": {
6-
"test": "https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/cloud.jsonl.zip"
6+
"test": "https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/document_level/cloud.jsonl.zip"
77
},
88
"compression": "zip",
99
"lines": true,

src/unitxt/catalog/cards/rag/mtrag/documents/fiqa.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"loader": {
44
"__type__": "load_json_file",
55
"files": {
6-
"test": "https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/fiqa.jsonl.zip"
6+
"test": "https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/document_level/fiqa.jsonl.zip"
77
},
88
"compression": "zip",
99
"lines": true,

src/unitxt/catalog/cards/rag/mtrag/documents/govt.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"loader": {
44
"__type__": "load_json_file",
55
"files": {
6-
"test": "https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/govt.jsonl.zip"
6+
"test": "https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/document_level/govt.jsonl.zip"
77
},
88
"compression": "zip",
99
"lines": true,

src/unitxt/register.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import importlib
2-
import inspect
31
import os
42
from pathlib import Path
53

6-
from .artifact import Artifact, Catalogs
4+
from .artifact import Catalogs
75
from .catalog import EnvironmentLocalCatalog, GithubCatalog, LocalCatalog
86
from .error_utils import Documentation, UnitxtError, UnitxtWarning
97
from .settings_utils import get_constants, get_settings
@@ -89,36 +87,13 @@ def _reset_env_local_catalogs():
8987
_register_catalog(EnvironmentLocalCatalog(location=path))
9088

9189

92-
def _register_all_artifacts():
93-
dir = os.path.dirname(__file__)
94-
file_name = os.path.basename(__file__)
95-
96-
for file in os.listdir(dir):
97-
if (
98-
file.endswith(".py")
99-
and file not in constants.non_registered_files
100-
and file != file_name
101-
):
102-
module_name = file.replace(".py", "")
103-
104-
module = importlib.import_module("." + module_name, __package__)
105-
106-
for _name, obj in inspect.getmembers(module):
107-
# Make sure the object is a class
108-
if inspect.isclass(obj):
109-
# Make sure the class is a subclass of Artifact (but not Artifact itself)
110-
if issubclass(obj, Artifact) and obj is not Artifact:
111-
Artifact.register_class(obj)
112-
113-
11490
class ProjectArtifactRegisterer(metaclass=Singleton):
11591
def __init__(self):
11692
if not hasattr(self, "_registered"):
11793
self._registered = False
11894

11995
if not self._registered:
12096
_register_all_catalogs()
121-
_register_all_artifacts()
12297
self._registered = True
12398

12499

src/unitxt/text_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,19 @@ def camel_to_snake_case(s):
7171
return s.lower()
7272

7373

74+
def snake_to_camel_case(s):
75+
"""Converts a snake_case string s to CamelCase. Assume a class name is in question so result to start with an upper case.
76+
77+
Not always the reciprocal of the above camel_to_snake_case. e.g: camel_to_snake_case(LoadHF) = load_hf,
78+
whereas snake_to_camel_case(load_hf) = LoadHf
79+
"""
80+
s = s.strip()
81+
words = s.split("_")
82+
# Capitalize all words and join them
83+
camel_case_parts = [word.capitalize() for word in words]
84+
return "".join(camel_case_parts)
85+
86+
7487
def to_pretty_string(
7588
value,
7689
indent=0,

0 commit comments

Comments
 (0)