diff --git a/flask_graphql/graphqlview.py b/flask_graphql/graphqlview.py index ff257b3..f846b8e 100644 --- a/flask_graphql/graphqlview.py +++ b/flask_graphql/graphqlview.py @@ -9,6 +9,7 @@ load_json_body, run_http_query) from .render_graphiql import render_graphiql +from .utils import place_files_in_operations class GraphQLView(View): @@ -135,9 +136,17 @@ def parse_body(self): elif content_type == 'application/json': return load_json_body(request.data.decode('utf8')) - elif content_type in ('application/x-www-form-urlencoded', 'multipart/form-data'): + elif content_type == 'application/x-www-form-urlencoded': return request.form + elif content_type == 'multipart/form-data': + operations = load_json_body(request.form['operations']) + files_map = load_json_body(request.form['map']) + return place_files_in_operations( + operations, + files_map, + request.files + ) return {} def should_display_graphiql(self): diff --git a/flask_graphql/utils.py b/flask_graphql/utils.py new file mode 100644 index 0000000..f26ec02 --- /dev/null +++ b/flask_graphql/utils.py @@ -0,0 +1,42 @@ +def place_files_in_operations(operations, files_map, files): + path_to_key_iter = ( + (value.split('.'), key) + for key, values in files_map.items() + for value in values + ) + # Since add_files_to_operations returns a new dict/list, first define + # output to be operations itself + output = operations + for path, key in path_to_key_iter: + file_obj = files[key] + output = add_file_to_operations(output, file_obj, path) + return output + + +def add_file_to_operations(operations, file_obj, path): + if not path: + return file_obj + if isinstance(operations, dict): + key = path[0] + sub_dict = add_file_to_operations(operations[key], file_obj, path[1:]) + return new_merged_dict(operations, {key: sub_dict}) + if isinstance(operations, list): + index = int(path[0]) + sub_item = add_file_to_operations(operations[index], file_obj, path[1:]) + return new_list_with_replaced_item(operations, index, sub_item) + return TypeError('Operations must be a JSON data structure') + + +def new_merged_dict(*dicts): + # Necessary for python2 support + output = {} + for d in dicts: + output.update(d) + return output + + +def new_list_with_replaced_item(input_list, index, new_value): + # Necessary for python2 support + output = [i for i in input_list] + output[index] = new_value + return output diff --git a/tests/schema.py b/tests/schema.py index f841672..7abde18 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -1,12 +1,34 @@ -from graphql.type.definition import GraphQLArgument, GraphQLField, GraphQLNonNull, GraphQLObjectType -from graphql.type.scalars import GraphQLString +from graphql.type.definition import GraphQLArgument, GraphQLField, GraphQLNonNull, GraphQLObjectType, GraphQLList +from graphql.type.scalars import GraphQLString, GraphQLScalarType from graphql.type.schema import GraphQLSchema +def resolve_test_file(obj, info, what): + output = what.readline().decode('utf-8') + what.seek(0) + return output + + +def resolve_test_files(obj, info, whats): + output = ''.join(what.readline().decode('utf-8') for what in whats) + for what in whats: + what.seek(0) + return output + + def resolve_raises(*_): raise Exception("Throws!") +# This scalar should be added to graphql-core at some point +GraphQLUpload = GraphQLScalarType( + name="Upload", + description="The `Upload` scalar type represents an uploaded file", + serialize=lambda x: None, + parse_value=lambda x: x, + parse_literal=lambda x: x, +) + QueryRootType = GraphQLObjectType( name='QueryRoot', fields={ @@ -21,6 +43,20 @@ def resolve_raises(*_): 'who': GraphQLArgument(GraphQLString) }, resolver=lambda obj, info, who='World': 'Hello %s' % who + ), + 'testFile': GraphQLField( + type=GraphQLString, + args={ + 'what': GraphQLArgument(GraphQLNonNull(GraphQLUpload)), + }, + resolver=resolve_test_file, + ), + 'testMultiFile': GraphQLField( + type=GraphQLString, + args={ + 'whats': GraphQLArgument(GraphQLNonNull(GraphQLList(GraphQLUpload))), + }, + resolver=resolve_test_files, ) } ) diff --git a/tests/test_graphqlview.py b/tests/test_graphqlview.py index 77626d4..9277898 100644 --- a/tests/test_graphqlview.py +++ b/tests/test_graphqlview.py @@ -1,5 +1,6 @@ import pytest import json +from tempfile import NamedTemporaryFile try: from StringIO import StringIO @@ -465,18 +466,63 @@ def test_supports_pretty_printing(client): def test_post_multipart_data(client): - query = 'mutation TestMutation { writeTest { test } }' - response = client.post( - url_string(), - data= { - 'query': query, - 'file': (StringIO(), 'text1.txt'), - }, - content_type='multipart/form-data' - ) - + query = 'mutation TestMutation($file: Upload!) { writeTest { testFile( what: $file ) } }' + with NamedTemporaryFile() as t_file: + t_file.write(b'Fake Data\nLine2\n') + t_file.seek(0) + response = client.post( + url_string(), + data={ + 'operations': j(query=query, variables={'file': None}), + 't_file': t_file, + 'map': j(t_file=["variables.file"]), + }, + content_type='multipart/form-data' + ) assert response.status_code == 200 - assert response_json(response) == {'data': {u'writeTest': {u'test': u'Hello World'}}} + assert response_json(response) == {'data': {u'writeTest': {u'testFile': u'Fake Data\n'}}} + + +@pytest.mark.parametrize('app', [create_app(batch=True)]) +def test_post_multipart_data_multi(client): + query1 = ''' + mutation TestMutation($file: Upload!) { + writeTest { testFile( what: $file ) } + }''' + query2 = ''' + mutation TestMutation($files: [Upload]!) { + writeTest { testMultiFile( whats: $files ) } + }''' + with NamedTemporaryFile() as tf1, NamedTemporaryFile() as tf2: + tf1.write(b'tf1\nNot This line!!\n') + tf1.seek(0) + tf2.write(b'tf2\nNot This line!!\n') + tf2.seek(0) + response = client.post( + url_string(), + data={ + 'operations': json.dumps([ + {'query': query1, 'variables': {'file': None}}, + {'query': query2, 'variables': {'files': [None, None]}}, + ]), + 'tf1': tf1, + 'tf2': tf2, + 'map': j( + tf1=['0.variables.file', '1.variables.files.0'], + tf2=['1.variables.files.1'], + ), + }, + content_type='multipart/form-data' + ) + assert response.status_code == 200 + assert response_json(response) == [ + {'data': { + u'writeTest': {u'testFile': u'tf1\n'} + }}, + {'data': { + u'writeTest': {u'testMultiFile': u'tf1\ntf2\n'} + }}, + ] @pytest.mark.parametrize('app', [create_app(batch=True)]) @@ -514,8 +560,8 @@ def test_batch_supports_post_json_query_with_json_variables(client): # 'id': 1, 'data': {'test': "Hello Dolly"} }] - - + + @pytest.mark.parametrize('app', [create_app(batch=True)]) def test_batch_allows_post_with_operation_name(client): response = client.post(