Skip to content

Commit 478441b

Browse files
authored
Merge branch 'master' into fw-and-version-bug
2 parents f498b3f + d089d40 commit 478441b

File tree

3 files changed

+132
-30
lines changed

3 files changed

+132
-30
lines changed

sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,50 @@ def from_dependency_file_path(dependency_file_path):
9494
class RuntimeEnvironmentManager:
9595
"""Runtime Environment Manager class to manage runtime environment."""
9696

97+
def _validate_path(self, path: str) -> str:
98+
"""Validate and sanitize file path to prevent path traversal attacks.
99+
100+
Args:
101+
path (str): The file path to validate
102+
103+
Returns:
104+
str: The validated absolute path
105+
106+
Raises:
107+
ValueError: If the path is invalid or contains suspicious patterns
108+
"""
109+
if not path:
110+
raise ValueError("Path cannot be empty")
111+
112+
# Get absolute path to prevent path traversal
113+
abs_path = os.path.abspath(path)
114+
115+
# Check for null bytes (common in path traversal attacks)
116+
if '\x00' in path:
117+
raise ValueError(f"Invalid path contains null byte: {path}")
118+
119+
return abs_path
120+
121+
def _validate_env_name(self, env_name: str) -> None:
122+
"""Validate conda environment name to prevent command injection.
123+
124+
Args:
125+
env_name (str): The environment name to validate
126+
127+
Raises:
128+
ValueError: If the environment name contains invalid characters
129+
"""
130+
if not env_name:
131+
raise ValueError("Environment name cannot be empty")
132+
133+
# Allow only alphanumeric, underscore, and hyphen
134+
import re
135+
if not re.match(r'^[a-zA-Z0-9_-]+$', env_name):
136+
raise ValueError(
137+
f"Invalid environment name '{env_name}'. "
138+
"Only alphanumeric characters, underscores, and hyphens are allowed."
139+
)
140+
97141
def snapshot(self, dependencies: str = None) -> str:
98142
"""Creates snapshot of the user's environment
99143
@@ -252,39 +296,50 @@ def _is_file_exists(self, dependencies):
252296

253297
def _install_requirements_txt(self, local_path, python_executable):
254298
"""Install requirements.txt file"""
255-
cmd = f"{python_executable} -m pip install -r {local_path} -U"
256-
logger.info("Running command: '%s' in the dir: '%s' ", cmd, os.getcwd())
299+
# Validate path to prevent command injection
300+
validated_path = self._validate_path(local_path)
301+
cmd = [python_executable, "-m", "pip", "install", "-r", validated_path, "-U"]
302+
logger.info("Running command: '%s' in the dir: '%s' ", " ".join(cmd), os.getcwd())
257303
_run_shell_cmd(cmd)
258-
logger.info("Command %s ran successfully", cmd)
304+
logger.info("Command %s ran successfully", " ".join(cmd))
259305

260306
def _create_conda_env(self, env_name, local_path):
261307
"""Create conda env using conda yml file"""
308+
# Validate inputs to prevent command injection
309+
self._validate_env_name(env_name)
310+
validated_path = self._validate_path(local_path)
262311

263-
cmd = f"{self._get_conda_exe()} env create -n {env_name} --file {local_path}"
264-
logger.info("Creating conda environment %s using: %s.", env_name, cmd)
312+
cmd = [self._get_conda_exe(), "env", "create", "-n", env_name, "--file", validated_path]
313+
logger.info("Creating conda environment %s using: %s.", env_name, " ".join(cmd))
265314
_run_shell_cmd(cmd)
266315
logger.info("Conda environment %s created successfully.", env_name)
267316

268317
def _install_req_txt_in_conda_env(self, env_name, local_path):
269318
"""Install requirements.txt in the given conda environment"""
319+
# Validate inputs to prevent command injection
320+
self._validate_env_name(env_name)
321+
validated_path = self._validate_path(local_path)
270322

271-
cmd = f"{self._get_conda_exe()} run -n {env_name} pip install -r {local_path} -U"
272-
logger.info("Activating conda env and installing requirements: %s", cmd)
323+
cmd = [self._get_conda_exe(), "run", "-n", env_name, "pip", "install", "-r", validated_path, "-U"]
324+
logger.info("Activating conda env and installing requirements: %s", " ".join(cmd))
273325
_run_shell_cmd(cmd)
274326
logger.info("Requirements installed successfully in conda env %s", env_name)
275327

276328
def _update_conda_env(self, env_name, local_path):
277329
"""Update conda env using conda yml file"""
330+
# Validate inputs to prevent command injection
331+
self._validate_env_name(env_name)
332+
validated_path = self._validate_path(local_path)
278333

279-
cmd = f"{self._get_conda_exe()} env update -n {env_name} --file {local_path}"
280-
logger.info("Updating conda env: %s", cmd)
334+
cmd = [self._get_conda_exe(), "env", "update", "-n", env_name, "--file", validated_path]
335+
logger.info("Updating conda env: %s", " ".join(cmd))
281336
_run_shell_cmd(cmd)
282337
logger.info("Conda env %s updated succesfully", env_name)
283338

284339
def _export_conda_env_from_prefix(self, prefix, local_path):
285340
"""Export the conda env to a conda yml file"""
286341

287-
cmd = f"{self._get_conda_exe()} env export -p {prefix} --no-builds > {local_path}"
342+
cmd = [self._get_conda_exe(), "env", "export", "-p", prefix, "--no-builds", ">", local_path]
288343
logger.info("Exporting conda environment: %s", cmd)
289344
_run_shell_cmd(cmd)
290345
logger.info("Conda environment %s exported successfully", prefix)
@@ -402,19 +457,26 @@ def _run_pre_execution_command_script(script_path: str):
402457
return return_code, error_logs
403458

404459

405-
def _run_shell_cmd(cmd: str):
460+
def _run_shell_cmd(cmd: list):
406461
"""This method runs a given shell command using subprocess
407462
408-
Raises RuntimeEnvironmentError if the command fails
463+
Args:
464+
cmd (list): Command and arguments as a list (e.g., ['pip', 'install', '-r', 'requirements.txt'])
465+
466+
Raises:
467+
RuntimeEnvironmentError: If the command fails
468+
ValueError: If cmd is not a list
409469
"""
470+
if not isinstance(cmd, list):
471+
raise ValueError("Command must be a list of arguments for security reasons")
410472

411-
process = subprocess.Popen((cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
473+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
412474

413475
_log_output(process)
414476
error_logs = _log_error(process)
415477
return_code = process.wait()
416478
if return_code:
417-
error_message = f"Encountered error while running command '{cmd}'. Reason: {error_logs}"
479+
error_message = f"Encountered error while running command '{' '.join(cmd)}'. Reason: {error_logs}"
418480
raise RuntimeEnvironmentError(error_message)
419481

420482

sagemaker-train/tests/unit/ai_registry/test_dataset_domain_id.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,58 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Unit tests for domain-id tagging in DataSet."""
14+
import json
15+
import tempfile
16+
import os
1417
import pytest
1518
from unittest.mock import Mock, patch, MagicMock
1619
from sagemaker.ai_registry.dataset import DataSet
1720
from sagemaker.ai_registry.dataset_utils import CustomizationTechnique
1821

