-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
109 lines (95 loc) · 3.95 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Importing dependencies
import streamlit as st
import requests
from io import BytesIO
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from transformers import GPT2TokenizerFast
from PIL import Image
import os
import json
import argparse
from model import SceneScript
# Set page title, icon and layout
st.set_page_config(page_title="Scene Script", page_icon=":memo:", layout="wide")
# Set page title and description
st.title("Scene Script")
st.markdown("This is a demo of the Scene Script model. The model generates a caption for a given image.")
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Loading model config
model_param = '97M'
with open(f'config/{model_param}.json'.format(model_param)) as f:
params = json.load(f)
# Encoder Parameters
encoder_params = params.pop('encoder')
# Decoder Parameters
decoder_params = params.pop('decoder')
# Define the data preprocessing pipeline
transform = transforms.Compose([
transforms.Resize((encoder_params['image_size'], encoder_params['image_size'])),
transforms.ToTensor()
])
# Load the tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.add_special_tokens({'bos_token': '<|startoftext|>'})
# Define function for generating caption
def generate_caption(model, image_path, tokenizer):
# Open image and apply transformations
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0).to(device)
# Generate caption
with torch.no_grad():
encoded = model.encoder(image, return_embeddings=True)
caption_tokens = [tokenizer.bos_token_id]
for i in range(decoder_params['max_seq_len']):
caption_tensor = torch.tensor(caption_tokens).unsqueeze(0).to(device)
output = model.decoder(caption_tensor, context=encoded)
last_token_logits = output[0, -1, :]
next_token_id = torch.argmax(last_token_logits).item()
caption_tokens.append(next_token_id)
# Decode caption
caption = tokenizer.decode(caption_tokens, skip_special_tokens=True)
return caption
# Load the model
model = SceneScript(encoder_params, decoder_params).to(device)
model.load_state_dict(torch.load('weights/scene-script.pth', map_location=device))
model.eval()
# Create sidebar for file upload or image URL
st.sidebar.title("Upload Image")
option = st.sidebar.selectbox('Select Input Option', ['Upload Image', 'Input URL'])
if option == 'Upload Image':
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.sidebar.image(image, caption='Uploaded Image.', use_column_width=True)
else:
image_url = st.sidebar.text_input("Enter Image URL")
if image_url != "":
try:
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
st.sidebar.image(image, caption='Input Image.', use_column_width=True)
except:
st.sidebar.warning("Invalid URL entered. Please try again.")
# Generate caption when 'Generate Caption' button is clicked
if st.button("Generate Caption"):
if option == 'Upload Image':
if uploaded_file is None:
st.warning("Please upload an image first.")
else:
caption = generate_caption(model, uploaded_file, tokenizer)
with st.spinner('Generating caption...'):
st.success('\nCaption: ' + caption)
else:
if image_url == "":
st.warning("Please enter an image URL first.")
else:
try:
caption = generate_caption(model, BytesIO(response.content), tokenizer)
with st.spinner('Generating caption...'):
st.success("Generated Caption: " + caption)
except:
st.warning("Invalid URL entered. Please try again.")