Skip to content

Commit 2c00d4a

Browse files
authored
Merge pull request #47 from forcedotcom/spark-session-provider
Allow for spark session provider override
2 parents bf74399 + 3ee0473 commit 2c00d4a

File tree

8 files changed

+297
-51
lines changed

8 files changed

+297
-51
lines changed

src/datacustomcode/client.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@
2121
Optional,
2222
)
2323

24-
from pyspark.sql import SparkSession
25-
26-
from datacustomcode.config import SparkConfig, config
24+
from datacustomcode.config import config
2725
from datacustomcode.file.path.default import DefaultFindFilePath
2826
from datacustomcode.io.reader.base import BaseDataCloudReader
27+
from datacustomcode.spark.default import DefaultSparkSessionProvider
2928

3029
if TYPE_CHECKING:
3130
from pathlib import Path
@@ -34,18 +33,7 @@
3433

3534
from datacustomcode.io.reader.base import BaseDataCloudReader
3635
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
37-
38-
39-
def _setup_spark(spark_config: SparkConfig) -> SparkSession:
40-
"""Setup Spark session from config."""
41-
builder = SparkSession.builder
42-
if spark_config.master is not None:
43-
builder = builder.master(spark_config.master)
44-
45-
builder = builder.appName(spark_config.app_name)
46-
for key, value in spark_config.options.items():
47-
builder = builder.config(key, value)
48-
return builder.getOrCreate()
36+
from datacustomcode.spark.base import BaseSparkSessionProvider
4937

5038

