@@ -35,12 +35,21 @@ def add_args(arg_parser): # pragma: no cover
35
35
type = str ,
36
36
required = False ,
37
37
)
38
+ arg_parser .add_argument (
39
+ "--mode" ,
40
+ dest = "mode" ,
41
+ metavar = "MODE" ,
42
+ help = "Execution mode - can be 'treated' or 'control'" ,
43
+ type = str ,
44
+ required = False ,
45
+ )
38
46
39
- def __init__ (self , context , flag_name , flag_resolution_methods , ignored_module_check_fn_path = None ):
47
+ def __init__ (self , context , flag_name , flag_resolution_methods , ignored_module_check_fn_path = None , mode = "treated" ):
40
48
super ().__init__ (context )
41
49
self .flag_name = flag_name
42
50
self .is_in_feature_flag_block = False
43
51
self .found_return_stmt_in_ff_block = False
52
+ self .running_in_treated_mode = mode == "treated"
44
53
45
54
if ignored_module_check_fn_path is None :
46
55
ignored_module_check_fn_path = self .DEFAULT_TEST_MODULE_CHECK_PATH
@@ -132,18 +141,20 @@ def leave_If(self, original_node, updated_node):
132
141
if not self .is_in_feature_flag_block :
133
142
return updated_node
134
143
135
- if matchers . matches ( updated_node . test , self .flag_resolution_matcher ):
136
- if self .is_treatment_method :
144
+ if self ._should_assume_that_flag_is_true ( ):
145
+ if matchers . matches ( updated_node . test , self .flag_resolution_matcher ) :
137
146
replaced_node = updated_node .body
138
- else :
139
- replaced_node = updated_node .orelse .body
140
- elif matchers .matches (updated_node .test , _inside_not_matcher (self .flag_resolution_matcher )):
141
- if self .is_treatment_method :
147
+ elif matchers .matches (updated_node .test , _inside_not_matcher (self .flag_resolution_matcher )):
142
148
replaced_node = updated_node .orelse .body
143
149
else :
144
- replaced_node = updated_node . body
150
+ return updated_node
145
151
else :
146
- return updated_node
152
+ if matchers .matches (updated_node .test , self .flag_resolution_matcher ):
153
+ replaced_node = updated_node .orelse .body
154
+ elif matchers .matches (updated_node .test , _inside_not_matcher (self .flag_resolution_matcher )):
155
+ replaced_node = updated_node .body
156
+ else :
157
+ return updated_node
147
158
148
159
return_statements = matchers .findall (replaced_node , matchers .Return ())
149
160
self .found_return_stmt_in_ff_block = len (return_statements ) > 0
@@ -196,6 +207,11 @@ def _updated_tuple_assignment(self, updated_node):
196
207
),
197
208
)
198
209
210
+ def _should_assume_that_flag_is_true (self ):
211
+ return (self .is_treatment_method and self .running_in_treated_mode ) or (
212
+ not self .is_treatment_method and not self .running_in_treated_mode
213
+ )
214
+
199
215
200
216
def _matches_flag_name (flag_name , n ):
201
217
return matchers .matches (n [0 ].value , matchers .Name (flag_name ))
0 commit comments