Skip to content

Commit 64738fa

Browse files
committed
add img2img
1 parent ecf5bc7 commit 64738fa

File tree

2 files changed

+74
-27
lines changed

2 files changed

+74
-27
lines changed

astronaut.png

991 KB
Loading

sdxl-sagemaker.ipynb

+74-27
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
},
4848
{
4949
"cell_type": "code",
50-
"execution_count": 1,
50+
"execution_count": 6,
5151
"metadata": {},
5252
"outputs": [],
5353
"source": [
@@ -60,18 +60,9 @@
6060
},
6161
{
6262
"cell_type": "code",
63-
"execution_count": 2,
64-
"metadata": {},
65-
"outputs": [
66-
{
67-
"name": "stdout",
68-
"output_type": "stream",
69-
"text": [
70-
"sagemaker.config INFO - Not applying SDK defaults from location: /opt/homebrew/share/sagemaker/config.yaml\n",
71-
"sagemaker.config INFO - Not applying SDK defaults from location: /Users/itay/Library/Application Support/sagemaker/config.yaml\n"
72-
]
73-
}
74-
],
63+
"execution_count": 7,
64+
"metadata": {},
65+
"outputs": [],
7566
"source": [
7667
"import json\n",
7768
"from sagemaker import ModelPackage\n",
@@ -83,16 +74,14 @@
8374
},
8475
{
8576
"cell_type": "code",
86-
"execution_count": 77,
77+
"execution_count": 37,
8778
"metadata": {},
8879
"outputs": [],
8980
"source": [
9081
"# TODO delete\n",
9182
"import os\n",
92-
"os.environ['AWS_PROFILE'] = 'marketplace-eng'\n",
93-
"# os.environ['AWS_PROFILE'] = 'octoml-sandbox-admin'\n",
94-
"#role = \"arn:aws:iam::006055839469:role/MarketplaceSageMakerValidationRole\" \n",
95-
"#role = \"arn:aws:iam::186900524924:role/SagemakerAdmin\"\n",
83+
"os.environ['AWS_PROFILE'] = 'octoml-sandbox-admin'\n",
84+
"role = \"arn:aws:iam::186900524924:role/SagemakerAdmin\"\n",
9685
"\n",
9786
"boto_session = boto3.Session()\n",
9887
"region = boto_session.region_name\n",
@@ -104,7 +93,7 @@
10493
},
10594
{
10695
"cell_type": "code",
107-
"execution_count": 87,
96+
"execution_count": 38,
10897
"metadata": {},
10998
"outputs": [],
11099
"source": [
@@ -131,7 +120,7 @@
131120
},
132121
{
133122
"cell_type": "code",
134-
"execution_count": 88,
123+
"execution_count": 39,
135124
"metadata": {},
136125
"outputs": [],
137126
"source": [
@@ -198,7 +187,7 @@
198187
},
199188
{
200189
"cell_type": "code",
201-
"execution_count": 73,
190+
"execution_count": 45,
202191
"metadata": {},
203192
"outputs": [],
204193
"source": [
@@ -231,7 +220,7 @@
231220
},
232221
{
233222
"cell_type": "code",
234-
"execution_count": 84,
223+
"execution_count": 46,
235224
"metadata": {},
236225
"outputs": [],
237226
"source": [
@@ -280,7 +269,7 @@
280269
},
281270
{
282271
"cell_type": "code",
283-
"execution_count": null,
272+
"execution_count": 43,
284273
"metadata": {},
285274
"outputs": [],
286275
"source": [
@@ -321,7 +310,7 @@
321310
},
322311
{
323312
"cell_type": "code",
324-
"execution_count": 5,
313+
"execution_count": null,
325314
"metadata": {},
326315
"outputs": [],
327316
"source": [
@@ -346,13 +335,64 @@
346335
"response = runtime_sm_client.invoke_endpoint(\n",
347336
" EndpointName=model_name,\n",
348337
" ContentType=content_type,\n",
349-
" Body=json.dumps(payload),\n",
338+
" Body=json.dumps(lightning_payload),\n",
350339
")\n",
351340
"\n",
352341
"output = json.loads(response[\"Body\"].read().decode(\"utf8\"))\n",
353342
"display_output(output)"
354343
]
355344
},
345+
{
346+
"cell_type": "markdown",
347+
"metadata": {},
348+
"source": [
349+
"Similar to generating images from text, we can also generate images from other images, utilizing the same performance and functionality (this also works with SDXL Lightning)"
350+
]
351+
},
352+
{
353+
"cell_type": "code",
354+
"execution_count": null,
355+
"metadata": {},
356+
"outputs": [],
357+
"source": [
358+
"import os\n",
359+
"from base64 import b64encode, b64decode\n",
360+
"\n",
361+
"# Read the image and encode it as base64\n",
362+
"init_image = b64encode(open('./astronaut.png', 'rb').read()).decode(\"utf-8\")\n",
363+
"\n",
364+
"img2img_payload = {\n",
365+
" \"prompt\": \"breathtaking, american woman, award winning photography, best quality, 8K HDR\",\n",
366+
" \"negative_prompt\": \"worst quality, low quality, bad quality, lazy eye\",\n",
367+
" \"width\": 1344,\n",
368+
" \"height\": 768,\n",
369+
" \"num_images\": 1,\n",
370+
" \"sampler\": \"DDIM\",\n",
371+
" \"steps\": 30,\n",
372+
" \"cfg_scale\": 12,\n",
373+
" \"use_refiner\": False,\n",
374+
" \"style_preset\": \"neon-punk\",\n",
375+
" \"strength\": 0.8,\n",
376+
" \"init_image\": init_image,\n",
377+
"\n",
378+
" # We use a specific seed to get a specific image out, but you can\n",
379+
" # change this or omit it\n",
380+
" \"seed\": 2701628909,\n",
381+
"}\n",
382+
"\n",
383+
"response = runtime_sm_client.invoke_endpoint(\n",
384+
" EndpointName=model_name,\n",
385+
" ContentType=content_type,\n",
386+
" Body=json.dumps(img2img_payload),\n",
387+
")\n",
388+
"\n",
389+
"output = json.loads(response[\"Body\"].read().decode(\"utf8\"))\n",
390+
"display_output(output)\n",
391+
"\n",
392+
"# display original image too for comparison\n",
393+
"display.Image(b64decode(init_image))"
394+
]
395+
},
356396
{
357397
"cell_type": "markdown",
358398
"metadata": {},
@@ -369,7 +409,7 @@
369409
},
370410
{
371411
"cell_type": "code",
372-
"execution_count": 90,
412+
"execution_count": 35,
373413
"metadata": {},
374414
"outputs": [],
375415
"source": [
@@ -393,12 +433,19 @@
393433
},
394434
{
395435
"cell_type": "code",
396-
"execution_count": null,
436+
"execution_count": 36,
397437
"metadata": {},
398438
"outputs": [],
399439
"source": [
400440
"model.delete_model()"
401441
]
442+
},
443+
{
444+
"cell_type": "code",
445+
"execution_count": null,
446+
"metadata": {},
447+
"outputs": [],
448+
"source": []
402449
}
403450
],
404451
"metadata": {

0 commit comments

Comments
 (0)