Skip to content

Commit

Permalink
ArraySchema -> CWLArraySchema and RecordSchema -> CWLRecordSchema
Browse files Browse the repository at this point in the history
  • Loading branch information
GlassOfWhiskey committed Jun 18, 2023
1 parent 8f521df commit 1b9dc32
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 439 deletions.
14 changes: 7 additions & 7 deletions cwl_utils/cwl_v1_0_expression_refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def escape_expression_field(contents: str) -> str:


def clean_type_ids(
cwltype: Union[cwl.ArraySchema, cwl.InputRecordSchema]
) -> Union[cwl.ArraySchema, cwl.InputRecordSchema]:
cwltype: Union[cwl.CWLArraySchema, cwl.InputRecordSchema]
) -> Union[cwl.CWLArraySchema, cwl.InputRecordSchema]:
"""Simplify type identifiers."""
result = copy.deepcopy(cwltype)
if isinstance(result, cwl.ArraySchema):
if isinstance(result, cwl.CWLArraySchema):
if isinstance(result.items, MutableSequence):
for item in result.items:
if hasattr(item, "id"):
Expand Down Expand Up @@ -344,8 +344,8 @@ def generate_etool_from_expr(
self_type = target
if isinstance(self_type, list):
new_type: Union[
List[Union[cwl.ArraySchema, cwl.InputRecordSchema]],
Union[cwl.ArraySchema, cwl.InputRecordSchema],
List[Union[cwl.CWLArraySchema, cwl.InputRecordSchema]],
Union[cwl.CWLArraySchema, cwl.InputRecordSchema],
] = [clean_type_ids(t.type) for t in self_type if t.type]
elif self_type.type:
new_type = clean_type_ids(self_type.type)
Expand Down Expand Up @@ -1345,7 +1345,7 @@ def traverse_CommandLineTool(
modified = True
inp_id = "_{}_glob".format(outp.id.split("#")[-1])
etool_id = f"_expression_{step_id}{inp_id}"
glob_target_type = ["string", cwl.ArraySchema("string", "array")]
glob_target_type = ["string", cwl.CWLArraySchema("string", "array")]
target = cwl.InputParameter(id=None, type=glob_target_type)
replace_step_clt_expr_with_etool(
expression, etool_id, parent, target, step, replace_etool
Expand Down Expand Up @@ -1835,7 +1835,7 @@ def traverse_step(
source_types.append(temp_type)
source_type = cwl.InputParameter(
id=None,
type=cwl.ArraySchema(source_types, "array"),
type=cwl.CWLArraySchema(source_types, "array"),
)
else:
input_source_id = inp.source.split("#")[-1]
Expand Down
14 changes: 7 additions & 7 deletions cwl_utils/cwl_v1_1_expression_refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def escape_expression_field(contents: str) -> str:


def clean_type_ids(
cwltype: Union[cwl.ArraySchema, cwl.InputRecordSchema]
) -> Union[cwl.ArraySchema, cwl.InputRecordSchema]:
cwltype: Union[cwl.CWLArraySchema, cwl.InputRecordSchema]
) -> Union[cwl.CWLArraySchema, cwl.InputRecordSchema]:
"""Simplify type identifiers."""
result = copy.deepcopy(cwltype)
if isinstance(result, cwl.ArraySchema):
if isinstance(result, cwl.CWLArraySchema):
if isinstance(result.items, MutableSequence):
for item in result.items:
if hasattr(item, "id"):
Expand Down Expand Up @@ -344,8 +344,8 @@ def generate_etool_from_expr(
self_type = target
if isinstance(self_type, list):
new_type: Union[
List[Union[cwl.ArraySchema, cwl.InputRecordSchema]],
Union[cwl.ArraySchema, cwl.InputRecordSchema],
List[Union[cwl.CWLArraySchema, cwl.InputRecordSchema]],
Union[cwl.CWLArraySchema, cwl.InputRecordSchema],
] = [clean_type_ids(t.type) for t in self_type]
else:
new_type = clean_type_ids(self_type.type)
Expand Down Expand Up @@ -1345,7 +1345,7 @@ def traverse_CommandLineTool(
modified = True
inp_id = "_{}_glob".format(outp.id.split("#")[-1])
etool_id = f"_expression_{step_id}{inp_id}"
glob_target_type = ["string", cwl.ArraySchema("string", "array")]
glob_target_type = ["string", cwl.CWLArraySchema("string", "array")]
target = cwl.WorkflowInputParameter(id=None, type=glob_target_type)
replace_step_clt_expr_with_etool(
expression, etool_id, parent, target, step, replace_etool
Expand Down Expand Up @@ -1835,7 +1835,7 @@ def traverse_step(
source_types.append(temp_type)
source_type = cwl.WorkflowInputParameter(
id=None,
type=cwl.ArraySchema(source_types, "array"),
type=cwl.CWLArraySchema(source_types, "array"),
)
else:
input_source_id = inp.source.split("#")[-1]
Expand Down
16 changes: 8 additions & 8 deletions cwl_utils/cwl_v1_2_expression_refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def escape_expression_field(contents: str) -> str:


def clean_type_ids(
cwltype: Union[cwl.ArraySchema, cwl.InputRecordSchema]
) -> Union[cwl.ArraySchema, cwl.InputRecordSchema]:
cwltype: Union[cwl.CWLArraySchema, cwl.InputRecordSchema]
) -> Union[cwl.CWLArraySchema, cwl.InputRecordSchema]:
"""Simplify type identifiers."""
result = copy.deepcopy(cwltype)
if isinstance(result, cwl.ArraySchema):
if isinstance(result, cwl.CWLArraySchema):
if isinstance(result.items, MutableSequence):
for item in result.items:
if hasattr(item, "id"):
Expand Down Expand Up @@ -344,8 +344,8 @@ def generate_etool_from_expr(
self_type = target
if isinstance(self_type, list):
new_type: Union[
List[Union[cwl.ArraySchema, cwl.InputRecordSchema]],
Union[cwl.ArraySchema, cwl.InputRecordSchema],
List[Union[cwl.CWLArraySchema, cwl.InputRecordSchema]],
Union[cwl.CWLArraySchema, cwl.InputRecordSchema],
] = [clean_type_ids(t.type) for t in self_type]
else:
new_type = clean_type_ids(self_type.type)
Expand Down Expand Up @@ -706,7 +706,7 @@ def process_workflow_inputs_and_outputs(
else:
sources = [s.split("#")[-1] for s in param2.outputSource]
source_type_items = utils.type_for_source(workflow, sources)
if isinstance(source_type_items, cwl.ArraySchema):
if isinstance(source_type_items, cwl.CWLArraySchema):
if isinstance(source_type_items.items, list):
if "null" not in source_type_items.items:
source_type_items.items.append("null")
Expand Down Expand Up @@ -1440,7 +1440,7 @@ def traverse_CommandLineTool(
modified = True
inp_id = "_{}_glob".format(outp.id.split("#")[-1])
etool_id = f"_expression_{step_id}{inp_id}"
glob_target_type = ["string", cwl.ArraySchema("string", "array")]
glob_target_type = ["string", cwl.CWLArraySchema("string", "array")]
target = cwl.WorkflowInputParameter(id=None, type=glob_target_type)
replace_step_clt_expr_with_etool(
expression, etool_id, parent, target, step, replace_etool
Expand Down Expand Up @@ -1930,7 +1930,7 @@ def traverse_step(
source_types.append(temp_type)
source_type = cwl.WorkflowInputParameter(
id=None,
type=cwl.ArraySchema(source_types, "array"),
type=cwl.CWLArraySchema(source_types, "array"),
)
else:
input_source_id = inp.source.split("#")[-1]
Expand Down
6 changes: 4 additions & 2 deletions cwl_utils/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,14 @@
cwl_v1_2.SoftwareRequirement,
)
"""Type union for a CWL v1.x SoftwareRequirement object."""
ArraySchema = Union[cwl_v1_0.ArraySchema, cwl_v1_1.ArraySchema, cwl_v1_2.ArraySchema]
ArraySchema = Union[
cwl_v1_0.CWLArraySchema, cwl_v1_1.CWLArraySchema, cwl_v1_2.CWLArraySchema
]
"""Type Union for a CWL v1.x ArraySchema object."""
EnumSchema = Union[cwl_v1_0.EnumSchema, cwl_v1_1.EnumSchema, cwl_v1_2.EnumSchema]
"""Type Union for a CWL v1.x EnumSchema object."""
RecordSchema = Union[
cwl_v1_0.RecordSchema, cwl_v1_1.RecordSchema, cwl_v1_2.RecordSchema
cwl_v1_0.CWLRecordSchema, cwl_v1_1.CWLRecordSchema, cwl_v1_2.CWLRecordSchema
]
"""Type Union for a CWL v1.x RecordSchema object."""
File = Union[cwl_v1_0.File, cwl_v1_1.File, cwl_v1_2.File]
Expand Down
40 changes: 22 additions & 18 deletions cwl_utils/parser/cwl_v1_0_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@


def _compare_records(
src: cwl.RecordSchema, sink: cwl.RecordSchema, strict: bool = False
src: cwl.CWLRecordSchema, sink: cwl.CWLRecordSchema, strict: bool = False
) -> bool:
"""
Compare two records, ensuring they have compatible fields.
Expand Down Expand Up @@ -73,9 +73,11 @@ def _compare_records(


def _compare_type(type1: Any, type2: Any) -> bool:
if isinstance(type1, cwl.ArraySchema) and isinstance(type2, cwl.ArraySchema):
if isinstance(type1, cwl.CWLArraySchema) and isinstance(type2, cwl.CWLArraySchema):
return _compare_type(type1.items, type2.items)
elif isinstance(type1, cwl.RecordSchema) and isinstance(type2, cwl.RecordSchema):
elif isinstance(type1, cwl.CWLRecordSchema) and isinstance(
type2, cwl.CWLRecordSchema
):
fields1 = {
cwl.shortname(field.name): field.type for field in (type1.fields or {})
}
Expand Down Expand Up @@ -167,9 +169,9 @@ def can_assign_src_to_sink(src: Any, sink: Any, strict: bool = False) -> bool:
"""
if src == "Any" or sink == "Any":
return True
if isinstance(src, cwl.ArraySchema) and isinstance(sink, cwl.ArraySchema):
if isinstance(src, cwl.CWLArraySchema) and isinstance(sink, cwl.CWLArraySchema):
return can_assign_src_to_sink(src.items, sink.items, strict)
if isinstance(src, cwl.RecordSchema) and isinstance(sink, cwl.RecordSchema):
if isinstance(src, cwl.CWLRecordSchema) and isinstance(sink, cwl.CWLRecordSchema):
return _compare_records(src, sink, strict)
if isinstance(src, MutableSequence):
if strict:
Expand Down Expand Up @@ -256,7 +258,7 @@ def check_types(
return "exception"
if linkMerge == "merge_nested":
return check_types(
cwl.ArraySchema(items=srctype, type="array"), sinktype, None, None
cwl.CWLArraySchema(items=srctype, type="array"), sinktype, None, None
)
if linkMerge == "merge_flattened":
return check_types(merge_flatten_type(srctype), sinktype, None, None)
Expand Down Expand Up @@ -374,9 +376,9 @@ def merge_flatten_type(src: Any) -> Any:
"""Return the merge flattened type of the source type."""
if isinstance(src, MutableSequence):
return [merge_flatten_type(t) for t in src]
if isinstance(src, cwl.ArraySchema):
if isinstance(src, cwl.CWLArraySchema):
return src
return cwl.ArraySchema(type="array", items=src)
return cwl.CWLArraySchema(type="array", items=src)


def type_for_step_input(
Expand All @@ -396,7 +398,7 @@ def type_for_step_input(
):
input_type = step_input.type
if step.scatter is not None and in_.id in aslist(step.scatter):
input_type = cwl.ArraySchema(items=input_type, type="array")
input_type = cwl.CWLArraySchema(items=input_type, type="array")
return input_type
return "Any"

Expand All @@ -418,11 +420,13 @@ def type_for_step_output(
if step.scatter is not None:
if step.scatterMethod == "nested_crossproduct":
for _ in range(len(aslist(step.scatter))):
output_type = cwl.ArraySchema(
output_type = cwl.CWLArraySchema(
items=output_type, type="array"
)
else:
output_type = cwl.ArraySchema(items=output_type, type="array")
output_type = cwl.CWLArraySchema(
items=output_type, type="array"
)
return output_type
raise ValidationException(
"param {} not found in {}.".format(
Expand All @@ -446,11 +450,11 @@ def type_for_source(
if scatter_context[0] is not None:
if scatter_context[0][1] == "nested_crossproduct":
for _ in range(scatter_context[0][0]):
new_type = cwl.ArraySchema(items=new_type, type="array")
new_type = cwl.CWLArraySchema(items=new_type, type="array")
else:
new_type = cwl.ArraySchema(items=new_type, type="array")
new_type = cwl.CWLArraySchema(items=new_type, type="array")
if linkMerge == "merge_nested":
new_type = cwl.ArraySchema(items=new_type, type="array")
new_type = cwl.CWLArraySchema(items=new_type, type="array")
elif linkMerge == "merge_flattened":
new_type = merge_flatten_type(new_type)
return new_type
Expand All @@ -466,18 +470,18 @@ def type_for_source(
if sc is not None:
if sc[1] == "nested_crossproduct":
for _ in range(sc[0]):
cur_type = cwl.ArraySchema(items=cur_type, type="array")
cur_type = cwl.CWLArraySchema(items=cur_type, type="array")
else:
cur_type = cwl.ArraySchema(items=cur_type, type="array")
cur_type = cwl.CWLArraySchema(items=cur_type, type="array")
new_type.append(cur_type)
if len(new_type) == 1:
new_type = new_type[0]
if linkMerge == "merge_nested":
return cwl.ArraySchema(items=new_type, type="array")
return cwl.CWLArraySchema(items=new_type, type="array")
elif linkMerge == "merge_flattened":
return merge_flatten_type(new_type)
elif isinstance(sourcenames, List) and len(sourcenames) > 1:
return cwl.ArraySchema(items=new_type, type="array")
return cwl.CWLArraySchema(items=new_type, type="array")
else:
return new_type

Expand Down
Loading

0 comments on commit 1b9dc32

Please sign in to comment.