1616from one .api import ONE
1717from one .webclient import AlyxClient
1818from one .util import filter_datasets
19- from one .alf .path import add_uuid_string , session_path_parts , get_alf_path
19+ from one .alf .path import add_uuid_string , get_alf_path , ensure_alf_path
2020from one .alf .cache import _make_datasets_df
2121from iblutil .util import flatten , ensure_list
2222
@@ -461,8 +461,8 @@ def dataset_from_name(name, datasets):
461461
462462 Parameters
463463 ----------
464- name : str
465- The name of the dataset.
464+ name : str, function
465+ The name of the dataset or a function to match the dataset name .
466466 datasets : list of ExpectedDataset
467467 A list of ExpectedDataset instances.
468468
@@ -475,14 +475,18 @@ def dataset_from_name(name, datasets):
475475 matches = []
476476 for dataset in datasets :
477477 if dataset .operator is None :
478- if dataset ._identifiers [2 ] == name :
479- matches .append (dataset )
478+ if isinstance (name , str ):
479+ if dataset ._identifiers [2 ] == name :
480+ matches .append (dataset )
481+ else :
482+ if name (dataset ._identifiers [2 ]):
483+ matches .append (dataset )
480484 else :
481485 matches .extend (dataset_from_name (name , dataset ._identifiers ))
482486 return matches
483487
484488
485- def update_collections (dataset , new_collection , substring = None , unique = None ):
489+ def update_collections (dataset , new_collection , substring = None , unique = None , exact_match = False ):
486490 """
487491 Update the collection of a dataset.
488492
@@ -497,6 +501,12 @@ def update_collections(dataset, new_collection, substring=None, unique=None):
497501 substring : str, optional
498502 An optional substring in the collection to replace with new collection(s). If None, the
499503 entire collection will be replaced.
504+ unique : bool, optional
505+ When provided, this will be used to set the `unique` attribute of the new dataset(s). If
506+ None, the `unique` attribute will be set to True if the collection does not contain
507+ wildcards.
508+ exact_match : bool
509+ If True, the collection will be replaced only if it contains `substring`.
500510
501511 Returns
502512 -------
@@ -511,7 +521,10 @@ def update_collections(dataset, new_collection, substring=None, unique=None):
511521 if revision is not None :
512522 raise NotImplementedError
513523 if substring :
514- after = [(collection or '' ).replace (substring , x ) or None for x in after ]
524+ if exact_match and substring not in collection :
525+ after = [collection ]
526+ else :
527+ after = [(collection or '' ).replace (substring , x ) or None for x in after ]
515528 if unique is None :
516529 unique = [not set (name + (x or '' )).intersection ('*[?' ) for x in after ]
517530 else :
@@ -523,7 +536,7 @@ def update_collections(dataset, new_collection, substring=None, unique=None):
523536 updated &= D (name , folder , not isinstance (dataset , OptionalDataset ), register , unique = unq )
524537 else :
525538 updated = copy (dataset )
526- updated ._identifiers = [update_collections (dd , new_collection , substring , unique )
539+ updated ._identifiers = [update_collections (dd , new_collection , substring , unique , exact_match )
527540 for dd in updated ._identifiers ]
528541 return updated
529542
@@ -536,7 +549,7 @@ def __init__(self, session_path, signature, one=None):
536549 :param signature: input and output file signatures
537550 :param one: ONE instance
538551 """
539- self .session_path = session_path
552+ self .session_path = ensure_alf_path ( session_path )
540553 self .signature = _parse_signature (signature )
541554 self .one = one
542555 self .processed = {} # Map of filepaths and their processed records (e.g. upload receipts or Alyx records)
@@ -566,7 +579,7 @@ def getData(self, one=None):
566579 dfs = [file .filter (session_datasets )[1 ] for file in self .signature ['input_files' ]]
567580 return one ._cache .datasets .iloc [0 :0 ] if len (dfs ) == 0 else pd .concat (dfs ).drop_duplicates ()
568581
569- def getOutputFiles (self ):
582+ def getOutputFiles (self , session_path = None ):
570583 """
571584 Return a data frame of output datasets found on disk.
572585
@@ -575,10 +588,11 @@ def getOutputFiles(self):
575588 pandas.DataFrame
576589 A dataset data frame of datasets on disk that were specified in signature['output_files'].
577590 """
578- assert self .session_path
591+ session_path = self .session_path if session_path is None else session_path
592+ assert session_path
579593 # Next convert datasets to frame
580594 # Create dataframe of all ALF datasets
581- df = _make_datasets_df (self . session_path , hash_files = False ).set_index (['eid' , 'id' ])
595+ df = _make_datasets_df (session_path , hash_files = False ).set_index (['eid' , 'id' ])
582596 # Filter outputs
583597 if len (self .signature ['output_files' ]) == 0 :
584598 return pd .DataFrame ()
@@ -714,7 +728,7 @@ def setUp(self, **_):
714728 _logger .warning ('Space left on server is < 500GB, won\' t re-download new data' )
715729 return
716730
717- rel_sess_path = '/' . join ( self .session_path .parts [ - 3 :] )
731+ rel_sess_path = self .session_path .session_path_short ( )
718732 target_paths = []
719733 source_paths = []
720734 for i , d in df .iterrows ():
@@ -761,13 +775,13 @@ def __init__(self, session_path, signature, one=None):
761775 """
762776 super ().__init__ (session_path , signature , one = one )
763777
764- def setUp (self , ** _ ):
778+ def setUp (self , check_hash = True , ** _ ):
765779 """
766780 Function to download necessary data to run tasks using ONE
767781 :return:
768782 """
769783 df = super ().getData ()
770- self .one ._check_filesystem (df , check_hash = False )
784+ self .one ._check_filesystem (df , check_hash = check_hash )
771785
772786 def uploadData (self , outputs , version , ** kwargs ):
773787 """
@@ -843,8 +857,8 @@ def uploadData(self, outputs, version, **kwargs):
843857 """
844858 # Set up Globus
845859 from one .remote .globus import Globus # noqa
846- self .globus = Globus (client_name = ' server' , headless = True )
847- self .lab = session_path_parts ( self .session_path , as_dict = True )[ ' lab' ]
860+ self .globus = Globus (client_name = kwargs . pop ( 'client_name' , ' server') , headless = True )
861+ self .lab = self .session_path . lab
848862 if self .lab == 'cortexlab' and 'cortexlab' in self .one .alyx .base_url :
849863 base_url = 'https://alyx.internationalbrainlab.org'
850864 _logger .warning ('Changing Alyx client to %s' , base_url )
@@ -957,25 +971,30 @@ def __init__(self, session_path, signatures, one=None):
957971 super ().__init__ (session_path , signatures , one = one )
958972 self .patch_path = os .getenv ('SDSC_PATCH_PATH' , SDSC_PATCH_PATH )
959973 self .root_path = SDSC_ROOT_PATH
974+ self .linked_files = [] # List of symlinks created to run tasks
960975
961- def setUp (self , task ):
976+ def setUp (self , task , ** _ ):
962977 """Function to create symlinks to necessary data to run tasks."""
963978 df = super ().getData ()
964979
965- SDSC_TMP = Path (self .patch_path .joinpath (task .__class__ .__name__ ))
980+ SDSC_TMP = ensure_alf_path (self .patch_path .joinpath (task .__class__ .__name__ ))
966981 session_path = Path (get_alf_path (self .session_path ))
967982 for uuid , d in df .iterrows ():
968983 file_path = session_path / d ['rel_path' ]
969984 file_uuid = add_uuid_string (file_path , uuid )
970985 file_link = SDSC_TMP .joinpath (file_path )
971986 file_link .parent .mkdir (exist_ok = True , parents = True )
972- try :
987+ try : # TODO append link to task attribute
973988 file_link .symlink_to (
974989 Path (self .root_path .joinpath (file_uuid )))
990+ self .linked_files .append (file_link )
975991 except FileExistsError :
976992 pass
977-
978993 task .session_path = SDSC_TMP .joinpath (session_path )
994+ # If one of the symlinked input files is also an expected output, raise here to avoid overwriting
995+ # In the future we may instead copy the data under this condition
996+ assert self .getOutputFiles (session_path = task .session_path ).shape [0 ] == 0 , (
997+ "On SDSC patcher, output files should be distinct from input files to avoid overwriting" )
979998
980999 def uploadData (self , outputs , version , ** kwargs ):
9811000 """
0 commit comments