Skip to content

Commit

Permalink
Added flatten directive to ArraySchema
Browse files Browse the repository at this point in the history
  • Loading branch information
GlassOfWhiskey committed Feb 12, 2023
1 parent 4a55cbf commit bf96061
Show file tree
Hide file tree
Showing 12 changed files with 125 additions and 39 deletions.
13 changes: 11 additions & 2 deletions schema_salad/avro/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,11 @@ def symbols(self) -> List[str]:

class ArraySchema(Schema):
def __init__(
self, items: JsonDataType, names: Names, other_props: Optional[PropsType] = None
self,
items: JsonDataType,
names: Names,
flatten: Optional[bool] = True,
other_props: Optional[PropsType] = None,
) -> None:
# Call parent ctor
Schema.__init__(self, "array", other_props)
Expand All @@ -415,12 +419,16 @@ def __init__(
) from err

self.set_prop("items", items_schema)
self.set_prop("flatten", flatten)

# read-only properties
@property
def items(self) -> Schema:
return cast(Schema, self.get_prop("items"))

def flatten(self) -> bool:
return cast(bool, self.get_prop("flatten"))


class MapSchema(Schema):
def __init__(
Expand Down Expand Up @@ -681,7 +689,8 @@ def make_avsc_object(json_data: JsonDataType, names: Optional[Names] = None) ->
if atype in VALID_TYPES:
if atype == "array":
items = json_data.get("items")
return ArraySchema(items, names, other_props)
flatten = json_data.get("flatten")
return ArraySchema(items, names, flatten, other_props)
elif atype == "map":
name = json_data.get("name")
namespace = json_data.get("namespace", names.default_namespace)
Expand Down
71 changes: 52 additions & 19 deletions schema_salad/metaschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,9 @@ def __repr__(self): # type: () -> str


class _ArrayLoader(_Loader):
def __init__(self, items):
# type: (_Loader) -> None
def __init__(self, items: _Loader, flatten: bool = True) -> None:
self.items = items
self.flatten = flatten

def load(self, doc, baseuri, loadingOptions, docRoot=None):
# type: (Any, str, LoadingOptions, Optional[str]) -> Any
Expand All @@ -395,7 +395,7 @@ def load(self, doc, baseuri, loadingOptions, docRoot=None):
for i in range(0, len(doc)):
try:
lf = load_field(doc[i], _UnionLoader((self, self.items)), baseuri, loadingOptions)
if isinstance(lf, MutableSequence):
if self.flatten and isinstance(lf, MutableSequence):
r.extend(lf)
else:
r.append(lf)
Expand Down Expand Up @@ -1365,6 +1365,7 @@ def __init__(
self,
items: Any,
type: Any,
flatten: Optional[Any] = None,
extension_fields: Optional[Dict[str, Any]] = None,
loadingOptions: Optional[LoadingOptions] = None,
) -> None:
Expand All @@ -1377,16 +1378,21 @@ def __init__(
self.loadingOptions = loadingOptions
else:
self.loadingOptions = LoadingOptions()
self.flatten = flatten
self.items = items
self.type = type

def __eq__(self, other: Any) -> bool:
if isinstance(other, ArraySchema):
return bool(self.items == other.items and self.type == other.type)
return bool(
self.flatten == other.flatten
and self.items == other.items
and self.type == other.type
)
return False

def __hash__(self) -> int:
return hash((self.items, self.type))
return hash((self.flatten, self.items, self.type))

@classmethod
def fromDoc(
Expand All @@ -1401,6 +1407,24 @@ def fromDoc(
_doc.lc.data = doc.lc.data
_doc.lc.filename = doc.lc.filename
_errors__ = []
if "flatten" in _doc:
try:
flatten = load_field(
_doc.get("flatten"),
uri_union_of_None_type_or_booltype_False_True_2,
baseuri,
loadingOptions,
)
except ValidationException as e:
_errors__.append(
ValidationException(
"the `flatten` field is not valid because:",
SourceLine(_doc, "flatten", str),
[e],
)
)
else:
flatten = None
try:
items = load_field(
_doc.get("items"),
Expand Down Expand Up @@ -1442,7 +1466,7 @@ def fromDoc(
else:
_errors__.append(
ValidationException(
"invalid field `{}`, expected one of: `items`, `type`".format(
"invalid field `{}`, expected one of: `flatten`, `items`, `type`".format(
k
),
SourceLine(_doc, k, str),
Expand All @@ -1453,6 +1477,7 @@ def fromDoc(
if _errors__:
raise ValidationException("Trying 'ArraySchema'", None, _errors__)
_constructed = cls(
flatten=flatten,
items=items,
type=type,
extension_fields=extension_fields,
Expand All @@ -1471,6 +1496,9 @@ def save(
else:
for ef in self.extension_fields:
r[ef] = self.extension_fields[ef]
if self.flatten is not None:
u = save_relative_uri(self.flatten, base_url, False, 2, relative_uris)
r["flatten"] = u
if self.items is not None:
u = save_relative_uri(self.items, base_url, False, 2, relative_uris)
r["items"] = u
Expand All @@ -1487,7 +1515,7 @@ def save(
r["$schemas"] = self.loadingOptions.schemas
return r

attrs = frozenset(["items", "type"])
attrs = frozenset(["flatten", "items", "type"])


class MapSchema(Saveable):
Expand Down Expand Up @@ -4619,7 +4647,7 @@ def save(
SaladMapSchemaLoader = _RecordLoader(SaladMapSchema)
SaladUnionSchemaLoader = _RecordLoader(SaladUnionSchema)
DocumentationLoader = _RecordLoader(Documentation)
array_of_strtype = _ArrayLoader(strtype)
array_of_strtype = _ArrayLoader(strtype, True)
union_of_None_type_or_strtype_or_array_of_strtype = _UnionLoader(
(
None_type,
Expand All @@ -4640,7 +4668,8 @@ def save(
)
)
array_of_union_of_PrimitiveTypeLoader_or_RecordSchemaLoader_or_EnumSchemaLoader_or_ArraySchemaLoader_or_MapSchemaLoader_or_UnionSchemaLoader_or_strtype = _ArrayLoader(
union_of_PrimitiveTypeLoader_or_RecordSchemaLoader_or_EnumSchemaLoader_or_ArraySchemaLoader_or_MapSchemaLoader_or_UnionSchemaLoader_or_strtype
union_of_PrimitiveTypeLoader_or_RecordSchemaLoader_or_EnumSchemaLoader_or_ArraySchemaLoader_or_MapSchemaLoader_or_UnionSchemaLoader_or_strtype,
True,
)
union_of_PrimitiveTypeLoader_or_RecordSchemaLoader_or_EnumSchemaLoader_or_ArraySchemaLoader_or_MapSchemaLoader_or_UnionSchemaLoader_or_strtype_or_array_of_union_of_PrimitiveTypeLoader_or_RecordSchemaLoader_or_EnumSchemaLoader_or_ArraySchemaLoader_or_MapSchemaLoader_or_UnionSchemaLoader_or_strtype = _UnionLoader(
(
Expand All @@ -4658,7 +4687,7 @@ def save(
union_of_PrimitiveTypeLoader_or_RecordSchemaLoader_or_EnumSchemaLoader_or_ArraySchemaLoader_or_MapSchemaLoader_or_UnionSchemaLoader_or_strtype_or_array_of_union_of_PrimitiveTypeLoader_or_RecordSchemaLoader_or_EnumSchemaLoader_or_ArraySchemaLoader_or_MapSchemaLoader_or_UnionSchemaLoader_or_strtype,
2,
)
array_of_RecordFieldLoader = _ArrayLoader(RecordFieldLoader)
array_of_RecordFieldLoader = _ArrayLoader(RecordFieldLoader, True)
union_of_None_type_or_array_of_RecordFieldLoader = _UnionLoader(
(
None_type,
Expand All @@ -4682,6 +4711,15 @@ def save(
uri_array_of_strtype_True_False_None = _URILoader(array_of_strtype, True, False, None)
Enum_nameLoader = _EnumLoader(("enum",), "Enum_name")
typedsl_Enum_nameLoader_2 = _TypeDSLLoader(Enum_nameLoader, 2)
union_of_None_type_or_booltype = _UnionLoader(
(
None_type,
booltype,
)
)
uri_union_of_None_type_or_booltype_False_True_2 = _URILoader(
union_of_None_type_or_booltype, False, True, 2
)
uri_union_of_PrimitiveTypeLoader_or_RecordSchemaLoader_or_EnumSchemaLoader_or_ArraySchemaLoader_or_MapSchemaLoader_or_UnionSchemaLoader_or_strtype_or_array_of_union_of_PrimitiveTypeLoader_or_RecordSchemaLoader_or_EnumSchemaLoader_or_ArraySchemaLoader_or_MapSchemaLoader_or_UnionSchemaLoader_or_strtype_False_True_2 = _URILoader(
union_of_PrimitiveTypeLoader_or_RecordSchemaLoader_or_EnumSchemaLoader_or_ArraySchemaLoader_or_MapSchemaLoader_or_UnionSchemaLoader_or_strtype_or_array_of_union_of_PrimitiveTypeLoader_or_RecordSchemaLoader_or_EnumSchemaLoader_or_ArraySchemaLoader_or_MapSchemaLoader_or_UnionSchemaLoader_or_strtype,
False,
Expand All @@ -4694,12 +4732,6 @@ def save(
typedsl_Map_nameLoader_2 = _TypeDSLLoader(Map_nameLoader, 2)
Union_nameLoader = _EnumLoader(("union",), "Union_name")
typedsl_Union_nameLoader_2 = _TypeDSLLoader(Union_nameLoader, 2)
union_of_None_type_or_booltype = _UnionLoader(
(
None_type,
booltype,
)
)
union_of_None_type_or_inttype = _UnionLoader(
(
None_type,
Expand All @@ -4726,7 +4758,7 @@ def save(
Any_type,
)
)
array_of_SaladRecordFieldLoader = _ArrayLoader(SaladRecordFieldLoader)
array_of_SaladRecordFieldLoader = _ArrayLoader(SaladRecordFieldLoader, True)
union_of_None_type_or_array_of_SaladRecordFieldLoader = _UnionLoader(
(
None_type,
Expand All @@ -4739,7 +4771,7 @@ def save(
uri_union_of_None_type_or_strtype_or_array_of_strtype_False_False_1 = _URILoader(
union_of_None_type_or_strtype_or_array_of_strtype, False, False, 1
)
array_of_SpecializeDefLoader = _ArrayLoader(SpecializeDefLoader)
array_of_SpecializeDefLoader = _ArrayLoader(SpecializeDefLoader, True)
union_of_None_type_or_array_of_SpecializeDefLoader = _UnionLoader(
(
None_type,
Expand All @@ -4761,7 +4793,8 @@ def save(
)
)
array_of_union_of_SaladRecordSchemaLoader_or_SaladEnumSchemaLoader_or_SaladMapSchemaLoader_or_SaladUnionSchemaLoader_or_DocumentationLoader = _ArrayLoader(
union_of_SaladRecordSchemaLoader_or_SaladEnumSchemaLoader_or_SaladMapSchemaLoader_or_SaladUnionSchemaLoader_or_DocumentationLoader
union_of_SaladRecordSchemaLoader_or_SaladEnumSchemaLoader_or_SaladMapSchemaLoader_or_SaladUnionSchemaLoader_or_DocumentationLoader,
True,
)
union_of_SaladRecordSchemaLoader_or_SaladEnumSchemaLoader_or_SaladMapSchemaLoader_or_SaladUnionSchemaLoader_or_DocumentationLoader_or_array_of_union_of_SaladRecordSchemaLoader_or_SaladEnumSchemaLoader_or_SaladMapSchemaLoader_or_SaladUnionSchemaLoader_or_DocumentationLoader = _UnionLoader(
(
Expand Down
7 changes: 7 additions & 0 deletions schema_salad/metaschema/metaschema_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ $graph:
_type: "@vocab"
refScope: 2
doc: "Defines the type of the array elements."
flatten:
type: boolean?
jsonldPredicate:
_id: "sld:flatten"
_type: "@vocab"
refScope: 2
doc: "Flatten inner array objects into a single sequence (default: true)."


- name: MapSchema
Expand Down
7 changes: 6 additions & 1 deletion schema_salad/python_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,12 @@ def type_loader(self, type_declaration: Union[List[Any], Dict[str, Any], str]) -
"https://w3id.org/cwl/salad#array",
):
i = self.type_loader(type_declaration["items"])
return self.declare_type(TypeDef(f"array_of_{i.name}", f"_ArrayLoader({i.name})"))
return self.declare_type(
TypeDef(
f"array_of_{i.name}",
f"_ArrayLoader({i.name}, {type_declaration.get('flatten', True)})",
)
)
if type_declaration["type"] in (
"map",
"https://w3id.org/cwl/salad#map",
Expand Down
6 changes: 3 additions & 3 deletions schema_salad/python_codegen_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ def __repr__(self): # type: () -> str


class _ArrayLoader(_Loader):
def __init__(self, items):
# type: (_Loader) -> None
def __init__(self, items: _Loader, flatten: bool = True) -> None:
self.items = items
self.flatten = flatten

def load(self, doc, baseuri, loadingOptions, docRoot=None):
# type: (Any, str, LoadingOptions, Optional[str]) -> Any
Expand All @@ -392,7 +392,7 @@ def load(self, doc, baseuri, loadingOptions, docRoot=None):
for i in range(0, len(doc)):
try:
lf = load_field(doc[i], _UnionLoader((self, self.items)), baseuri, loadingOptions)
if isinstance(lf, MutableSequence):
if self.flatten and isinstance(lf, MutableSequence):
r.extend(lf)
else:
r.append(lf)
Expand Down
1 change: 1 addition & 0 deletions schema_salad/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def get_metaschema() -> Tuple[Names, List[Dict[str, str]], Loader]:
"mapPredicate": "type",
"mapSubject": "name",
},
"flatten": saladp + "flatten",
"float": "http://www.w3.org/2001/XMLSchema#float",
"identity": saladp + "JsonldPredicate/identity",
"inVocab": saladp + "NamedType/inVocab",
Expand Down
23 changes: 18 additions & 5 deletions schema_salad/tests/cwl-pre.yml
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,6 @@
"name": "https://w3id.org/cwl/cwl#CWLObjectType",
"type": "union",
"names": [
"null",
"boolean",
"int",
"long",
Expand All @@ -583,12 +582,19 @@
"https://w3id.org/cwl/cwl#File",
"https://w3id.org/cwl/cwl#Directory",
{
"items": "https://w3id.org/cwl/cwl#CWLObjectType",
"flatten": false,
"items": [
"null",
"https://w3id.org/cwl/cwl#CWLObjectType"
],
"type": "array"
},
{
"type": "map",
"values": "https://w3id.org/cwl/cwl#CWLObjectType"
"values": [
"null",
"https://w3id.org/cwl/cwl#CWLObjectType"
]
}
],
"doc": "Generic type representing a valid CWL object. It is used to represent\n`default` values passed to CWL `InputParameter` and `WorkflowStepInput`\nrecord fields.\n"
Expand Down Expand Up @@ -997,7 +1003,10 @@
},
{
"name": "https://w3id.org/cwl/cwl#InputParameter/default",
"type": "https://w3id.org/cwl/cwl#CWLObjectType",
"type": [
"null",
"https://w3id.org/cwl/cwl#CWLObjectType"
],
"jsonldPredicate": "cwl:default",
"doc": "The default value for this parameter if not provided in the input\nobject.\n"
},
Expand Down Expand Up @@ -2263,7 +2272,10 @@
},
{
"name": "https://w3id.org/cwl/cwl#WorkflowStepInput/default",
"type": "https://w3id.org/cwl/cwl#CWLObjectType",
"type": [
"null",
"https://w3id.org/cwl/cwl#CWLObjectType"
],
"doc": "The default value for this parameter if there is no `source`\nfield.\n",
"jsonldPredicate": "cwl:default"
},
Expand Down Expand Up @@ -2548,6 +2560,7 @@
"name": "https://w3id.org/cwl/cwl#CWLInputFile",
"doc": "Type representing a valid CWL input file as a `map<string, union<array<ProcessRequirement>, CWLObjectType>>`.",
"values": [
"null",
{
"items": "https://w3id.org/cwl/cwl#ProcessRequirement",
"type": "array"
Expand Down
13 changes: 13 additions & 0 deletions schema_salad/tests/metaschema-pre.yml
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,19 @@
"name": "https://w3id.org/cwl/salad#ArraySchema",
"type": "record",
"fields": [
{
"type": [
"null",
"boolean"
],
"jsonldPredicate": {
"_id": "https://w3id.org/cwl/salad#flatten",
"_type": "@vocab",
"refScope": 2
},
"doc": "Flatten inner array objects into a single sequence (default: true).",
"name": "https://w3id.org/cwl/salad#ArraySchema/flatten"
},
{
"type": [
"PrimitiveType",
Expand Down
2 changes: 1 addition & 1 deletion schema_salad/tests/test_makedoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,5 +239,5 @@ def test_detect_changes_in_html(metaschema_doc: str, tmp_path: Path) -> None:
with open(result, "w") as h:
h.write(metaschema_doc)
assert (
hasher.hexdigest() == "108722da130cb85c6dc76e9925789f698e26cd42ab0056975b524449d8e469f7"
hasher.hexdigest() == "9f42a2951050100c81028d3e082899951ef914d942e6576c512c77530fe41fc4"
), result
Loading

0 comments on commit bf96061

Please sign in to comment.