|  | 
| 1 | 1 | #!/usr/bin/env python3 | 
| 2 | 2 | 
 | 
| 3 |  | -# Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
|  | 3 | +# Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
| 4 | 4 | # | 
| 5 | 5 | # Redistribution and use in source and binary forms, with or without | 
| 6 | 6 | # modification, are permitted provided that the following conditions | 
|  | 
| 26 | 26 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
| 27 | 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
| 28 | 28 | 
 | 
|  | 29 | +import base64 | 
|  | 30 | +import json | 
| 29 | 31 | import os | 
|  | 32 | +import subprocess | 
| 30 | 33 | import sys | 
| 31 | 34 | 
 | 
| 32 | 35 | sys.path.append("../../common") | 
| @@ -83,5 +86,241 @@ def test_model_reload(self): | 
| 83 | 86 |                 self.assertFalse(client.is_model_ready(ensemble_model_name)) | 
| 84 | 87 | 
 | 
| 85 | 88 | 
 | 
|  | 89 | +class ModelIDValidationTest(unittest.TestCase): | 
|  | 90 | +    """ | 
|  | 91 | +    Test model ID validation for user-provided model names. | 
|  | 92 | +
 | 
|  | 93 | +    Verifies that model names containing dangerous characters are properly rejected. | 
|  | 94 | +    Uses raw HTTP requests via curl instead of the Triton client to test server-side | 
|  | 95 | +    validation without the Triton client encoding special characters. | 
|  | 96 | +    """ | 
|  | 97 | + | 
|  | 98 | +    def setUp(self): | 
|  | 99 | +        self._shm_leak_detector = shm_util.ShmLeakDetector() | 
|  | 100 | +        self._client = httpclient.InferenceServerClient(f"{_tritonserver_ipaddr}:8000") | 
|  | 101 | +        self._triton_host = _tritonserver_ipaddr | 
|  | 102 | +        self._triton_port = 8000 | 
|  | 103 | + | 
|  | 104 | +        # Check if curl is available | 
|  | 105 | +        try: | 
|  | 106 | +            subprocess.run(["curl", "--version"], capture_output=True, check=True) | 
|  | 107 | +        except (subprocess.CalledProcessError, FileNotFoundError): | 
|  | 108 | +            self.skipTest("curl command not available - required for raw HTTP testing") | 
|  | 109 | + | 
|  | 110 | +    def _send_load_model_request(self, model_name): | 
|  | 111 | +        """Send HTTP request to load model for testing input validation using curl""" | 
|  | 112 | + | 
|  | 113 | +        # Create simple Triton Python model code | 
|  | 114 | +        python_model_code = f"""import triton_python_backend_utils as pb_utils | 
|  | 115 | +
 | 
|  | 116 | +class TritonPythonModel: | 
|  | 117 | +    def execute(self, requests): | 
|  | 118 | +        print('Hello world from model {model_name}') | 
|  | 119 | +        responses = [] | 
|  | 120 | +        for request in requests: | 
|  | 121 | +            # Simple identity function | 
|  | 122 | +            input_tensor = pb_utils.get_input_tensor_by_name(request, "INPUT0") | 
|  | 123 | +            out_tensor = pb_utils.Tensor("OUTPUT0", input_tensor.as_numpy()) | 
|  | 124 | +            responses.append(pb_utils.InferenceResponse([out_tensor])) | 
|  | 125 | +        return responses""" | 
|  | 126 | + | 
|  | 127 | +        # Base64 encode the Python code (as required by Triton server) | 
|  | 128 | +        python_code_b64 = base64.b64encode(python_model_code.encode("utf-8")).decode( | 
|  | 129 | +            "ascii" | 
|  | 130 | +        ) | 
|  | 131 | + | 
|  | 132 | +        # Create simple config | 
|  | 133 | +        config = { | 
|  | 134 | +            "name": model_name, | 
|  | 135 | +            "backend": "python", | 
|  | 136 | +            "max_batch_size": 4, | 
|  | 137 | +            "input": [{"name": "INPUT0", "data_type": "TYPE_FP32", "dims": [-1]}], | 
|  | 138 | +            "output": [{"name": "OUTPUT0", "data_type": "TYPE_FP32", "dims": [-1]}], | 
|  | 139 | +        } | 
|  | 140 | + | 
|  | 141 | +        payload = { | 
|  | 142 | +            "parameters": { | 
|  | 143 | +                "config": json.dumps(config), | 
|  | 144 | +                "file:/1/model.py": python_code_b64, | 
|  | 145 | +            } | 
|  | 146 | +        } | 
|  | 147 | + | 
|  | 148 | +        url = f"http://{self._triton_host}:{self._triton_port}/v2/repository/models/{model_name}/load" | 
|  | 149 | + | 
|  | 150 | +        # Convert payload to JSON string | 
|  | 151 | +        payload_json = json.dumps(payload) | 
|  | 152 | + | 
|  | 153 | +        try: | 
|  | 154 | +            # Use curl to send the request | 
|  | 155 | +            curl_cmd = [ | 
|  | 156 | +                "curl", | 
|  | 157 | +                "-s", | 
|  | 158 | +                "-w", | 
|  | 159 | +                "\n%{http_code}",  # Write HTTP status code on separate line | 
|  | 160 | +                "-X", | 
|  | 161 | +                "POST", | 
|  | 162 | +                "-H", | 
|  | 163 | +                "Content-Type: application/json", | 
|  | 164 | +                "-d", | 
|  | 165 | +                payload_json, | 
|  | 166 | +                url, | 
|  | 167 | +            ] | 
|  | 168 | + | 
|  | 169 | +            result = subprocess.run( | 
|  | 170 | +                curl_cmd, capture_output=True, text=True, timeout=10 | 
|  | 171 | +            ) | 
|  | 172 | + | 
|  | 173 | +            # Parse curl output - last line is status code, rest is response body | 
|  | 174 | +            output_lines = ( | 
|  | 175 | +                result.stdout.strip().split("\n") if result.stdout.strip() else [] | 
|  | 176 | +            ) | 
|  | 177 | +            if len(output_lines) >= 2: | 
|  | 178 | +                try: | 
|  | 179 | +                    status_code = int(output_lines[-1]) | 
|  | 180 | +                    response_text = "\n".join(output_lines[:-1]) | 
|  | 181 | +                except ValueError: | 
|  | 182 | +                    status_code = 0 | 
|  | 183 | +                    response_text = result.stdout or result.stderr or "Invalid response" | 
|  | 184 | +            elif len(output_lines) == 1 and output_lines[0].isdigit(): | 
|  | 185 | +                status_code = int(output_lines[0]) | 
|  | 186 | +                response_text = result.stderr or "No response body" | 
|  | 187 | +            else: | 
|  | 188 | +                status_code = 0 | 
|  | 189 | +                response_text = result.stdout or result.stderr or "No response" | 
|  | 190 | + | 
|  | 191 | +            # Return an object similar to requests.Response | 
|  | 192 | +            class CurlResponse: | 
|  | 193 | +                def __init__(self, status_code, text): | 
|  | 194 | +                    self.status_code = status_code | 
|  | 195 | +                    self.text = text | 
|  | 196 | +                    self.content = text.encode() | 
|  | 197 | + | 
|  | 198 | +            return CurlResponse(status_code, response_text) | 
|  | 199 | + | 
|  | 200 | +        except ( | 
|  | 201 | +            subprocess.TimeoutExpired, | 
|  | 202 | +            subprocess.CalledProcessError, | 
|  | 203 | +            ValueError, | 
|  | 204 | +        ) as e: | 
|  | 205 | +            # Return a mock response for errors | 
|  | 206 | +            class ErrorResponse: | 
|  | 207 | +                def __init__(self, error_msg): | 
|  | 208 | +                    self.status_code = 0 | 
|  | 209 | +                    self.text = f"Error: {error_msg}" | 
|  | 210 | +                    self.content = self.text.encode() | 
|  | 211 | + | 
|  | 212 | +            return ErrorResponse(str(e)) | 
|  | 213 | + | 
|  | 214 | +    def test_invalid_character_model_names(self): | 
|  | 215 | +        """Test that model names with invalid characters are properly rejected""" | 
|  | 216 | + | 
|  | 217 | +        # Based on INVALID_CHARS = ";|&$`<>()[]{}\\\"'*?~#!" | 
|  | 218 | +        invalid_model_names = [ | 
|  | 219 | +            r"model;test", | 
|  | 220 | +            r"model|test", | 
|  | 221 | +            r"model&test", | 
|  | 222 | +            r"model$test", | 
|  | 223 | +            r"model`test`", | 
|  | 224 | +            r"model<test>", | 
|  | 225 | +            r"model(test)", | 
|  | 226 | +            # r"model[test]", # request fails to send unencoded | 
|  | 227 | +            r"model{test}", | 
|  | 228 | +            r"model\test", | 
|  | 229 | +            r'model"test"', | 
|  | 230 | +            r"model'test'", | 
|  | 231 | +            r"model*test", | 
|  | 232 | +            # r"model?test", # request fails to send unencoded | 
|  | 233 | +            r"model~test", | 
|  | 234 | +            # r"model#test", # request fails to send unencoded | 
|  | 235 | +            r"model!test", | 
|  | 236 | +        ] | 
|  | 237 | + | 
|  | 238 | +        for invalid_name in invalid_model_names: | 
|  | 239 | +            with self.subTest(model_name=invalid_name): | 
|  | 240 | +                print(f"Testing invalid model name: {invalid_name}") | 
|  | 241 | + | 
|  | 242 | +                response = self._send_load_model_request(invalid_name) | 
|  | 243 | +                print( | 
|  | 244 | +                    f"Response for '{invalid_name}': Status {response.status_code}, Text: {response.text[:200]}..." | 
|  | 245 | +                ) | 
|  | 246 | + | 
|  | 247 | +                # Should not get a successful 200 response | 
|  | 248 | +                self.assertNotEqual( | 
|  | 249 | +                    200, | 
|  | 250 | +                    response.status_code, | 
|  | 251 | +                    f"Invalid model name '{invalid_name}' should not get 200 OK response", | 
|  | 252 | +                ) | 
|  | 253 | + | 
|  | 254 | +                # Special case for curly braces - they get stripped and cause load failures prior to the validation check | 
|  | 255 | +                if "{" in invalid_name or "}" in invalid_name: | 
|  | 256 | +                    self.assertIn( | 
|  | 257 | +                        "failed to load", | 
|  | 258 | +                        response.text, | 
|  | 259 | +                        f"Model with curly braces '{invalid_name}' should fail to load", | 
|  | 260 | +                    ) | 
|  | 261 | +                else: | 
|  | 262 | +                    # Normal case - should get character validation error | 
|  | 263 | +                    self.assertIn( | 
|  | 264 | +                        "Invalid stub name: contains invalid characters", | 
|  | 265 | +                        response.text, | 
|  | 266 | +                        f"invalid response for '{invalid_name}' should contain 'Invalid stub name: contains invalid characters'", | 
|  | 267 | +                    ) | 
|  | 268 | + | 
|  | 269 | +                # Verify the model is not loaded/ready since it was rejected | 
|  | 270 | +                try: | 
|  | 271 | +                    self.assertFalse( | 
|  | 272 | +                        self._client.is_model_ready(invalid_name), | 
|  | 273 | +                        f"Model '{invalid_name}' should not be ready after failed load attempt", | 
|  | 274 | +                    ) | 
|  | 275 | +                except Exception as e: | 
|  | 276 | +                    # If checking model readiness fails, that's also acceptable since the model name is invalid | 
|  | 277 | +                    print( | 
|  | 278 | +                        f"Note: Could not check model readiness for '{invalid_name}': {e}" | 
|  | 279 | +                    ) | 
|  | 280 | + | 
|  | 281 | +    def test_valid_model_names(self): | 
|  | 282 | +        """Test that valid model names work""" | 
|  | 283 | + | 
|  | 284 | +        valid_model_names = [ | 
|  | 285 | +            "TestModel123", | 
|  | 286 | +            "model-with-hyphens", | 
|  | 287 | +            "model_with_underscores", | 
|  | 288 | +        ] | 
|  | 289 | + | 
|  | 290 | +        for valid_name in valid_model_names: | 
|  | 291 | +            with self.subTest(model_name=valid_name): | 
|  | 292 | +                print(f"Testing valid model name: {valid_name}") | 
|  | 293 | + | 
|  | 294 | +                response = self._send_load_model_request(valid_name) | 
|  | 295 | +                print( | 
|  | 296 | +                    f"Response for valid '{valid_name}': Status {response.status_code}, Text: {response.text[:100]}..." | 
|  | 297 | +                ) | 
|  | 298 | + | 
|  | 299 | +                # Valid model names should be accepted and load successfully | 
|  | 300 | +                self.assertEqual( | 
|  | 301 | +                    200, | 
|  | 302 | +                    response.status_code, | 
|  | 303 | +                    f"Valid model name '{valid_name}' should get 200 OK response, got {response.status_code}. Response: {response.text}", | 
|  | 304 | +                ) | 
|  | 305 | + | 
|  | 306 | +                # Should not contain validation error message | 
|  | 307 | +                self.assertNotIn( | 
|  | 308 | +                    "Invalid stub name: contains invalid characters", | 
|  | 309 | +                    response.text, | 
|  | 310 | +                    f"Valid model name '{valid_name}' should not contain validation error message", | 
|  | 311 | +                ) | 
|  | 312 | + | 
|  | 313 | +                # Verify the model is actually loaded by checking if it's ready | 
|  | 314 | +                try: | 
|  | 315 | +                    self.assertTrue( | 
|  | 316 | +                        self._client.is_model_ready(valid_name), | 
|  | 317 | +                        f"Model '{valid_name}' should be ready after successful load", | 
|  | 318 | +                    ) | 
|  | 319 | +                    # Clean up - unload the model after testing | 
|  | 320 | +                    self._client.unload_model(valid_name) | 
|  | 321 | +                except Exception as e: | 
|  | 322 | +                    self.fail(f"Failed to check if model '{valid_name}' is ready: {e}") | 
|  | 323 | + | 
|  | 324 | + | 
| 86 | 325 | if __name__ == "__main__": | 
| 87 | 326 |     unittest.main() | 
0 commit comments