1212 ToDtype ,
1313)
1414import logging
15+
1516logging .basicConfig (
1617 level = logging .INFO ,
1718 format = "%(asctime)s [%(levelname)s] %(name)s: %(message)s" ,
1819)
1920logger = logging .getLogger (__name__ )
2021
22+
2123# Trained on COCOFake dataset
2224class BNext_M_ModelONNX :
2325 def __init__ (
@@ -29,14 +31,18 @@ def __init__(
2931 / "onnx_models"
3032 / "bnext_M_dffd_model.onnx"
3133 )
32- providers = [("CUDAExecutionProvider" , {"cudnn_conv_use_max_workspace" : '1' }),'CPUExecutionProvider' ]
34+ providers = [
35+ ("CUDAExecutionProvider" , {"cudnn_conv_use_max_workspace" : "1" }),
36+ "CPUExecutionProvider" ,
37+ ]
3338 sess_options = ort .SessionOptions ()
3439 sess_options .execution_mode = ort .ExecutionMode .ORT_PARALLEL
3540 # sess_options.log_severity_level = 0
3641 # sess_options.enable_profiling = True
3742 self .session = ort .InferenceSession (
3843 str (self .model_path ), # Convert Path object to string for onnxruntime
39- sess_options = sess_options , providers = providers
44+ sess_options = sess_options ,
45+ providers = providers ,
4046 )
4147 dev = ort .get_device ()
4248 logger .info ("BNext_M Model ONNX %s" , dev )
@@ -46,7 +52,6 @@ def __init__(
4652 logger .info ("ort available_providers %s" , available_providers )
4753 except Exception as e :
4854 logger .error (f"Error getting available providers: { e } " )
49-
5055
5156 self .resolution = resolution
5257 self .valid_extensions = (".jpg" , ".jpeg" , ".png" )
0 commit comments