1922

23+
# Sample RLVR format dataset (GSM8K style)
24+
SAMPLE_DATASET = {
25+
"data_source": "openai/gsm8k",
26+
"prompt": [{"content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\".", "role": "user"}],
27+
"ability": "math",
28+
"reward_model": {"ground_truth": "72", "style": "rule"},
29+
"extra_info": {"answer": "Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72", "index": 0, "question": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?", "split": "train"}
30+
}
31+
32+
33+
@pytest.fixture
34+
def sample_dataset_file():
35+
"""Create a temporary JSONL file with sample dataset."""
36+
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
37+
json.dump(SAMPLE_DATASET, f)
38+
temp_path = f.name
39+
40+
yield temp_path
41+
42+
# Cleanup
43+
if os.path.exists(temp_path):
44+
os.remove(temp_path)
45+
46+
2047
class TestDataSetDomainId:
2148
"""Test domain-id is added to SearchKeywords when available."""
2249

2350
@patch('sagemaker.core.helper.session_helper.Session')
2451
@patch('sagemaker.ai_registry.dataset._get_current_domain_id')
2552
@patch('sagemaker.ai_registry.dataset.AIRHub')
26-
@patch('sagemaker.ai_registry.dataset.validate_dataset')
53+
@patch('sagemaker.train.defaults.TrainDefaults.get_sagemaker_session')
54+
@patch('sagemaker.train.defaults.TrainDefaults.get_role')
2755
def test_domain_id_added_when_available(
28-
self, mock_validate, mock_air_hub, mock_get_domain_id, mock_session
56+
self, mock_get_role, mock_get_session, mock_air_hub, mock_get_domain_id, mock_session, sample_dataset_file
2957
):
3058
"""Test that domain-id is added to tags when available."""
3159
# Setup mocks
3260
mock_domain_id = "d-test123456"
3361
mock_get_domain_id.return_value = mock_domain_id
34-
mock_session.return_value = Mock()
62+
mock_session_instance = Mock()
63+
mock_session.return_value = mock_session_instance
64+
mock_get_session.return_value = mock_session_instance
65+
mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role"
3566

3667
# Mock AIRHub methods
3768
mock_air_hub.upload_to_s3 = Mock()
@@ -46,11 +77,11 @@ def test_domain_id_added_when_available(
4677
'HubContentDocument': '{"DatasetS3Bucket": "bucket", "DatasetS3Prefix": "prefix"}'
4778
})
4879

