Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-132113: Make EnvironmentVarGuard thread safe #132128

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 36 additions & 26 deletions Lib/test/support/os_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import stat
import string
import sys
import threading
import time
import unittest
import warnings
Expand Down Expand Up @@ -736,64 +737,73 @@ def temp_umask(umask):


class EnvironmentVarGuard(collections.abc.MutableMapping):
"""Class to help protect the environment variable properly.

"""Thread-safe class to help protect environment variables.
Can be used as a context manager.
"""

def __init__(self):
self._environ = os.environ
self._changed = {}
self._lock = threading.RLock()

def __getitem__(self, envvar):
return self._environ[envvar]
with self._lock:
return self._environ[envvar]

def __setitem__(self, envvar, value):
# Remember the initial value on the first access
if envvar not in self._changed:
self._changed[envvar] = self._environ.get(envvar)
self._environ[envvar] = value
with self._lock:
# Remember the initial value on the first access
if envvar not in self._changed:
self._changed[envvar] = self._environ.get(envvar)
self._environ[envvar] = value

def __delitem__(self, envvar):
# Remember the initial value on the first access
if envvar not in self._changed:
self._changed[envvar] = self._environ.get(envvar)
if envvar in self._environ:
del self._environ[envvar]
with self._lock:
# Remember the initial value on the first access
if envvar not in self._changed:
self._changed[envvar] = self._environ.get(envvar)
self._environ.pop(envvar, None)

def keys(self):
return self._environ.keys()
with self._lock:
return list(self._environ.keys())

def __iter__(self):
return iter(self._environ)
with self._lock:
return iter(dict(self._environ))

def __len__(self):
return len(self._environ)
with self._lock:
return len(self._environ)

def set(self, envvar, value):
self[envvar] = value

def unset(self, envvar, /, *envvars):
"""Unset one or more environment variables."""
for ev in (envvar, *envvars):
del self[ev]
with self._lock:
for ev in (envvar, *envvars):
del self[ev]

def copy(self):
# We do what os.environ.copy() does.
return dict(self)
with self._lock:
return dict(self._environ)

def __enter__(self):
return self

def __exit__(self, *ignore_exc):
for (k, v) in self._changed.items():
if v is None:
if k in self._environ:
del self._environ[k]
else:
self._environ[k] = v
os.environ = self._environ
with self._lock:
for (k, v) in self._changed.items():
if v is None:
self._environ.pop(k, None)
else:
self._environ[k] = v
self._changed.clear()
os.environ = self._environ

def __reduce__(self):
return (dict, (dict(self),))

try:
if support.MS_WINDOWS:
Expand Down
24 changes: 24 additions & 0 deletions Lib/test/test_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import sysconfig
import tempfile
import textwrap
import threading
import time
import unittest
import warnings

Expand All @@ -22,6 +24,7 @@
from test.support import script_helper
from test.support import socket_helper
from test.support import warnings_helper
from test.support.os_helper import EnvironmentVarGuard

TESTFN = os_helper.TESTFN

Expand Down Expand Up @@ -794,6 +797,27 @@ def test_linked_to_musl(self):
for v in linked:
self.assertIsInstance(v, int)

def test_threadsafe_environmentvarguard(self):
Copy link
Member

@picnixz picnixz Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test should be executed more than once. I think, so I would suggest doing something like that (when I executed the script in my terminal, I needed more than one script execution to see the issue)

def test_threadsafe_environmentvarguard(self):
    for _ in range(100):
        self._test_threadsafe_environmentvarguard()

def _test_threadsafe_environmentvarguard(self):
    ...

def worker1(guard):
for i in range(1000):
guard['MY_VAR'] = 'value1'
time.sleep(0.0001) # Small delay to increase chance of thread switching

def worker2(guard):
for i in range(1000):
guard['MY_VAR'] = 'value2'
time.sleep(0.0001)

with EnvironmentVarGuard() as guard:
t1 = threading.Thread(target=worker1, args=(guard,))
t2 = threading.Thread(target=worker2, args=(guard,))
t1.start()
t2.start()
t1.join()
t2.join()
final_value = os.getenv('MY_VAR')
self.assertIn(final_value, ("value1", "value2"))


# XXX -follows a list of untested API
# make_legacy_pyc
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make :class:`EnvironmentVarGuard <test.support.os_helper.EnvironmentVarGuard>` thread safe
Loading