5139
class DataCloudObjectType(Enum):
@@ -123,7 +111,8 @@ class Client:
123111
def __new__(
124112
cls,
125113
reader: Optional[BaseDataCloudReader] = None,
126-
writer: Optional[BaseDataCloudWriter] = None,
114+
writer: Optional["BaseDataCloudWriter"] = None,
115+
spark_provider: Optional["BaseSparkSessionProvider"] = None,
127116
) -> Client:
128117
if cls._instance is None:
129118
cls._instance = super().__new__(cls)
@@ -136,7 +125,16 @@ def __new__(
136125
raise ValueError(
137126
"Spark config is required when reader/writer is not provided"
138127
)
139-
spark = _setup_spark(config.spark_config)
128+
129+
provider: BaseSparkSessionProvider
130+
if spark_provider is not None:
131+
provider = spark_provider
132+
elif config.spark_provider_config is not None:
133+
provider = config.spark_provider_config.to_object()
134+
else:
135+
provider = DefaultSparkSessionProvider()
136+
137+
spark = provider.get_session(config.spark_config)
140138

141139
if config.reader_config is None and reader is None:
142140
raise ValueError(

src/datacustomcode/config.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from datacustomcode.io.base import BaseDataAccessLayer
3939
from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH001
4040
from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH001
41+
from datacustomcode.spark.base import BaseSparkSessionProvider
4142

4243
DEFAULT_CONFIG_NAME = "config.yaml"
4344

@@ -89,10 +90,29 @@ class SparkConfig(ForceableConfig):
8990
)
9091

9192

93+
_P = TypeVar("_P", bound=BaseSparkSessionProvider)
94+
95+
96+
class SparkProviderConfig(ForceableConfig, Generic[_P]):
97+
model_config = ConfigDict(validate_default=True, extra="forbid")
98+
type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider
99+
type_config_name: str = Field(
100+
description="CONFIG_NAME of the Spark session provider."
101+
)
102+
options: dict[str, Any] = Field(default_factory=dict)
103+
104+
def to_object(self) -> _P:
105+
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
106+
return cast(_P, type_(**self.options))
107+
108+
92109
class ClientConfig(BaseModel):
93110
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
94111
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
95112
spark_config: Union[SparkConfig, None] = None
113+
spark_provider_config: Union[
114+
SparkProviderConfig[BaseSparkSessionProvider], None
115+
] = None
96116

97117
def update(self, other: ClientConfig) -> ClientConfig:
98118
"""Merge this ClientConfig with another, respecting force flags.
@@ -117,6 +137,9 @@ def merge(
117137
self.reader_config = merge(self.reader_config, other.reader_config)
118138
self.writer_config = merge(self.writer_config, other.writer_config)
119139
self.spark_config = merge(self.spark_config, other.spark_config)
140+
self.spark_provider_config = merge(
141+
self.spark_provider_config, other.spark_provider_config
142+
)
120143
return self
121144

122145
def load(self, config_path: str) -> ClientConfig:
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
from datacustomcode.spark.base import BaseSparkSessionProvider
18+
from datacustomcode.spark.default import DefaultSparkSessionProvider
19+
20+
__all__ = ["BaseSparkSessionProvider", "DefaultSparkSessionProvider"]

src/datacustomcode/spark/base.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
19+
from datacustomcode.mixin import UserExtendableNamedConfigMixin
20+
21+
if TYPE_CHECKING:
22+
from pyspark.sql import SparkSession
23+
24+
from datacustomcode.config import SparkConfig
25+
26+
27+
class BaseSparkSessionProvider(UserExtendableNamedConfigMixin):
28+
def get_session(self, spark_config: SparkConfig) -> "SparkSession":
29+
raise NotImplementedError
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) 2025, Salesforce, Inc.
2+
# SPDX-License-Identifier: Apache-2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
19+
from datacustomcode.spark.base import BaseSparkSessionProvider
20+
21+
if TYPE_CHECKING:
22+
from pyspark.sql import SparkSession
23+
24+
from datacustomcode.config import SparkConfig
25+
26+
27+
class DefaultSparkSessionProvider(BaseSparkSessionProvider):
28+
CONFIG_NAME = "DefaultSparkSessionProvider"
29+
30+
def get_session(self, spark_config: SparkConfig) -> "SparkSession":
31+
from pyspark.sql import SparkSession
32+
33+
builder = SparkSession.builder
34+
if spark_config.master is not None:
35+
builder = builder.master(spark_config.master)
36+
builder = builder.appName(spark_config.app_name)
37+
for key, value in spark_config.options.items():
38+
builder = builder.config(key, value)
39+
return builder.getOrCreate()

tests/spark/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Package initialization file
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
from datacustomcode.client import Client
6+
from datacustomcode.config import (
7+
AccessLayerObjectConfig,
8+
ClientConfig,
9+
SparkConfig,
10+
SparkProviderConfig,
11+
)
12+
from datacustomcode.io.reader.base import BaseDataCloudReader
13+
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
14+
from datacustomcode.spark.base import BaseSparkSessionProvider
15+
16+
if TYPE_CHECKING:
17+
from pyspark.sql import DataFrame as PySparkDataFrame
18+
19+
20+
class _Sentinel:
21+
pass
22+
23+
24+
SENTINEL_SPARK = _Sentinel()
25+
26+
27+
class MockReader(BaseDataCloudReader):
28+
CONFIG_NAME = "MockReader"
29+
last_spark: Any | None = None
30+
31+
def __init__(self, spark):
32+
super().__init__(spark)
33+
MockReader.last_spark = spark
34+
35+
def read_dlo(self, name: str): # type: ignore[override]
36+
raise NotImplementedError
37+
38+
def read_dmo(self, name: str): # type: ignore[override]
39+
raise NotImplementedError
40+
41+
42+
class MockWriter(BaseDataCloudWriter):
43+
CONFIG_NAME = "MockWriter"
44+
last_spark: Any | None = None
45+
46+
def __init__(self, spark):
47+
super().__init__(spark)
48+
MockWriter.last_spark = spark
49+
50+
def write_to_dlo(
51+
self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode
52+
) -> None: # type: ignore[override]
53+
raise NotImplementedError
54+
55+
def write_to_dmo(
56+
self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode
57+
) -> None: # type: ignore[override]
58+
raise NotImplementedError
59+
60+
61+
class FakeProvider(BaseSparkSessionProvider):
62+
CONFIG_NAME = "FakeProvider"
63+
64+
def get_session(self, spark_config: SparkConfig): # type: ignore[override]
65+
return SENTINEL_SPARK
66+
67+
68+
def _reset_singleton():
69+
# Reset Client singleton between tests
70+
Client._instance = None # type: ignore[attr-defined]
71+
72+
73+
def test_client_uses_provider_from_config(monkeypatch):
74+
_reset_singleton()
75+
76+
cfg = ClientConfig(
77+
reader_config=AccessLayerObjectConfig(
78+
type_config_name=MockReader.CONFIG_NAME, options={}
79+
),
80+
writer_config=AccessLayerObjectConfig(
81+
type_config_name=MockWriter.CONFIG_NAME, options={}
82+
),
83+
spark_config=SparkConfig(app_name="test-app", master=None, options={}),
84+
spark_provider_config=SparkProviderConfig(
85+
type_config_name=FakeProvider.CONFIG_NAME, options={}
86+
),
87+
)
88+
89+
from datacustomcode.config import config as global_config
90+
91+
global_config.update(cfg)
92+
93+
Client()
94+
assert MockReader.last_spark is SENTINEL_SPARK
95+
assert MockWriter.last_spark is SENTINEL_SPARK
96+
97+
98+
class ExplicitProvider(BaseSparkSessionProvider):
99+
CONFIG_NAME = "ExplicitProvider"
100+
101+
def get_session(self, spark_config: SparkConfig): # type: ignore[override]
102+
return SENTINEL_SPARK
103+
104+
105+
def test_client_explicit_provider_overrides_config(monkeypatch):
106+
_reset_singleton()
107+
108+
cfg = ClientConfig(
109+
reader_config=AccessLayerObjectConfig(
110+
type_config_name=MockReader.CONFIG_NAME, options={}
111+
),
112+
writer_config=AccessLayerObjectConfig(
113+
type_config_name=MockWriter.CONFIG_NAME, options={}
114+
),
115+
spark_config=SparkConfig(app_name="test-app", master=None, options={}),
116+
spark_provider_config=None,
117+
)
118+
119+
from datacustomcode.config import config as global_config
120+
121+
global_config.update(cfg)
122+
123+
provider = ExplicitProvider()
124+
Client(spark_provider=provider)
125+
assert MockReader.last_spark is SENTINEL_SPARK
126+
assert MockWriter.last_spark is SENTINEL_SPARK

0 commit comments

Comments
 (0)