Skip to content

Commit

Permalink
Add pickling support to HttpError class (#571)
Browse files Browse the repository at this point in the history
  • Loading branch information
tasansal authored Jul 25, 2023
1 parent 1966d78 commit 4541e40
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
8 changes: 8 additions & 0 deletions gcsfs/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class HttpError(Exception):
"""Holds the message and code from cloud errors."""

def __init__(self, error_response=None):
# Save error_response for potential pickle.
self._error_response = error_response
if error_response:
self.code = error_response.get("code", None)
self.message = error_response.get("message", "")
Expand All @@ -29,6 +31,12 @@ def __init__(self, error_response=None):
# Call the base class constructor with the parameters it needs
super().__init__(self.message)

def __reduce__(self):
"""This makes the Exception pickleable."""

# This is basically deconstructing the HttpError when pickled.
return HttpError, (self._error_response,)


class ChecksumError(Exception):
"""Raised when the md5 hash of the content does not match the header."""
Expand Down
34 changes: 34 additions & 0 deletions gcsfs/tests/test_retry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import multiprocessing
import os
import pickle
from concurrent.futures import ProcessPoolExecutor

import pytest
import requests
Expand Down Expand Up @@ -39,6 +42,37 @@ def test_retriable_exception():
assert is_retriable(e)


def test_pickle_serialization():
expected = HttpError({"message": "", "code": 400})

# Serialize/Deserialize
serialized = pickle.dumps(expected)
actual = pickle.loads(serialized)

is_same_type = type(expected) is type(actual)
is_same_args = expected.args == actual.args

assert is_same_type and is_same_args


def conditional_exception(process_id):
# Raise only on second process (id=1)
if process_id == 1:
raise HttpError({"message": "", "code": 400})


def test_multiprocessing_error_handling():
# Ensure spawn context to avoid forking issues
ctx = multiprocessing.get_context("spawn")

# Run on two processes
with ProcessPoolExecutor(2, mp_context=ctx) as p:
results = p.map(conditional_exception, range(2))

with pytest.raises(HttpError):
_ = [result for result in results]


def test_validate_response():
validate_response(200, None, "/path")

Expand Down

0 comments on commit 4541e40

Please sign in to comment.