Skip to content

Commit 11ddbfd

Browse files
authored
Update: Download the ONNX model if not found main.py
1 parent c837549 commit 11ddbfd

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

main.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,38 @@ def _ensure_opset15(original_path: str) -> str:
125125
onnx.save(converted, conv_path)
126126
return conv_path
127127

128+
def download_model(url, save_path):
129+
"""Download the ONNX model from the provided URL and save it to the specified path."""
130+
import urllib.request
131+
132+
print(f"Downloading model from {url}...")
133+
try:
134+
# Create the directory if it doesn't exist
135+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
136+
137+
# Download the file
138+
urllib.request.urlretrieve(url, save_path)
139+
print(f"Model downloaded successfully to {save_path}")
140+
return True
141+
except Exception as e:
142+
print(f"Error downloading model: {str(e)}")
143+
return False
144+
128145
class NudeDetector:
129146
def __init__(self, providers=None):
147+
# Set up model paths
148+
model_dir = os.path.join(os.path.dirname(__file__), "Models")
149+
model_orig = os.path.join(model_dir, "best.onnx")
150+
151+
# Check if model exists, if not download it
152+
if not os.path.exists(model_orig):
153+
print("Model file not found. Creating Models directory and downloading model...")
154+
model_url = "https://github.com/im-syn/SafeVision/raw/refs/heads/main/Models/best.onnx"
155+
success = download_model(model_url, model_orig)
156+
if not success:
157+
raise FileNotFoundError(f"Could not download model from {model_url}. Please download manually and place in {model_dir}")
158+
130159
# Convert best.onnx → best_opset15.onnx at runtime
131-
model_orig = os.path.join(os.path.dirname(__file__), "Models/best.onnx")
132160
model_to_load = _ensure_opset15(model_orig)
133161

134162
self.onnx_session = onnxruntime.InferenceSession(

0 commit comments

Comments
 (0)