Skip to content

Commit 6bedd75

Browse files
committed
fixed failing test cases
Signed-off-by: Nikhil Suri <[email protected]>
1 parent efbeb1a commit 6bedd75

File tree

4 files changed

+130
-90
lines changed

4 files changed

+130
-90
lines changed

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,20 @@ def __init__(
200200

201201
# Create telemetry push client based on circuit breaker enabled flag
202202
if client_context.telemetry_circuit_breaker_enabled:
203-
# Create circuit breaker configuration with hardcoded values
204-
# These values are optimized for telemetry batching and network resilience
205-
circuit_breaker_config = CircuitBreakerConfig(
206-
failure_threshold=0.5, # Opens if 50%+ of calls fail
207-
minimum_calls=20, # Minimum sample size before circuit can open
208-
timeout=30, # Time window for counting failures (seconds)
209-
reset_timeout=30, # Cool-down period before retrying (seconds)
203+
# Create circuit breaker configuration from client context or use defaults
204+
self._circuit_breaker_config = CircuitBreakerConfig(
205+
failure_threshold=getattr(
206+
client_context, "telemetry_circuit_breaker_failure_threshold", 0.5
207+
),
208+
minimum_calls=getattr(
209+
client_context, "telemetry_circuit_breaker_minimum_calls", 20
210+
),
211+
timeout=getattr(
212+
client_context, "telemetry_circuit_breaker_timeout", 30
213+
),
214+
reset_timeout=getattr(
215+
client_context, "telemetry_circuit_breaker_reset_timeout", 30
216+
),
210217
name=f"telemetry-circuit-breaker-{session_id_hex}",
211218
)
212219

@@ -215,11 +222,12 @@ def __init__(
215222
CircuitBreakerTelemetryPushClient(
216223
TelemetryPushClient(self._http_client),
217224
host_url,
218-
circuit_breaker_config,
225+
self._circuit_breaker_config,
219226
)
220227
)
221228
else:
222229
# Circuit breaker disabled - use direct telemetry push client
230+
self._circuit_breaker_config = None
223231
self._telemetry_push_client: ITelemetryPushClient = TelemetryPushClient(
224232
self._http_client
225233
)
@@ -402,6 +410,18 @@ def close(self):
402410
logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex)
403411
self._flush()
404412

413+
def get_circuit_breaker_state(self) -> str:
414+
"""Get the current state of the circuit breaker."""
415+
return self._telemetry_push_client.get_circuit_breaker_state()
416+
417+
def is_circuit_breaker_open(self) -> bool:
418+
"""Check if the circuit breaker is currently open."""
419+
return self._telemetry_push_client.is_circuit_breaker_open()
420+
421+
def reset_circuit_breaker(self) -> None:
422+
"""Reset the circuit breaker."""
423+
self._telemetry_push_client.reset_circuit_breaker()
424+
405425

406426
class TelemetryClientFactory:
407427
"""

tests/unit/test_circuit_breaker_http_client.py

Lines changed: 53 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -71,42 +71,25 @@ def test_initialization(self):
7171
assert self.client._config == self.config
7272
assert self.client._circuit_breaker is not None
7373

74-
def test_initialization_disabled(self):
75-
"""Test client initialization with circuit breaker disabled."""
76-
config = CircuitBreakerConfig(enabled=False)
77-
client = CircuitBreakerHttpClient(self.mock_delegate, self.host, config)
78-
79-
assert client._config.enabled is False
8074

81-
def test_request_context_disabled(self):
82-
"""Test request context when circuit breaker is disabled."""
83-
config = CircuitBreakerConfig(enabled=False)
84-
client = CircuitBreakerHttpClient(self.mock_delegate, self.host, config)
85-
86-
mock_response = Mock()
87-
self.mock_delegate.request_context.return_value.__enter__.return_value = mock_response
88-
self.mock_delegate.request_context.return_value.__exit__.return_value = None
89-
90-
with client.request_context(HttpMethod.POST, "https://test.com", {}) as response:
91-
assert response == mock_response
92-
93-
self.mock_delegate.request_context.assert_called_once()
9475

9576
def test_request_context_enabled_success(self):
9677
"""Test successful request context when circuit breaker is enabled."""
9778
mock_response = Mock()
98-
self.mock_delegate.request_context.return_value.__enter__.return_value = mock_response
99-
self.mock_delegate.request_context.return_value.__exit__.return_value = None
79+
mock_context = MagicMock()
80+
mock_context.__enter__.return_value = mock_response
81+
mock_context.__exit__.return_value = None
82+
self.mock_delegate.request_context.return_value = mock_context
10083

101-
with client.request_context(HttpMethod.POST, "https://test.com", {}) as response:
84+
with self.client.request_context(HttpMethod.POST, "https://test.com", {}) as response:
10285
assert response == mock_response
10386

10487
self.mock_delegate.request_context.assert_called_once()
10588

10689
def test_request_context_enabled_circuit_breaker_error(self):
10790
"""Test request context when circuit breaker is open."""
10891
# Mock circuit breaker to raise CircuitBreakerError
109-
with patch.object(self.client._circuit_breaker, '__enter__', side_effect=CircuitBreakerError("Circuit is open")):
92+
with patch.object(self.client._circuit_breaker, 'call', side_effect=CircuitBreakerError("Circuit is open")):
11093
with pytest.raises(CircuitBreakerError):
11194
with self.client.request_context(HttpMethod.POST, "https://test.com", {}):
11295
pass
@@ -120,18 +103,6 @@ def test_request_context_enabled_other_error(self):
120103
with self.client.request_context(HttpMethod.POST, "https://test.com", {}):
121104
pass
122105

123-
def test_request_disabled(self):
124-
"""Test request method when circuit breaker is disabled."""
125-
config = CircuitBreakerConfig(enabled=False)
126-
client = CircuitBreakerHttpClient(self.mock_delegate, self.host, config)
127-
128-
mock_response = Mock()
129-
self.mock_delegate.request.return_value = mock_response
130-
131-
response = client.request(HttpMethod.POST, "https://test.com", {})
132-
133-
assert response == mock_response
134-
self.mock_delegate.request.assert_called_once()
135106

136107
def test_request_enabled_success(self):
137108
"""Test successful request when circuit breaker is enabled."""
@@ -146,7 +117,7 @@ def test_request_enabled_success(self):
146117
def test_request_enabled_circuit_breaker_error(self):
147118
"""Test request when circuit breaker is open."""
148119
# Mock circuit breaker to raise CircuitBreakerError
149-
with patch.object(self.client._circuit_breaker, '__enter__', side_effect=CircuitBreakerError("Circuit is open")):
120+
with patch.object(self.client._circuit_breaker, 'call', side_effect=CircuitBreakerError("Circuit is open")):
150121
with pytest.raises(CircuitBreakerError):
151122
self.client.request(HttpMethod.POST, "https://test.com", {})
152123

@@ -160,15 +131,15 @@ def test_request_enabled_other_error(self):
160131

161132
def test_get_circuit_breaker_state(self):
162133
"""Test getting circuit breaker state."""
163-
with patch.object(self.client._circuit_breaker, 'current_state', 'open'):
134+
with patch('databricks.sql.telemetry.telemetry_push_client.CircuitBreakerManager.get_circuit_breaker_state', return_value='open'):
164135
state = self.client.get_circuit_breaker_state()
165136
assert state == 'open'
166137

167138
def test_reset_circuit_breaker(self):
168139
"""Test resetting circuit breaker."""
169-
with patch.object(self.client._circuit_breaker, 'reset') as mock_reset:
140+
with patch('databricks.sql.telemetry.telemetry_push_client.CircuitBreakerManager.reset_circuit_breaker') as mock_reset:
170141
self.client.reset_circuit_breaker()
171-
mock_reset.assert_called_once()
142+
mock_reset.assert_called_once_with(self.client._host)
172143

173144
def test_is_circuit_breaker_open(self):
174145
"""Test checking if circuit breaker is open."""
@@ -180,42 +151,38 @@ def test_is_circuit_breaker_open(self):
180151

181152
def test_is_circuit_breaker_enabled(self):
182153
"""Test checking if circuit breaker is enabled."""
183-
assert self.client.is_circuit_breaker_enabled() is True
184-
185-
config = CircuitBreakerConfig(enabled=False)
186-
client = CircuitBreakerHttpClient(self.mock_delegate, self.host, config)
187-
assert client.is_circuit_breaker_enabled() is False
154+
assert self.client._circuit_breaker is not None
188155

189156
def test_circuit_breaker_state_logging(self):
190157
"""Test that circuit breaker state changes are logged."""
191-
with patch('databricks.sql.telemetry.circuit_breaker_http_client.logger') as mock_logger:
192-
with patch.object(self.client._circuit_breaker, '__enter__', side_effect=CircuitBreakerError("Circuit is open")):
158+
with patch('databricks.sql.telemetry.telemetry_push_client.logger') as mock_logger:
159+
with patch.object(self.client._circuit_breaker, 'call', side_effect=CircuitBreakerError("Circuit is open")):
193160
with pytest.raises(CircuitBreakerError):
194161
self.client.request(HttpMethod.POST, "https://test.com", {})
195-
196-
# Check that warning was logged
197-
mock_logger.warning.assert_called()
198-
warning_call = mock_logger.warning.call_args[0][0]
199-
assert "Circuit breaker is open" in warning_call
200-
assert self.host in warning_call
162+
163+
# Check that warning was logged
164+
mock_logger.warning.assert_called()
165+
warning_call = mock_logger.warning.call_args[0]
166+
assert "Circuit breaker is open" in warning_call[0]
167+
assert self.host in warning_call[1]
201168

202169
def test_other_error_logging(self):
203170
"""Test that other errors are logged appropriately."""
204-
with patch('databricks.sql.telemetry.circuit_breaker_http_client.logger') as mock_logger:
171+
with patch('databricks.sql.telemetry.telemetry_push_client.logger') as mock_logger:
205172
self.mock_delegate.request.side_effect = ValueError("Network error")
206173

207174
with pytest.raises(ValueError):
208175
self.client.request(HttpMethod.POST, "https://test.com", {})
209176

210177
# Check that debug was logged
211178
mock_logger.debug.assert_called()
212-
debug_call = mock_logger.debug.call_args[0][0]
213-
assert "Telemetry request failed" in debug_call
214-
assert self.host in debug_call
179+
debug_call = mock_logger.debug.call_args[0]
180+
assert "Telemetry request failed" in debug_call[0]
181+
assert self.host in debug_call[1]
215182

216183

217-
class TestCircuitBreakerHttpClientIntegration:
218-
"""Integration tests for CircuitBreakerHttpClient."""
184+
class TestCircuitBreakerTelemetryPushClientIntegration:
185+
"""Integration tests for CircuitBreakerTelemetryPushClient."""
219186

220187
def setup_method(self):
221188
"""Set up test fixtures."""
@@ -224,42 +191,59 @@ def setup_method(self):
224191

225192
def test_circuit_breaker_opens_after_failures(self):
226193
"""Test that circuit breaker opens after repeated failures."""
194+
from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager
195+
196+
# Clear any existing state
197+
CircuitBreakerManager.clear_all_circuit_breakers()
198+
227199
config = CircuitBreakerConfig(
228200
failure_threshold=0.1, # 10% failure rate
229201
minimum_calls=2, # Only 2 calls needed
230202
reset_timeout=1 # 1 second reset timeout
231203
)
232-
client = CircuitBreakerHttpClient(self.mock_delegate, self.host, config)
204+
205+
# Initialize the manager
206+
CircuitBreakerManager.initialize(config)
207+
208+
client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host, config)
233209

234210
# Simulate failures
235211
self.mock_delegate.request.side_effect = Exception("Network error")
236212

237-
# First few calls should fail with the original exception
238-
for _ in range(2):
239-
with pytest.raises(Exception, match="Network error"):
240-
client.request(HttpMethod.POST, "https://test.com", {})
213+
# First call should fail with the original exception
214+
with pytest.raises(Exception, match="Network error"):
215+
client.request(HttpMethod.POST, "https://test.com", {})
241216

242-
# After enough failures, circuit breaker should open
217+
# Second call should open the circuit breaker and raise CircuitBreakerError
243218
with pytest.raises(CircuitBreakerError):
244219
client.request(HttpMethod.POST, "https://test.com", {})
245220

246221
def test_circuit_breaker_recovers_after_success(self):
247222
"""Test that circuit breaker recovers after successful calls."""
223+
from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager
224+
225+
# Clear any existing state
226+
CircuitBreakerManager.clear_all_circuit_breakers()
227+
248228
config = CircuitBreakerConfig(
249229
failure_threshold=0.1,
250230
minimum_calls=2,
251231
reset_timeout=1
252232
)
253-
client = CircuitBreakerHttpClient(self.mock_delegate, self.host, config)
233+
234+
# Initialize the manager
235+
CircuitBreakerManager.initialize(config)
236+
237+
client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host, config)
254238

255239
# Simulate failures first
256240
self.mock_delegate.request.side_effect = Exception("Network error")
257241

258-
for _ in range(2):
259-
with pytest.raises(Exception):
260-
client.request(HttpMethod.POST, "https://test.com", {})
242+
# First call should fail with the original exception
243+
with pytest.raises(Exception):
244+
client.request(HttpMethod.POST, "https://test.com", {})
261245

262-
# Circuit breaker should be open now
246+
# Second call should open the circuit breaker
263247
with pytest.raises(CircuitBreakerError):
264248
client.request(HttpMethod.POST, "https://test.com", {})
265249

tests/unit/test_circuit_breaker_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_get_circuit_breaker_not_initialized(self):
7878

7979
# Should return a no-op circuit breaker
8080
assert breaker.name == "noop-circuit-breaker"
81-
assert breaker.failure_threshold == 1.0
81+
assert breaker.fail_max == 1000000 # Very high threshold for no-op
8282

8383
def test_get_circuit_breaker_enabled(self):
8484
"""Test getting circuit breaker when enabled."""

0 commit comments

Comments
 (0)