diff --git a/README.md b/README.md index e843e9b..aa2a313 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,17 @@ The SDK automatically handles all dependency packaging for Data Cloud deployment **No need to worry about platform compatibility** - the SDK handles this automatically through the Docker-based packaging process. +## files directory + +``` +. +├── payload +│ ├── config.json +│ ├── entrypoint.py +├── files +│ ├── data.csv +``` + ## py-files directory Your Python dependencies can be packaged as .py files, .zip archives (containing multiple .py files or a Python package structure), or .egg files. @@ -124,6 +135,7 @@ Your Python dependencies can be packaged as .py files, .zip archives (containing Your entry point script will define logic using the `Client` object which wraps data access layers. You should only need the following methods: +* `find_file_path(file_name)` - Returns a file path * `read_dlo(name)` – Read from a Data Lake Object by name * `read_dmo(name)` – Read from a Data Model Object by name * `write_to_dlo(name, spark_dataframe, write_mode)` – Write to a Data Model Object by name with a Spark dataframe @@ -197,6 +209,7 @@ Argument: Options: - `--config-file TEXT`: Path to configuration file - `--dependencies TEXT`: Additional dependencies (can be specified multiple times) +- `--profile TEXT`: Credential profile name (default: "default") #### `datacustomcode zip` diff --git a/src/datacustomcode/client.py b/src/datacustomcode/client.py index 7a1f9a4..15d34a9 100644 --- a/src/datacustomcode/client.py +++ b/src/datacustomcode/client.py @@ -24,9 +24,12 @@ from pyspark.sql import SparkSession from datacustomcode.config import SparkConfig, config +from datacustomcode.file.path.default import DefaultFindFilePath from datacustomcode.io.reader.base import BaseDataCloudReader if TYPE_CHECKING: + from pathlib import Path + from pyspark.sql import DataFrame as PySparkDataFrame from datacustomcode.io.reader.base import BaseDataCloudReader @@ -100,11 +103,13 @@ class Client: writing, we print to the console instead of writing to Data Cloud. Args: + finder: Find a file path reader: A custom reader to use for reading Data Cloud objects. writer: A custom writer to use for writing Data Cloud objects. Example: >>> client = Client() + >>> file_path = client.find_file_path("data.csv") >>> dlo = client.read_dlo("my_dlo") >>> client.write_to_dmo("my_dmo", dlo) """ @@ -112,6 +117,7 @@ class Client: _instance: ClassVar[Optional[Client]] = None _reader: BaseDataCloudReader _writer: BaseDataCloudWriter + _file: DefaultFindFilePath _data_layer_history: dict[DataCloudObjectType, set[str]] def __new__( @@ -154,6 +160,7 @@ def __new__( writer_init = writer cls._instance._reader = reader_init cls._instance._writer = writer_init + cls._instance._file = DefaultFindFilePath() cls._instance._data_layer_history = { DataCloudObjectType.DLO: set(), DataCloudObjectType.DMO: set(), @@ -212,6 +219,11 @@ def write_to_dmo( self._validate_data_layer_history_does_not_contain(DataCloudObjectType.DLO) return self._writer.write_to_dmo(name, dataframe, write_mode, **kwargs) + def find_file_path(self, file_name: str) -> Path: + """Return a file path""" + + return self._file.find_file_path(file_name) + def _validate_data_layer_history_does_not_contain( self, data_cloud_object_type: DataCloudObjectType ) -> None: diff --git a/src/datacustomcode/file/__init__.py b/src/datacustomcode/file/__init__.py new file mode 100644 index 0000000..93988ff --- /dev/null +++ b/src/datacustomcode/file/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/datacustomcode/file/base.py b/src/datacustomcode/file/base.py new file mode 100644 index 0000000..fdd8320 --- /dev/null +++ b/src/datacustomcode/file/base.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + + +class BaseDataAccessLayer: + """Base class for data access layer implementations.""" diff --git a/src/datacustomcode/file/path/__init__.py b/src/datacustomcode/file/path/__init__.py new file mode 100644 index 0000000..93988ff --- /dev/null +++ b/src/datacustomcode/file/path/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/datacustomcode/file/path/default.py b/src/datacustomcode/file/path/default.py new file mode 100644 index 0000000..c07471b --- /dev/null +++ b/src/datacustomcode/file/path/default.py @@ -0,0 +1,163 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +from pathlib import Path +from typing import Optional + +from datacustomcode.file.base import BaseDataAccessLayer + + +class FileReaderError(Exception): + """Base exception for file reader operations.""" + + +class FileNotFoundError(FileReaderError): + """Raised when a file cannot be found.""" + + +class DefaultFindFilePath(BaseDataAccessLayer): + """Base class for finding file path + + This class provides a framework for finding files from various locations + with configurable search strategies and error handling. + """ + + # Default configuration values + DEFAULT_CODE_PACKAGE = "payload" + DEFAULT_FILE_FOLDER = "files" + DEFAULT_CONFIG_FILE = "config.json" + + def __init__( + self, + code_package: Optional[str] = None, + file_folder: Optional[str] = None, + config_file: Optional[str] = None, + ): + """Initialize the file reader with configuration. + + Args: + code_package: The default code package directory to search + file_folder: The folder containing files relative to the code package + config_file: The configuration file to use for path resolution + """ + self.code_package = code_package or self.DEFAULT_CODE_PACKAGE + self.file_folder = file_folder or self.DEFAULT_FILE_FOLDER + self.config_file = config_file or self.DEFAULT_CONFIG_FILE + + def find_file_path(self, file_name: str) -> Path: + """Find a file path. + + Args: + file_name: The name of the file to open + + Returns: + A file path + + Raises: + FileNotFoundError: If the file cannot be found + """ + if not file_name: + raise ValueError("file_name cannot be empty") + + file_path = self._resolve_file_path(file_name) + + if not file_path.exists(): + raise FileNotFoundError( + f"File '{file_name}' not found in any search location" + ) + + return file_path + + def _resolve_file_path(self, file_name: str) -> Path: + """Resolve the full path to a file. + + Args: + file_name: The name of the file to resolve + + Returns: + The full path to the file + """ + # First try the default code package location + if self._code_package_exists(): + file_path = self._get_code_package_file_path(file_name) + if file_path.exists(): + return file_path + + # Fall back to config.json-based location + config_path = self._find_config_file() + if config_path: + file_path = self._get_config_based_file_path(file_name, config_path) + if file_path.exists(): + return file_path + + # Return the file name as a Path if not found in any location + return Path(file_name) + + def _code_package_exists(self) -> bool: + """Check if the default code package directory exists. + + Returns: + True if the code package directory exists + """ + return os.path.exists(self.code_package) + + def _get_code_package_file_path(self, file_name: str) -> Path: + """Get the file path relative to the code package. + + Args: + file_name: The name of the file + + Returns: + The full path to the file + """ + relative_path = f"{self.code_package}/{self.file_folder}/{file_name}" + return Path(relative_path) + + def _find_config_file(self) -> Optional[Path]: + """Find the configuration file in the current directory tree. + + Returns: + The path to the config file, or None if not found + """ + return self._find_file_in_tree(self.config_file, Path.cwd()) + + def _get_config_based_file_path(self, file_name: str, config_path: Path) -> Path: + """Get the file path relative to the config file location. + + Args: + file_name: The name of the file + config_path: The path to the config file + + Returns: + The full path to the file + """ + relative_path = f"{self.file_folder}/{file_name}" + return Path(relative_path) + + def _find_file_in_tree(self, filename: str, search_path: Path) -> Optional[Path]: + """Find a file within a directory tree. + + Args: + filename: The name of the file to find + search_path: The root directory to search from + + Returns: + The full path to the file, or None if not found + """ + for file_path in search_path.rglob(filename): + return file_path + return None diff --git a/tests/file/__init__.py b/tests/file/__init__.py new file mode 100644 index 0000000..93988ff --- /dev/null +++ b/tests/file/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/file/path/__init__.py b/tests/file/path/__init__.py new file mode 100644 index 0000000..93988ff --- /dev/null +++ b/tests/file/path/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/file/test_path_default.py b/tests/file/test_path_default.py new file mode 100644 index 0000000..e928f1b --- /dev/null +++ b/tests/file/test_path_default.py @@ -0,0 +1,322 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from pathlib import Path +import tempfile +from unittest.mock import MagicMock, patch + +import pytest + +from datacustomcode.file.path.default import ( + DefaultFindFilePath, + FileNotFoundError, + FileReaderError, +) + + +class TestDefaultFindFilePath: + """Test cases for DefaultFindFilePath class.""" + + def test_init_with_defaults(self): + """Test initialization with default values.""" + finder = DefaultFindFilePath() + + assert finder.code_package == "payload" + assert finder.file_folder == "files" + assert finder.config_file == "config.json" + + def test_init_with_custom_values(self): + """Test initialization with custom values.""" + finder = DefaultFindFilePath( + code_package="custom_package", + file_folder="custom_files", + config_file="custom_config.json", + ) + + assert finder.code_package == "custom_package" + assert finder.file_folder == "custom_files" + assert finder.config_file == "custom_config.json" + + def test_find_file_path_empty_filename(self): + """Test find_file_path with empty filename raises ValueError.""" + finder = DefaultFindFilePath() + + with pytest.raises(ValueError, match="file_name cannot be empty"): + finder.find_file_path("") + + with pytest.raises(ValueError, match="file_name cannot be empty"): + finder.find_file_path(None) + + def test_find_file_path_file_not_found(self): + """Test find_file_path when file doesn't exist raises FileNotFoundError.""" + finder = DefaultFindFilePath() + + with patch.object(finder, "_resolve_file_path") as mock_resolve: + mock_path = MagicMock() + mock_path.exists.return_value = False + mock_resolve.return_value = mock_path + + with pytest.raises( + FileNotFoundError, + match="File 'test.txt' not found in any search location", + ): + finder.find_file_path("test.txt") + + def test_find_file_path_success(self): + """Test find_file_path when file exists returns Path.""" + finder = DefaultFindFilePath() + + with patch.object(finder, "_resolve_file_path") as mock_resolve: + mock_path = MagicMock() + mock_path.exists.return_value = True + mock_resolve.return_value = mock_path + + result = finder.find_file_path("test.txt") + + assert result == mock_path + mock_resolve.assert_called_once_with("test.txt") + + def test_resolve_file_path_code_package_exists(self): + """Test _resolve_file_path when code package exists and file is found.""" + finder = DefaultFindFilePath() + + with patch.object( + finder, "_code_package_exists", return_value=True + ) as mock_exists: + with patch.object(finder, "_get_code_package_file_path") as mock_get_path: + mock_path = MagicMock() + mock_path.exists.return_value = True + mock_get_path.return_value = mock_path + + result = finder._resolve_file_path("test.txt") + + assert result == mock_path + mock_exists.assert_called_once() + mock_get_path.assert_called_once_with("test.txt") + + def test_resolve_file_path_code_package_exists_file_not_found(self): + """Test _resolve_file_path when code package exists but file not found, + falls back to config.""" + finder = DefaultFindFilePath() + + with patch.object(finder, "_code_package_exists", return_value=True): + with patch.object(finder, "_get_code_package_file_path") as mock_get_path: + with patch.object(finder, "_find_config_file") as mock_find_config: + with patch.object( + finder, "_get_config_based_file_path" + ) as mock_get_config_path: + # Code package file doesn't exist + mock_code_path = MagicMock() + mock_code_path.exists.return_value = False + mock_get_path.return_value = mock_code_path + + # Config file exists and config-based file exists + mock_config_path = MagicMock() + mock_find_config.return_value = mock_config_path + + mock_config_file_path = MagicMock() + mock_config_file_path.exists.return_value = True + mock_get_config_path.return_value = mock_config_file_path + + result = finder._resolve_file_path("test.txt") + + assert result == mock_config_file_path + mock_find_config.assert_called_once() + mock_get_config_path.assert_called_once_with( + "test.txt", mock_config_path + ) + + def test_resolve_file_path_fallback_to_filename(self): + """Test _resolve_file_path falls back to Path(filename) + when no other location works.""" + finder = DefaultFindFilePath() + + with patch.object(finder, "_code_package_exists", return_value=False): + with patch.object(finder, "_find_config_file", return_value=None): + result = finder._resolve_file_path("test.txt") + + assert result == Path("test.txt") + + def test_code_package_exists_true(self): + """Test _code_package_exists returns True when directory exists.""" + finder = DefaultFindFilePath() + + with patch("os.path.exists", return_value=True): + assert finder._code_package_exists() is True + + def test_code_package_exists_false(self): + """Test _code_package_exists returns False when directory doesn't exist.""" + finder = DefaultFindFilePath() + + with patch("os.path.exists", return_value=False): + assert finder._code_package_exists() is False + + def test_get_code_package_file_path(self): + """Test _get_code_package_file_path constructs correct path.""" + finder = DefaultFindFilePath() + + result = finder._get_code_package_file_path("test.txt") + + expected = Path("payload/files/test.txt") + assert result == expected + + def test_get_code_package_file_path_custom_values(self): + """Test _get_code_package_file_path with custom values.""" + finder = DefaultFindFilePath( + code_package="custom_package", file_folder="custom_files" + ) + + result = finder._get_code_package_file_path("test.txt") + + expected = Path("custom_package/custom_files/test.txt") + assert result == expected + + def test_find_config_file_found(self): + """Test _find_config_file when config file is found.""" + finder = DefaultFindFilePath() + + with patch.object(finder, "_find_file_in_tree") as mock_find: + mock_path = MagicMock() + mock_find.return_value = mock_path + + result = finder._find_config_file() + + assert result == mock_path + mock_find.assert_called_once_with("config.json", Path.cwd()) + + def test_find_config_file_not_found(self): + """Test _find_config_file when config file is not found.""" + finder = DefaultFindFilePath() + + with patch.object(finder, "_find_file_in_tree", return_value=None): + result = finder._find_config_file() + + assert result is None + + def test_get_config_based_file_path(self): + """Test _get_config_based_file_path constructs correct path.""" + finder = DefaultFindFilePath() + config_path = Path("/some/path/config.json") + + result = finder._get_config_based_file_path("test.txt", config_path) + + expected = Path("files/test.txt") + assert result == expected + + def test_get_config_based_file_path_custom_folder(self): + """Test _get_config_based_file_path with custom file folder.""" + finder = DefaultFindFilePath(file_folder="custom_files") + config_path = Path("/some/path/config.json") + + result = finder._get_config_based_file_path("test.txt", config_path) + + expected = Path("custom_files/test.txt") + assert result == expected + + def test_find_file_in_tree_found(self): + """Test _find_file_in_tree when file is found.""" + finder = DefaultFindFilePath() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + test_file = temp_path / "test.txt" + test_file.write_text("test content") + + result = finder._find_file_in_tree("test.txt", temp_path) + + assert result is not None + assert result.name == "test.txt" + + def test_find_file_in_tree_not_found(self): + """Test _find_file_in_tree when file is not found.""" + finder = DefaultFindFilePath() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + result = finder._find_file_in_tree("nonexistent.txt", temp_path) + + assert result is None + + def test_find_file_in_tree_multiple_matches(self): + """Test _find_file_in_tree when multiple files match, returns first one.""" + finder = DefaultFindFilePath() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create multiple files with same name in different subdirectories + (temp_path / "subdir1").mkdir() + (temp_path / "subdir2").mkdir() + + file1 = temp_path / "subdir1" / "test.txt" + file2 = temp_path / "subdir2" / "test.txt" + + file1.write_text("content1") + file2.write_text("content2") + + result = finder._find_file_in_tree("test.txt", temp_path) + + assert result is not None + assert result.name == "test.txt" + # Should return one of the files (implementation returns first found) + + def test_integration_find_file_path_success(self): + """Test integration: find_file_path with real file system.""" + finder = DefaultFindFilePath() + + with tempfile.TemporaryDirectory() as temp_dir: + # Create a test file + test_file = Path(temp_dir) / "test.txt" + test_file.write_text("test content") + + # Mock the code package to point to our temp directory + finder.code_package = temp_dir + finder.file_folder = "" + + result = finder.find_file_path("test.txt") + + assert result == test_file + assert result.exists() + + def test_integration_find_file_path_not_found(self): + """Test integration: find_file_path when file doesn't exist.""" + finder = DefaultFindFilePath() + + with tempfile.TemporaryDirectory() as temp_dir: + # Don't create any files + finder.code_package = temp_dir + finder.file_folder = "" + + with pytest.raises(FileNotFoundError): + finder.find_file_path("nonexistent.txt") + + +class TestFileReaderError: + """Test cases for FileReaderError exception classes.""" + + def test_file_reader_error_inheritance(self): + """Test FileReaderError inherits from Exception.""" + error = FileReaderError("test message") + assert isinstance(error, Exception) + assert str(error) == "test message" + + def test_file_not_found_error_inheritance(self): + """Test FileNotFoundError inherits from FileReaderError.""" + error = FileNotFoundError("file not found") + assert isinstance(error, FileReaderError) + assert isinstance(error, Exception) + assert str(error) == "file not found" diff --git a/tests/test_credentials_profile_integration.py b/tests/test_credentials_profile_integration.py new file mode 100644 index 0000000..92a1538 --- /dev/null +++ b/tests/test_credentials_profile_integration.py @@ -0,0 +1,241 @@ +""" +Integration tests for credentials profile functionality. + +This module tests the complete flow of using different credentials profiles +with the DataCloud Custom Code Python SDK components. +""" + +from __future__ import annotations + +import os +from unittest.mock import MagicMock, patch + +from datacustomcode.config import config +from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader +from datacustomcode.io.writer.print import PrintDataCloudWriter + + +class TestCredentialsProfileIntegration: + """Test integration of credentials profile functionality across components.""" + + def test_query_api_reader_with_custom_profile(self): + """Test QueryAPIDataCloudReader uses custom credentials profile.""" + mock_spark = MagicMock() + + with patch( + "datacustomcode.credentials.Credentials.from_available" + ) as mock_from_available: + # Mock credentials for custom profile + mock_credentials = MagicMock() + mock_credentials.login_url = "https://custom.salesforce.com" + mock_credentials.username = "custom@example.com" + mock_credentials.password = "custom_password" + mock_credentials.client_id = "custom_client_id" + mock_credentials.client_secret = "custom_secret" + mock_from_available.return_value = mock_credentials + + # Mock the SalesforceCDPConnection + with patch( + "datacustomcode.io.reader.query_api.SalesforceCDPConnection" + ) as mock_conn_class: + mock_conn = MagicMock() + mock_conn_class.return_value = mock_conn + + # Test with custom profile + QueryAPIDataCloudReader( + mock_spark, credentials_profile="custom_profile" + ) + + # Verify the correct profile was used + mock_from_available.assert_called_with(profile="custom_profile") + + # Verify the connection was created with the custom credentials + mock_conn_class.assert_called_once_with( + "https://custom.salesforce.com", + "custom@example.com", + "custom_password", + "custom_client_id", + "custom_secret", + ) + + def test_print_writer_with_custom_profile(self): + """Test PrintDataCloudWriter uses custom credentials profile.""" + mock_spark = MagicMock() + + with patch( + "datacustomcode.credentials.Credentials.from_available" + ) as mock_from_available: + # Mock credentials for custom profile + mock_credentials = MagicMock() + mock_credentials.login_url = "https://custom.salesforce.com" + mock_credentials.username = "custom@example.com" + mock_credentials.password = "custom_password" + mock_credentials.client_id = "custom_client_id" + mock_credentials.client_secret = "custom_secret" + mock_from_available.return_value = mock_credentials + + # Mock the SalesforceCDPConnection + with patch( + "datacustomcode.io.reader.query_api.SalesforceCDPConnection" + ) as mock_conn_class: + mock_conn = MagicMock() + mock_conn_class.return_value = mock_conn + + # Test with custom profile + writer = PrintDataCloudWriter( + mock_spark, credentials_profile="custom_profile" + ) + + # Verify the correct profile was used + mock_from_available.assert_called_with(profile="custom_profile") + + # Verify the writer has the reader with custom credentials + assert writer.reader is not None + assert isinstance(writer.reader, QueryAPIDataCloudReader) + + def test_config_override_with_environment_variable(self): + """Test that environment variable overrides config credentials profile.""" + # Set environment variable + os.environ["SFDC_CREDENTIALS_PROFILE"] = "env_profile" + + try: + # Simulate what happens in entrypoint.py + credentials_profile = os.environ.get("SFDC_CREDENTIALS_PROFILE", "default") + assert credentials_profile == "env_profile" + + # Update both reader and writer configs + if config.reader_config and hasattr(config.reader_config, "options"): + config.reader_config.options["credentials_profile"] = ( + credentials_profile + ) + + if config.writer_config and hasattr(config.writer_config, "options"): + config.writer_config.options["credentials_profile"] = ( + credentials_profile + ) + + # Verify the configs were updated + assert config.reader_config.options["credentials_profile"] == "env_profile" + assert config.writer_config.options["credentials_profile"] == "env_profile" + + finally: + # Clean up + del os.environ["SFDC_CREDENTIALS_PROFILE"] + + def test_config_override_programmatically(self): + """Test programmatic override of credentials profile.""" + custom_profile = "programmatic_profile" + + # Update both reader and writer configs programmatically + if config.reader_config and hasattr(config.reader_config, "options"): + config.reader_config.options["credentials_profile"] = custom_profile + + if config.writer_config and hasattr(config.writer_config, "options"): + config.writer_config.options["credentials_profile"] = custom_profile + + # Verify the configs were updated + assert config.reader_config.options["credentials_profile"] == custom_profile + assert config.writer_config.options["credentials_profile"] == custom_profile + + def test_default_profile_behavior(self): + """Test that default profile is used when no override is specified.""" + # Reset to default values + if config.reader_config and hasattr(config.reader_config, "options"): + config.reader_config.options["credentials_profile"] = "default" + + if config.writer_config and hasattr(config.writer_config, "options"): + config.writer_config.options["credentials_profile"] = "default" + + # Verify default values + assert config.reader_config.options["credentials_profile"] == "default" + assert config.writer_config.options["credentials_profile"] == "default" + + def test_credentials_profile_consistency(self): + """Test that reader and writer use the same credentials profile.""" + mock_spark = MagicMock() + test_profile = "consistent_profile" + + with patch( + "datacustomcode.credentials.Credentials.from_available" + ) as mock_from_available: + # Mock credentials + mock_credentials = MagicMock() + mock_credentials.login_url = "https://consistent.salesforce.com" + mock_credentials.username = "consistent@example.com" + mock_credentials.password = "consistent_password" + mock_credentials.client_id = "consistent_client_id" + mock_credentials.client_secret = "consistent_secret" + mock_from_available.return_value = mock_credentials + + # Mock the SalesforceCDPConnection + with patch( + "datacustomcode.io.reader.query_api.SalesforceCDPConnection" + ) as mock_conn_class: + mock_conn = MagicMock() + mock_conn_class.return_value = mock_conn + + # Create reader and writer with same profile + reader = QueryAPIDataCloudReader( + mock_spark, credentials_profile=test_profile + ) + writer = PrintDataCloudWriter( + mock_spark, credentials_profile=test_profile + ) + + # Verify both used the same profile + assert mock_from_available.call_count == 2 + for call in mock_from_available.call_args_list: + assert call[1]["profile"] == test_profile + + # Verify both have the same credentials + assert reader._conn is not None + assert writer.reader._conn is not None + + def test_multiple_profiles_isolation(self): + """Test that different profiles are properly isolated.""" + mock_spark = MagicMock() + + with patch( + "datacustomcode.credentials.Credentials.from_available" + ) as mock_from_available: + # Mock different credentials for different profiles + def mock_credentials_side_effect(profile="default"): + mock_creds = MagicMock() + if profile == "profile1": + mock_creds.login_url = "https://profile1.salesforce.com" + mock_creds.username = "profile1@example.com" + elif profile == "profile2": + mock_creds.login_url = "https://profile2.salesforce.com" + mock_creds.username = "profile2@example.com" + else: # default + mock_creds.login_url = "https://default.salesforce.com" + mock_creds.username = "default@example.com" + + mock_creds.password = f"{profile}_password" + mock_creds.client_id = f"{profile}_client_id" + mock_creds.client_secret = f"{profile}_secret" + return mock_creds + + mock_from_available.side_effect = mock_credentials_side_effect + + # Mock the SalesforceCDPConnection + with patch( + "datacustomcode.io.reader.query_api.SalesforceCDPConnection" + ) as mock_conn_class: + mock_conn = MagicMock() + mock_conn_class.return_value = mock_conn + + # Create readers with different profiles + QueryAPIDataCloudReader(mock_spark, credentials_profile="profile1") + QueryAPIDataCloudReader(mock_spark, credentials_profile="profile2") + QueryAPIDataCloudReader(mock_spark, credentials_profile="default") + + # Verify each reader used the correct profile + calls = mock_from_available.call_args_list + assert len(calls) == 3 + + # Check that each call used the correct profile + profiles_used = [call[1]["profile"] for call in calls] + assert "profile1" in profiles_used + assert "profile2" in profiles_used + assert "default" in profiles_used