diff --git a/graphql/__init__.py b/graphql/__init__.py index 2365383f..035a5435 100644 --- a/graphql/__init__.py +++ b/graphql/__init__.py @@ -89,6 +89,7 @@ # Parse and operate on GraphQL language source files. from .language.base import ( # no import order Source, + FileSource, get_location, # Parse parse, @@ -223,6 +224,7 @@ "BREAK", "ParallelVisitor", "Source", + "FileSource", "TypeInfoVisitor", "get_location", "parse", diff --git a/graphql/language/ast.py b/graphql/language/ast.py index f7f407ea..5109eb59 100644 --- a/graphql/language/ast.py +++ b/graphql/language/ast.py @@ -1079,6 +1079,7 @@ def __repr__(self): "name={self.name!r}" ", arguments={self.arguments!r}" ", type={self.type!r}" + ", directives={self.directives!r}" ")" ).format(self=self) diff --git a/graphql/language/base.py b/graphql/language/base.py index fca28dc6..6c0bac60 100644 --- a/graphql/language/base.py +++ b/graphql/language/base.py @@ -2,7 +2,7 @@ from .location import get_location from .parser import parse, parse_value from .printer import print_ast -from .source import Source +from .source import FileSource, Source from .visitor import BREAK, ParallelVisitor, TypeInfoVisitor, visit __all__ = [ @@ -12,6 +12,7 @@ "parse_value", "print_ast", "Source", + "FileSource", "BREAK", "ParallelVisitor", "TypeInfoVisitor", diff --git a/graphql/language/source.py b/graphql/language/source.py index 0f737774..45a2f452 100644 --- a/graphql/language/source.py +++ b/graphql/language/source.py @@ -1,5 +1,6 @@ -__all__ = ["Source"] +import os +__all__ = ["Source", "FileSource"] class Source(object): __slots__ = "body", "name" @@ -15,3 +16,45 @@ def __eq__(self, other): and self.body == other.body and self.name == other.name ) + +class FileSource(Source): + __slots__ = "body", "name" + + def __init__(self, *args, **kwargs): + """Create a Source using the specified GraphQL files' contents.""" + name = kwargs.get("name", "GraphQL") + + # From the specified list of paths, first identify all files. Then, load + # their contents into a single, newline delimited string. + file_contents = [] + file_paths = self.__get_file_paths__(args) + for fp in file_paths: + with open(fp) as f: + file_contents.append(f.read()) + body = '\n'.join(file_contents) + + super(FileSource, self).__init__(body, name) + + def __get_file_paths__(self, paths): + """Get the paths to all files in the given list of paths. This means + filtering out invalid paths and recursively walking a given directory + path to gather the paths of all files that it contains.""" + all_file_paths = [] + + # Filter out invalid paths. + valid_paths = [p for p in paths if os.path.exists(p)] + + # Add all paths pointing to a file to all_file_paths. + all_file_paths += [p for p in valid_paths if os.path.isfile(p)] + + # For each path referring to a directory, walk that directory's structure + # recursively, and add its constituent files' paths to all_file_paths. + all_file_paths += [ + os.path.join(dir_name, file_name) + for p in valid_paths + if os.path.isdir(p) + for dir_name, _, files_in_dir in os.walk(p) + for file_name in files_in_dir + ] + + return all_file_paths diff --git a/graphql/language/tests/graphql_schemas/models/Person.graphql b/graphql/language/tests/graphql_schemas/models/Person.graphql new file mode 100644 index 00000000..254a8a84 --- /dev/null +++ b/graphql/language/tests/graphql_schemas/models/Person.graphql @@ -0,0 +1,4 @@ +type Person { + name: ID! + age: Int +} \ No newline at end of file diff --git a/graphql/language/tests/graphql_schemas/models/Skill.graphql b/graphql/language/tests/graphql_schemas/models/Skill.graphql new file mode 100644 index 00000000..4f7b1299 --- /dev/null +++ b/graphql/language/tests/graphql_schemas/models/Skill.graphql @@ -0,0 +1,9 @@ +type Skill { + name: ID! + level: Int + possessors: [Person!] +} + +extend type Person { + skills: [Skill!] +} diff --git a/graphql/language/tests/graphql_schemas/schema.graphql b/graphql/language/tests/graphql_schemas/schema.graphql new file mode 100644 index 00000000..5b44ad4b --- /dev/null +++ b/graphql/language/tests/graphql_schemas/schema.graphql @@ -0,0 +1,8 @@ +type Query { + person(name: ID!): Person + skill(name: ID!): Skill +} + +schema { + query: Query +} diff --git a/graphql/language/tests/test_schema_parser.py b/graphql/language/tests/test_schema_parser.py index 6ec58d4d..4b4c744d 100644 --- a/graphql/language/tests/test_schema_parser.py +++ b/graphql/language/tests/test_schema_parser.py @@ -1,9 +1,12 @@ -from pytest import raises - -from graphql import Source, parse +from graphql import FileSource, Source, parse from graphql.error import GraphQLSyntaxError from graphql.language import ast from graphql.language.parser import Loc + +import os + +from pytest import raises + from typing import Callable @@ -567,6 +570,133 @@ def test_parses_simple_input_object(): assert doc == expected +def test_parses_schema_files(): + test_graphql_schemas_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "graphql_schemas") + doc = parse(FileSource(test_graphql_schemas_dir)) + expected = ast.Document( + definitions=[ + ast.ObjectTypeDefinition( + name=ast.Name(value="Query"), + interfaces=[], + fields=[ + ast.FieldDefinition( + name=ast.Name(value="person"), + arguments=[ + ast.InputValueDefinition( + name=ast.Name(value="name"), + type=ast.NonNullType( + type=ast.NamedType(name=ast.Name(value="ID")) + ), + default_value=None, + directives=[] + ) + ], + type=ast.NamedType(name=ast.Name(value="Person")), + directives=[] + ), + ast.FieldDefinition( + name=ast.Name(value="skill"), + arguments=[ + ast.InputValueDefinition( + name=ast.Name(value="name"), + type=ast.NonNullType( + type=ast.NamedType(name=ast.Name(value="ID")) + ), + default_value=None, + directives=[] + ) + ], + type=ast.NamedType(name=ast.Name(value="Skill")), + directives=[] + ) + ], + directives=[] + ), + ast.SchemaDefinition( + operation_types=[ + ast.OperationTypeDefinition( + operation="query", + type=ast.NamedType(name=ast.Name(value="Query")) + ) + ], + directives=[] + ), + ast.ObjectTypeDefinition( + name=ast.Name(value="Person"), + interfaces=[], + fields=[ + ast.FieldDefinition( + name=ast.Name(value="name"), + arguments=[], + type=ast.NonNullType( + type=ast.NamedType(name=ast.Name(value="ID")) + ), + directives=[] + ), + ast.FieldDefinition( + name=ast.Name(value="age"), + arguments=[], + type=ast.NamedType(name=ast.Name(value="Int")), + directives=[] + ) + ], + directives=[] + ), + ast.ObjectTypeDefinition( + name=ast.Name(value="Skill"), + interfaces=[], + fields=[ + ast.FieldDefinition( + name=ast.Name(value="name"), + arguments=[], + type=ast.NonNullType( + type=ast.NamedType(name=ast.Name(value="ID")) + ), + directives=[] + ), + ast.FieldDefinition( + name=ast.Name(value="level"), + arguments=[], + type=ast.NamedType(name=ast.Name(value="Int")), + directives=[] + ), + ast.FieldDefinition( + name=ast.Name(value="possessors"), + arguments=[], + type=ast.ListType( + type=ast.NonNullType( + type=ast.NamedType(name=ast.Name(value="Person")) + ) + ), + directives=[] + ) + ], + directives=[] + ), + ast.TypeExtensionDefinition( + definition=ast.ObjectTypeDefinition( + name=ast.Name(value="Person"), + interfaces=[], + fields=[ + ast.FieldDefinition( + name=ast.Name(value="skills"), + arguments=[], + type=ast.ListType( + type=ast.NonNullType( + type=ast.NamedType(name=ast.Name(value="Skill")) + ) + ), + directives=[] + ) + ], + directives=[] + ) + ) + ] + ) + assert doc == expected + + def test_parsing_simple_input_object_with_args_should_fail(): # type: () -> None body = """