3535import torch
3636from io import BytesIO
3737from PIL import UnidentifiedImageError
38+ import aiohttp
3839
3940
4041async def handle_recraft_file_request (
@@ -82,10 +83,16 @@ async def handle_recraft_file_request(
8283 return all_bytesio
8384
8485
85- def recraft_multipart_parser (data , parent_key = None , formatter : callable = None , converted_to_check : list [list ]= None , is_list = False ) -> dict :
86+ def recraft_multipart_parser (
87+ data ,
88+ parent_key = None ,
89+ formatter : callable = None ,
90+ converted_to_check : list [list ] = None ,
91+ is_list : bool = False ,
92+ return_mode : str = "formdata" # "dict" | "formdata"
93+ ) -> dict | aiohttp .FormData :
8694 """
87- Formats data such that multipart/form-data will work with requests library
88- when both files and data are present.
95+ Formats data such that multipart/form-data will work with aiohttp library when both files and data are present.
8996
9097 The OpenAI client that Recraft uses has a bizarre way of serializing lists:
9198
@@ -103,19 +110,19 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co
103110 # Modification of a function that handled a different type of multipart parsing, big ups:
104111 # https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b
105112
106- def handle_converted_lists (data , parent_key , lists_to_check = tuple [list ]):
113+ def handle_converted_lists (item , parent_key , lists_to_check = tuple [list ]):
107114 # if list already exists exists, just extend list with data
108115 for check_list in lists_to_check :
109116 for conv_tuple in check_list :
110117 if conv_tuple [0 ] == parent_key and isinstance (conv_tuple [1 ], list ):
111- conv_tuple [1 ].append (formatter (data ))
118+ conv_tuple [1 ].append (formatter (item ))
112119 return True
113120 return False
114121
115122 if converted_to_check is None :
116123 converted_to_check = []
117124
118-
125+ effective_mode = return_mode if parent_key is None else "dict"
119126 if formatter is None :
120127 formatter = lambda v : v # Multipart representation of value
121128
@@ -145,6 +152,15 @@ def handle_converted_lists(data, parent_key, lists_to_check=tuple[list]):
145152 else :
146153 converted .append ((current_key , formatter (value )))
147154
155+ if effective_mode == "formdata" :
156+ fd = aiohttp .FormData ()
157+ for k , v in dict (converted ).items ():
158+ if isinstance (v , list ):
159+ for item in v :
160+ fd .add_field (k , str (item ))
161+ else :
162+ fd .add_field (k , str (v ))
163+ return fd
148164 return dict (converted )
149165
150166
0 commit comments