49-
# Create dataset
80+
# Create dataset with real file
5081
with patch('sagemaker.ai_registry.dataset.DataSet.wait'):
5182
dataset = DataSet.create(
5283
name="test-dataset",
53-
source="test-data.jsonl",
84+
source=sample_dataset_file,
5485
customization_technique=CustomizationTechnique.SFT
5586
)
5687

@@ -67,14 +98,18 @@ def test_domain_id_added_when_available(
6798
@patch('sagemaker.core.helper.session_helper.Session')
6899
@patch('sagemaker.ai_registry.dataset._get_current_domain_id')
69100
@patch('sagemaker.ai_registry.dataset.AIRHub')
70-
@patch('sagemaker.ai_registry.dataset.validate_dataset')
101+
@patch('sagemaker.train.defaults.TrainDefaults.get_sagemaker_session')
102+
@patch('sagemaker.train.defaults.TrainDefaults.get_role')
71103
def test_domain_id_not_added_when_unavailable(
72-
self, mock_validate, mock_air_hub, mock_get_domain_id, mock_session
104+
self, mock_get_role, mock_get_session, mock_air_hub, mock_get_domain_id, mock_session, sample_dataset_file
73105
):
74106
"""Test that domain-id is not added when unavailable (non-Studio)."""
75107
# Setup mocks - domain_id returns None
76108
mock_get_domain_id.return_value = None
77-
mock_session.return_value = Mock()
109+
mock_session_instance = Mock()
110+
mock_session.return_value = mock_session_instance
111+
mock_get_session.return_value = mock_session_instance
112+
mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role"
78113

79114
# Mock AIRHub methods
80115
mock_air_hub.upload_to_s3 = Mock()
@@ -89,11 +124,11 @@ def test_domain_id_not_added_when_unavailable(
89124
'HubContentDocument': '{"DatasetS3Bucket": "bucket", "DatasetS3Prefix": "prefix"}'
90125
})
91126

92-
# Create dataset
127+
# Create dataset with real file
93128
with patch('sagemaker.ai_registry.dataset.DataSet.wait'):
94129
dataset = DataSet.create(
95130
name="test-dataset",
96-
source="test-data.jsonl",
131+
source=sample_dataset_file,
97132
customization_technique=CustomizationTechnique.SFT
98133
)
99134

@@ -110,14 +145,19 @@ def test_domain_id_not_added_when_unavailable(
110145
@patch('sagemaker.core.helper.session_helper.Session')
111146
@patch('sagemaker.ai_registry.dataset._get_current_domain_id')
112147
@patch('sagemaker.ai_registry.dataset.AIRHub')
148+
@patch('sagemaker.train.defaults.TrainDefaults.get_sagemaker_session')
149+
@patch('sagemaker.train.defaults.TrainDefaults.get_role')
113150
def test_domain_id_added_without_customization_technique(
114-
self, mock_air_hub, mock_get_domain_id, mock_session
151+
self, mock_get_role, mock_get_session, mock_air_hub, mock_get_domain_id, mock_session, sample_dataset_file
115152
):
116153
"""Test that domain-id is added even without customization_technique."""
117154
# Setup mocks
118155
mock_domain_id = "d-test789"
119156
mock_get_domain_id.return_value = mock_domain_id
120-
mock_session.return_value = Mock()
157+
mock_session_instance = Mock()
158+
mock_session.return_value = mock_session_instance
159+
mock_get_session.return_value = mock_session_instance
160+
mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role"
121161

122162
# Mock AIRHub methods
123163
mock_air_hub.upload_to_s3 = Mock()
@@ -132,11 +172,11 @@ def test_domain_id_added_without_customization_technique(
132172
'HubContentDocument': '{"DatasetS3Bucket": "bucket", "DatasetS3Prefix": "prefix"}'
133173
})
134174

135-
# Create dataset WITHOUT customization_technique
175+
# Create dataset WITHOUT customization_technique using real file
136176
with patch('sagemaker.ai_registry.dataset.DataSet.wait'):
137177
dataset = DataSet.create(
138178
name="test-dataset",
139-
source="test-data.jsonl"
179+
source=sample_dataset_file
140180
# No customization_technique
141181
)
142182

sagemaker-train/tests/unit/train/remote_function/test_runtime_environment_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def test_runs_command_successfully(self, mock_popen, mock_log_output, mock_log_e
490490
mock_popen.return_value = mock_process
491491
mock_log_error.return_value = ""
492492

493-
_run_shell_cmd("echo test")
493+
_run_shell_cmd(["echo", "test"])
494494

495495
mock_popen.assert_called_once()
496496

@@ -505,7 +505,7 @@ def test_runs_command_raises_error_on_failure(self, mock_popen, mock_log_output,
505505
mock_log_error.return_value = "Error message"
506506

507507
with pytest.raises(RuntimeEnvironmentError):
508-
_run_shell_cmd("false")
508+
_run_shell_cmd(["false"])
509509

510510

511511
class TestLogOutput:

0 commit comments

Comments
 (0)