-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclient.py
More file actions
36 lines (32 loc) · 1 KB
/
client.py
File metadata and controls
36 lines (32 loc) · 1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import argparse
import onnxruntime as ort
import json
import numpy as np
from config import *
from transformers import BertTokenizer
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--input_text", type=str, required=True)
args = parser.parse_args()
text = args.input_text
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
model_path = "review_classifier_model.onnx"
session = ort.InferenceSession(
model_path,
providers=["CPUExecutionProvider"],
)
encoding = tokenizer(
text,
return_tensors="np",
max_length=max_length,
padding="max_length",
truncation=True,
)
input_ids = np.array(encoding["input_ids"], dtype=np.int64)
attention_mask = np.array(encoding["attention_mask"], dtype=np.int64)
outputs = session.run(None, {"input": input_ids, "attention_mask": attention_mask})
predicted_class = np.argmax(outputs[0][0])
out = "safe" if predicted_class == 1 else "not-safe"
out = {"text": text, "prediction": out}
with open("output.json", "w") as f:
json.dump(out, f)