Skip to content

Commit 98d33e2

Browse files
committed
added code for deploying model via torchserve
1 parent d373597 commit 98d33e2

11 files changed

+620
-15
lines changed

.idea/Python Code Generation.iml

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/csv-plugin.xml

+14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
2+
Requirements:
3+
1) Java JDK 11
4+
2) torch, torchserve, torch-model_archiver
5+
3) torchtext to load stoi and itos vocabs
6+
4) spacy
7+
8+
9+
10+
11+
torch-model-archiver --model-name py_code_generator --version 1.0 --serialized-file data/model_saved_by_jit.pt --handler custom_handler_for_deployment.py --extra-files data/SRC_stio_local,data/TRG_itos_local --model-file model.py
12+
13+
mv -f py_code_generator.mar data/
14+
15+
torchserve --start --ncs --model-store data/ --models py_code_generator=py_code_generator.mar
16+
17+
torchserve --stop
18+
19+
curl http://localhost:8080/ping
20+
21+
22+
curl http://localhost:8080/predictions/py_code_generator "write a Python function to convert binary to Gray codeword"

custom_handler_for_deployment.py

+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import json
2+
3+
from ts.torch_handler.base_handler import BaseHandler
4+
from model import Seq2Seq
5+
import spacy
6+
import torch
7+
import pickle
8+
import re
9+
import os
10+
import logging
11+
12+
spacy_en = spacy.load('en_core_web_sm')
13+
logger = logging.getLogger(__name__)
14+
15+
'''
16+
One can use a simple module entry itself as mentioned in
17+
https://pytorch.org/serve/custom_service.html
18+
but we will try a class entry because we have a lot to do
19+
in preprocess and postprocess.
20+
'''
21+
22+
23+
class ModelHandler(BaseHandler):
24+
25+
def __init__(self):
26+
self._context = None
27+
self.initialized = False
28+
self.explain = False
29+
self.target = 0
30+
31+
def initialize(self, context):
32+
# this func is called while scaling up or increasing the numbers of workers
33+
self.manifest = context.manifest
34+
35+
source_file = self.manifest['model']['modelFile']
36+
properties = context.system_properties
37+
model_dir = properties.get("model_dir")
38+
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
39+
40+
# Read model serialize/pt file
41+
serialized_file = self.manifest['model']['serializedFile']
42+
model_pt_path = os.path.join(model_dir, serialized_file)
43+
if not os.path.isfile(model_pt_path):
44+
raise RuntimeError("Missing the model.pt file")
45+
46+
self.model = torch.jit.load(model_pt_path,map_location=torch.device('cpu'))
47+
self.model.to(self.device)
48+
49+
# self.model =
50+
51+
self.initialized = True
52+
53+
with open("SRC_stio_local", "rb") as f:
54+
self.stoi = pickle.load(f)
55+
with open("TRG_itos_local", "rb") as f:
56+
self.itos = pickle.load(f)
57+
58+
self.trg_stoi = {j: i for i, j in enumerate(self.itos)}
59+
60+
self.answer_max_len = 100
61+
62+
self.src_pad_idx = self.stoi['<pad>']
63+
self.trg_pad_idx = self.trg_stoi['<pad>']
64+
65+
def handle(self, data, context):
66+
# this function is used during inference
67+
# Refer https://github.com/pytorch/serve/blob/master/examples/Huggingface_Transformers/Transformer_handler_generalized.py
68+
# for multiple requests
69+
70+
# TODO: make it for a batch of requests
71+
72+
input_text = data[0].get("data")
73+
if input_text is None:
74+
input_text = data[0].get("body")
75+
if isinstance(input_text, (bytes, bytearray)):
76+
input_text = input_text.decode('utf-8')
77+
78+
src = self.tokenize(input_text, self.stoi)
79+
# trg = '<sos>'
80+
# trg_indexes = [self.stoi[trg]]
81+
#
82+
#
83+
# decoder_outputs = []
84+
# for i in range(self.answer_max_len):
85+
# # TODO: I know this is way to expensive by recalculating encoder attentions
86+
# # but with the current implementation getting model.encoder or model.decoder
87+
# # is not working
88+
# trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(self.device)
89+
# decoder_output, _ = self.model.forward(src, trg_tensor)
90+
# pred_token = decoder_output.argmax(2)[:, -1].item()
91+
#
92+
# if pred_token == self.trg_stoi['<eos>']:
93+
# break
94+
#
95+
# decoder_outputs.append(self.itos[pred_token])
96+
# trg_indexes.append(pred_token)
97+
98+
src_mask = self.make_src_mask(src)
99+
100+
enc_src = self.model.encoder.forward(src, src_mask)
101+
102+
trg = '<sos>'
103+
trg_indexes = [self.stoi[trg]]
104+
105+
decoder_outputs = []
106+
for i in range(self.answer_max_len):
107+
trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(self.device)
108+
trg_mask = self.make_trg_mask(trg_tensor)
109+
110+
decoder_output, encoder_decoder_attention = self.model.decoder.forward(trg_tensor, enc_src, trg_mask, src_mask)
111+
112+
pred_token = decoder_output.argmax(2)[:, -1].item()
113+
114+
if pred_token == self.trg_stoi['<eos>']:
115+
break
116+
decoder_outputs.append(self.itos[pred_token])
117+
trg_indexes.append(pred_token)
118+
119+
return self.prune_outputs(decoder_outputs)
120+
121+
def tokenize(self, input, vocab):
122+
tokenized_input_ = [tok.text.lower() for tok in spacy_en.tokenizer(input)]
123+
tokenized_input = ['<sos>'] + tokenized_input_ + ['<eos>']
124+
125+
numericalized_input = [vocab[i] for i in tokenized_input]
126+
127+
tensor_input = torch.LongTensor([numericalized_input])
128+
129+
return tensor_input.to(self.device)
130+
131+
def prune_outputs(self, decoder_outputs):
132+
133+
def variables_names_in_print(matchobj):
134+
statement = matchobj.group(1)
135+
statement = statement.replace(" ", "")
136+
return "{" + statement + "}"
137+
138+
decoder_outputs = [i for i in decoder_outputs if
139+
i is not '']
140+
# removing redundant empty token created by tokenizer while identation during tokenization
141+
combined_output = " ".join(decoder_outputs)
142+
pruned_output = re.sub(r'\n |\n |\n ', r'\n', combined_output)
143+
# removing empty lines
144+
pruned_output = re.sub(r'{(.*?)}', variables_names_in_print,
145+
pruned_output)
146+
# setting printing variable names inside print(f'{}') statements
147+
148+
return [json.dumps(pruned_output)]
149+
150+
def make_src_mask(self,src):
151+
src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
152+
153+
return src_mask
154+
155+
def make_trg_mask(self, trg):
156+
# trg : [batch_size, trg_len]
157+
158+
# Masking pad values
159+
trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
160+
# trg_pad_mask : [batch_size,1,1, trg_len]
161+
162+
# Masking future values
163+
trg_len = trg.shape[1]
164+
trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=self.device)).bool()
165+
# trg_sub_mask : [trg_len, trg_len]
166+
167+
# combine both masks
168+
trg_mask = trg_pad_mask & trg_sub_mask
169+
# trg_mask = [batch_size,1,trg_len,trg_len]
170+
171+
return trg_mask

data/SRC_stio_local

109 KB
Binary file not shown.

data/TRG_itos_local

80 KB
Binary file not shown.

data/model_saved_by_jit.pt

36.6 MB
Binary file not shown.

data/py_code_generator.mar

33.7 MB
Binary file not shown.

0 commit comments

Comments
 (0)