@@ -119,16 +119,29 @@ def reset_override(self, provider_id: str | None = None) -> None:
119
119
else :
120
120
self ._overrides .pop (provider_id , None )
121
121
122
- async def __aenter__ (self ) -> "Container" :
122
+ def async_enter (self ) -> "Container" :
123
123
self ._is_async = True
124
124
return self
125
125
126
+ def sync_enter (self ) -> "Container" :
127
+ self ._is_async = False
128
+ return self
129
+
126
130
async def async_close (self ) -> None :
127
131
self ._check_entered ()
128
132
for provider_state in reversed (self ._provider_states .values ()):
129
133
await provider_state .async_tear_down ()
130
134
self ._exit ()
131
135
136
+ def sync_close (self ) -> None :
137
+ self ._check_entered ()
138
+ for provider_state in reversed (self ._provider_states .values ()):
139
+ provider_state .sync_tear_down ()
140
+ self ._exit ()
141
+
142
+ async def __aenter__ (self ) -> "Container" :
143
+ return self .async_enter ()
144
+
132
145
async def __aexit__ (
133
146
self ,
134
147
exc_type : type [BaseException ] | None ,
@@ -138,16 +151,12 @@ async def __aexit__(
138
151
await self .async_close ()
139
152
140
153
def __enter__ (self ) -> "Container" :
141
- self ._is_async = False
142
- return self
154
+ return self .sync_enter ()
143
155
144
156
def __exit__ (
145
157
self ,
146
158
exc_type : type [BaseException ] | None ,
147
159
exc_value : BaseException | None ,
148
160
traceback : types .TracebackType | None ,
149
161
) -> None :
150
- self ._check_entered ()
151
- for provider_state in reversed (self ._provider_states .values ()):
152
- provider_state .sync_tear_down ()
153
- self ._exit ()
162
+ self .sync_close ()
0 commit comments