-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
39 lines (31 loc) · 1.37 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
The model is based on the ViT and Transformer architecture.
The x-transformer library is a python wrapper by lucidrains (Phil Wang)
(https://github.com/lucidrains/x-transformers)
"""
# Importing Libraries
import torch
import torch.nn as nn
from x_transformers import ViTransformerWrapper, TransformerWrapper, Encoder, Decoder
# Define the model architecture
class SceneScript(nn.Module):
def __init__(self, encoder_params, decoder_params):
super(SceneScript, self).__init__()
# Encoder parameters
encoder_attn_layers_params = encoder_params.pop('attn_layers')
encoder_attn_layers = Encoder(**encoder_attn_layers_params)
encoder_params['attn_layers'] = encoder_attn_layers
# Decoder parameters
decoder_attn_layers_params = decoder_params.pop('attn_layers')
decoder_attn_layers = Decoder(**decoder_attn_layers_params)
decoder_params['attn_layers'] = decoder_attn_layers
# Define the image encoder
self.encoder = ViTransformerWrapper(**encoder_params)
# Define the caption decoder
self.decoder = TransformerWrapper(**decoder_params)
def forward(self, img, caption):
# Encode the image
encoded = self.encoder(img, return_embeddings=True)
# Decode the caption
output = self.decoder(caption, context=encoded)
return output