@@ -130,7 +130,8 @@ def _client_service_logs_effect():
130
130
client_mock .remove_service .assert_called_once_with ("some_id" )
131
131
132
132
@mock .patch ("airflow.providers.docker.operators.docker_swarm.types" )
133
- def test_auto_remove (self , types_mock , docker_api_client_patcher ):
133
+ @pytest .mark .parametrize ("auto_remove" , ["success" , "force" ])
134
+ def test_auto_remove (self , types_mock , docker_api_client_patcher , auto_remove ):
134
135
mock_obj = mock .Mock ()
135
136
136
137
client_mock = mock .Mock (spec = APIClient )
@@ -148,12 +149,45 @@ def test_auto_remove(self, types_mock, docker_api_client_patcher):
148
149
docker_api_client_patcher .return_value = client_mock
149
150
150
151
operator = DockerSwarmOperator (
151
- image = "" , auto_remove = "success" , task_id = "unittest" , enable_logging = False
152
+ image = "" , auto_remove = auto_remove , task_id = "unittest" , enable_logging = False
152
153
)
153
154
operator .execute (None )
154
155
155
156
client_mock .remove_service .assert_called_once_with ("some_id" )
156
157
158
+ @mock .patch ("airflow.providers.docker.operators.docker_swarm.types" )
159
+ @pytest .mark .parametrize (
160
+ "auto_remove,expected_remove_call" , [("success" , False ), ("force" , True ), ("never" , False )]
161
+ )
162
+ def test_auto_remove_failed (
163
+ self , types_mock , docker_api_client_patcher , auto_remove , expected_remove_call
164
+ ):
165
+ mock_obj = mock .Mock ()
166
+
167
+ client_mock = mock .Mock (spec = APIClient )
168
+ client_mock .create_service .return_value = {"ID" : "some_id" }
169
+ client_mock .images .return_value = []
170
+ client_mock .pull .return_value = [b'{"status":"pull log"}' ]
171
+ client_mock .tasks .return_value = [
172
+ {"Status" : {"State" : "failed" , "ContainerStatus" : {"ContainerID" : "some_id" }}}
173
+ ]
174
+ types_mock .TaskTemplate .return_value = mock_obj
175
+ types_mock .ContainerSpec .return_value = mock_obj
176
+ types_mock .RestartPolicy .return_value = mock_obj
177
+ types_mock .Resources .return_value = mock_obj
178
+
179
+ docker_api_client_patcher .return_value = client_mock
180
+
181
+ operator = DockerSwarmOperator (
182
+ image = "" , auto_remove = auto_remove , task_id = "unittest" , enable_logging = False
183
+ )
184
+ try :
185
+ operator .execute (None )
186
+ except AirflowException :
187
+ pass
188
+
189
+ assert (client_mock .remove_service .call_count > 0 ) == expected_remove_call
190
+
157
191
@mock .patch ("airflow.providers.docker.operators.docker_swarm.types" )
158
192
def test_no_auto_remove (self , types_mock , docker_api_client_patcher ):
159
193
mock_obj = mock .Mock ()
0 commit comments