77import inspect
88import itertools
99from types import MethodType
10+ from typing import List
1011
1112import numpy as np
1213
@@ -96,6 +97,31 @@ def filter_attributes(ctx, f, **kwargs):
9697 _process_exclusion (ctx , cls_attrs , kwargs ["exclude" ], f )
9798
9899
100+ def validate_data_types (
101+ prohibited_data_types : List [str ], reserved_words = ["collaborators" ], ** kwargs
102+ ):
103+ """Validates that the types of attributes in kwargs are not among the prohibited data types.
104+ Raises a TypeError if any prohibited data type is found.
105+
106+ Args:
107+ prohibited_data_types (List[str]): A list of prohibited data type names
108+ (e.g., ['int', 'float']).
109+ kwargs (dict): Arbitrary keyword arguments representing attribute names and their values.
110+
111+ Raises:
112+ TypeError: If any prohibited data types are found in kwargs.
113+ ValueError: If prohibited_data_types is empty.
114+ """
115+ if not prohibited_data_types :
116+ raise ValueError ("prohibited_data_types must not be empty." )
117+ for attr_name , attr_value in kwargs .items ():
118+ if type (attr_value ).__name__ in prohibited_data_types and attr_value not in reserved_words :
119+ raise TypeError (
120+ f"The attribute '{ attr_name } ' = '{ attr_value } ' has a prohibited value type: "
121+ f"{ type (attr_value ).__name__ } "
122+ )
123+
124+
99125def _validate_include_exclude (kwargs , cls_attrs ):
100126 """Validates that 'include' and 'exclude' are not both present, and that
101127 attributes in 'include' or 'exclude' exist in the context.
@@ -152,13 +178,13 @@ def _process_exclusion(ctx, cls_attrs, exclude_list, f):
152178 delattr (ctx , attr )
153179
154180
155- def checkpoint (ctx , parent_func , chkpnt_reserved_words = ["next" , "runtime" ]):
181+ def checkpoint (ctx , parent_func , checkpoint_reserved_words = ["next" , "runtime" ]):
156182 """Optionally saves the current state for the task just executed.
157183
158184 Args:
159185 ctx (any): The context to checkpoint.
160186 parent_func (function): The function that was just executed.
161- chkpnt_reserved_words (list, optional): A list of reserved words to
187+ checkpoint_reserved_words (list, optional): A list of reserved words to
162188 exclude from checkpointing. Defaults to ["next", "runtime"].
163189
164190 Returns:
@@ -173,7 +199,7 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
173199 if ctx ._checkpoint :
174200 # all objects will be serialized using Metaflow interface
175201 print (f"Saving data artifacts for { parent_func .__name__ } " )
176- artifacts_iter , _ = generate_artifacts (ctx = ctx , reserved_words = chkpnt_reserved_words )
202+ artifacts_iter , _ = generate_artifacts (ctx = ctx , reserved_words = checkpoint_reserved_words )
177203 task_id = ctx ._metaflow_interface .create_task (parent_func .__name__ )
178204 ctx ._metaflow_interface .save_artifacts (
179205 artifacts_iter (),
@@ -188,15 +214,15 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
188214
189215def old_check_resource_allocation (num_gpus , each_participant_gpu_usage ):
190216 remaining_gpu_memory = {}
191- # TODO for each GPU the funtion tries see if all participant usages fit
217+ # TODO for each GPU the function tries see if all participant usages fit
192218 # into a GPU, it it doesn't it removes that participant from the
193219 # participant list, and adds it to the remaining_gpu_memory dict. So any
194220 # sum of GPU requirements above 1 triggers this.
195- # But at this point the funtion will raise an error because
221+ # But at this point the function will raise an error because
196222 # remaining_gpu_memory is never cleared.
197223 # The participant list should remove the participant if it fits in the gpu
198- # and save the partipant if it doesn't and continue to the next GPU to see
199- # if it fits in that one, only if we run out of GPUs should this funtion
224+ # and save the participant if it doesn't and continue to the next GPU to see
225+ # if it fits in that one, only if we run out of GPUs should this function
200226 # raise an error.
201227 for gpu in np .ones (num_gpus , dtype = int ):
202228 for i , (participant_name , participant_gpu_usage ) in enumerate (
@@ -230,7 +256,7 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage):
230256 if gpu == 0 :
231257 break
232258 if gpu < participant_gpu_usage :
233- # participant doesn't fitm break to next GPU
259+ # participant doesn't fit, break to next GPU
234260 break
235261 else :
236262 # if participant fits remove from need_assigned
0 commit comments