Skip to content

Commit 08daf4f

Browse files
committed
feat(stream_events): stream send()'s to client too
1 parent 46ea977 commit 08daf4f

File tree

4 files changed

+51
-19
lines changed

4 files changed

+51
-19
lines changed

api/app.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def truncateInputs(inputs: dict):
129129

130130
# Inference is ran for every server call
131131
# Reference your preloaded global model variable here.
132-
def inference(all_inputs: dict) -> dict:
132+
async def inference(all_inputs: dict, response) -> dict:
133133
global model
134134
global pipelines
135135
global last_model_id
@@ -151,6 +151,8 @@ def inference(all_inputs: dict) -> dict:
151151
send_opts.update({"SEND_URL": call_inputs.get("SEND_URL")})
152152
if call_inputs.get("SIGN_KEY", None):
153153
send_opts.update({"SIGN_KEY": call_inputs.get("SIGN_KEY")})
154+
if response:
155+
send_opts.update({"response": response})
154156

155157
if model_inputs == None or call_inputs == None:
156158
return {
@@ -356,7 +358,7 @@ def inference(all_inputs: dict) -> dict:
356358
)
357359
)
358360

359-
send("inference", "start", {"startRequestId": startRequestId}, send_opts)
361+
await send("inference", "start", {"startRequestId": startRequestId}, send_opts)
360362

361363
# Run patchmatch for inpainting
362364
if call_inputs.get("FILL_MODE", None) == "patchmatch":
@@ -417,7 +419,7 @@ def inference(all_inputs: dict) -> dict:
417419
send_opts=send_opts,
418420
)
419421
torch.set_grad_enabled(False)
420-
send("inference", "done", {"startRequestId": startRequestId}, send_opts)
422+
await send("inference", "done", {"startRequestId": startRequestId}, send_opts)
421423
result.update({"$timings": getTimings()})
422424
return result
423425

@@ -435,8 +437,8 @@ def inference(all_inputs: dict) -> dict:
435437
callback = None
436438
if model_inputs.get("callback_steps", None):
437439

438-
def callback(step: int, timestep: int, latents: torch.FloatTensor):
439-
send(
440+
async def callback(step: int, timestep: int, latents: torch.FloatTensor):
441+
await send(
440442
"inference",
441443
"progress",
442444
{"startRequestId": startRequestId, "step": step},
@@ -473,7 +475,7 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
473475
image.save(buffered, format="PNG")
474476
images_base64.append(base64.b64encode(buffered.getvalue()).decode("utf-8"))
475477

476-
send("inference", "done", {"startRequestId": startRequestId}, send_opts)
478+
await send("inference", "done", {"startRequestId": startRequestId}, send_opts)
477479

478480
# Return the results as a dictionary
479481
if len(images_base64) > 1:

api/send.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def getTimings():
7070
return timings
7171

7272

73-
def send(type: str, status: str, payload: dict = {}, opts: dict = {}):
73+
async def send(type: str, status: str, payload: dict = {}, opts: dict = {}):
7474
now = get_now()
7575
send_url = opts.get("SEND_URL", SEND_URL)
7676
sign_key = opts.get("SIGN_KEY", SIGN_KEY)
@@ -102,6 +102,11 @@ def send(type: str, status: str, payload: dict = {}, opts: dict = {}):
102102
if send_url:
103103
futureSession.post(send_url, json=data)
104104

105+
response = opts.get("response")
106+
if response:
107+
print("streaming above")
108+
await response.send(json.dumps(data) + "\n")
109+
105110
# try:
106111
# requests.post(send_url, json=data) # , timeout=0.0000000001)
107112
# except requests.exceptions.ReadTimeout:

api/server.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import app as user_src
1010
import traceback
1111
import os
12+
import json
1213

1314
# We do the model load-to-GPU step on server startup
1415
# so the model object is available globally for reuse
@@ -34,14 +35,21 @@ def healthcheck(request):
3435

3536
# Inference POST handler at '/' is called for every http call from Banana
3637
@server.route("/", methods=["POST"])
37-
def inference(request):
38+
async def inference(request):
3839
try:
39-
model_inputs = response.json.loads(request.json)
40+
all_inputs = response.json.loads(request.json)
4041
except:
41-
model_inputs = request.json
42+
all_inputs = request.json
43+
44+
call_inputs = all_inputs.get("callInputs", None)
45+
stream_events = call_inputs and call_inputs.get("streamEvents", 0) != 0
46+
47+
streaming_response = None
48+
if stream_events:
49+
streaming_response = await request.respond(content_type="application/x-ndjson")
4250

4351
try:
44-
output = user_src.inference(model_inputs)
52+
output = await user_src.inference(all_inputs, streaming_response)
4553
except Exception as err:
4654
output = {
4755
"$error": {
@@ -52,7 +60,10 @@ def inference(request):
5260
}
5361
}
5462

55-
return response.json(output)
63+
if stream_events:
64+
await streaming_response.send(json.dumps(output) + "\n")
65+
else:
66+
return response.json(output)
5667

5768

5869
if __name__ == "__main__":

test.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def runTest(name, args, extraCallInputs, extraModelInputs):
114114
"modelInputs": inputs,
115115
"startOnly": False,
116116
}
117+
117118
response = requests.post(f"{BANANA_API_URL}/start/v4/", json=payload)
119+
118120
result = response.json()
119121
callID = result.get("callID")
120122

@@ -185,13 +187,25 @@ def runTest(name, args, extraCallInputs, extraModelInputs):
185187

186188
else:
187189
test_url = args.get("test_url", None) or TEST_URL
188-
response = requests.post(test_url, json=inputs)
189-
try:
190-
result = response.json()
191-
except requests.exceptions.JSONDecodeError as error:
192-
print(error)
193-
print(response.text)
194-
sys.exit(1)
190+
call_inputs = inputs["callInputs"]
191+
stream_events = call_inputs and call_inputs.get("streamEvents", 0) != 0
192+
print({"stream_events": stream_events})
193+
if stream_events:
194+
result = None
195+
response = requests.post(test_url, json=inputs, stream=True)
196+
for line in response.iter_lines():
197+
if line:
198+
result = json.loads(line)
199+
if not result.get("$timings", None):
200+
print(result)
201+
else:
202+
response = requests.post(test_url, json=inputs)
203+
try:
204+
result = response.json()
205+
except requests.exceptions.JSONDecodeError as error:
206+
print(error)
207+
print(response.text)
208+
sys.exit(1)
195209

196210
finish = time.time() - start
197211
timings = result.get("$timings")

0 commit comments

Comments
 (0)