@@ -24,22 +24,22 @@ def test_transformer_engine_no_config(feature_dirs):
2424 # FP8 enabled - true by the default
2525 assert debug_api .transformer_engine .fp8_gemm_enabled (
2626 "decoder.1.attn.qkv" , gemm = "fprop" , iteration = 0
27- )
27+ )[ 0 ]
2828
29- # modify_tensor_enabled - False by default
29+ # modify_tensor_enabled - ( False, None) by default
3030 assert not debug_api .transformer_engine .modify_tensor_enabled (
3131 "decoder.1.attn.qkv" , gemm = "fprop" , tensor_name = "activation" , iteration = 0
32- )
32+ )[ 0 ]
3333
34- # inspect_tensor_enabled - False by default
34+ # inspect_tensor_enabled - ( False, None) by default
3535 assert not debug_api .transformer_engine .inspect_tensor_enabled (
3636 "decoder.1.attn.qkv" , tensor_name = "activation" , iteration = 0
37- )
37+ )[ 0 ]
3838
39- # inspect_tensor_postquantize - False by default
39+ # inspect_tensor_postquantize - ( False, None) by default
4040 assert not debug_api .transformer_engine .inspect_tensor_postquantize_enabled (
4141 "decoder.1.attn.qkv" , gemm = "fprop" , tensor_name = "activation" , iteration = 0
42- )
42+ )[ 0 ]
4343
4444 finally :
4545 debug_api .end_debug ()
@@ -51,24 +51,24 @@ def test_disable_fp8_gemm(configs_dir, feature_dirs):
5151
5252 assert debug_api .transformer_engine .fp8_gemm_enabled (
5353 "decoder.1.attn.qkv" , gemm = "fprop" , iteration = 0
54- )
54+ )[ 0 ]
5555 assert not debug_api .transformer_engine .fp8_gemm_enabled (
5656 "decoder.1.attn.qkv" , gemm = "dgrad" , iteration = 0
57- )
57+ )[ 0 ]
5858 assert not debug_api .transformer_engine .fp8_gemm_enabled (
5959 "decoder.1.attn.qkv" , gemm = "wgrad" , iteration = 0
60- )
60+ )[ 0 ]
6161
6262 # caching
6363 assert debug_api .transformer_engine .fp8_gemm_enabled (
6464 "decoder.1.attn.qkv" , gemm = "fprop" , iteration = 0
65- )
65+ )[ 0 ]
6666 assert not debug_api .transformer_engine .fp8_gemm_enabled (
6767 "decoder.1.attn.qkv" , gemm = "dgrad" , iteration = 0
68- )
68+ )[ 0 ]
6969 assert not debug_api .transformer_engine .fp8_gemm_enabled (
7070 "decoder.1.attn.qkv" , gemm = "wgrad" , iteration = 0
71- )
71+ )[ 0 ]
7272
7373 finally :
7474 debug_api .end_debug ()
@@ -80,22 +80,22 @@ def test_disable_fp8_layer(configs_dir, feature_dirs):
8080
8181 assert debug_api .transformer_engine .fp8_gemm_enabled (
8282 "decoder.1.mlp.fc1" , gemm = "fprop" , iteration = 0
83- )
83+ )[ 0 ]
8484 assert debug_api .transformer_engine .fp8_gemm_enabled (
8585 "decoder.1.mlp.fc1" , gemm = "wgrad" , iteration = 0
86- )
86+ )[ 0 ]
8787 assert debug_api .transformer_engine .fp8_gemm_enabled (
8888 "decoder.1.mlp.fc1" , gemm = "dgrad" , iteration = 0
89- )
89+ )[ 0 ]
9090 assert not debug_api .transformer_engine .fp8_gemm_enabled (
9191 "decoder.1.attn.qkv" , gemm = "fprop" , iteration = 0
92- )
92+ )[ 0 ]
9393 assert not debug_api .transformer_engine .fp8_gemm_enabled (
9494 "decoder.1.attn.qkv" , gemm = "wgrad" , iteration = 0
95- )
95+ )[ 0 ]
9696 assert not debug_api .transformer_engine .fp8_gemm_enabled (
9797 "decoder.1.attn.qkv" , gemm = "dgrad" , iteration = 0
98- )
98+ )[ 0 ]
9999
100100 finally :
101101 debug_api .end_debug ()
@@ -111,22 +111,22 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
111111 # check modify_tensor_enabled
112112 assert debug_api .transformer_engine .modify_tensor_enabled (
113113 "decoder.1.mlp.fc1" , gemm = "fprop" , tensor_name = "activation" , iteration = 0
114- )
114+ )[ 0 ]
115115 assert debug_api .transformer_engine .modify_tensor_enabled (
116116 "decoder.1.mlp.fc1" , gemm = "fprop" , tensor_name = "weight" , iteration = 0
117- )
117+ )[ 0 ]
118118 assert debug_api .transformer_engine .modify_tensor_enabled (
119119 "decoder.1.mlp.fc1" , gemm = "dgrad" , tensor_name = "gradient" , iteration = 0
120- )
120+ )[ 0 ]
121121 assert not debug_api .transformer_engine .modify_tensor_enabled (
122122 "decoder.1.mlp.fc1" , gemm = "dgrad" , tensor_name = "weight" , iteration = 0
123- )
123+ )[ 0 ]
124124 assert not debug_api .transformer_engine .modify_tensor_enabled (
125125 "decoder.1.mlp.fc1" , gemm = "wgrad" , tensor_name = "gradient" , iteration = 0
126- )
126+ )[ 0 ]
127127 assert not debug_api .transformer_engine .modify_tensor_enabled (
128128 "decoder.1.mlp.fc1" , gemm = "wgrad" , tensor_name = "activation" , iteration = 0
129- )
129+ )[ 0 ]
130130
131131 # check modify_tensor
132132
@@ -168,14 +168,14 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
168168 gemm = "wgrad" ,
169169 tensor_name = "gradient" ,
170170 iteration = 0 ,
171- )
171+ )[ 0 ]
172172
173173 assert not debug_api .transformer_engine .modify_tensor_enabled (
174174 "decoder.1.mlp.fc4" ,
175175 gemm = "fprop" ,
176176 tensor_name = "activation" ,
177177 iteration = 0 ,
178- )
178+ )[ 0 ]
179179 finally :
180180 debug_api .end_debug ()
181181
@@ -191,11 +191,11 @@ def test_fake_quant(configs_dir, feature_dirs):
191191 # modify_tensor_enabled
192192 assert debug_api .transformer_engine .modify_tensor_enabled (
193193 "decoder.1.mlp.fc1" , gemm = "fprop" , tensor_name = "activation" , iteration = 0
194- )
194+ )[ 0 ]
195195
196196 assert debug_api .transformer_engine .modify_tensor_enabled (
197197 "decoder.1.mlp.fc1" , gemm = "dgrad" , tensor_name = "gradient" , iteration = 0
198- )
198+ )[ 0 ]
199199
200200 # modify_tensor
201201 debug_api .transformer_engine .modify_tensor (
@@ -218,11 +218,11 @@ def test_fake_quant(configs_dir, feature_dirs):
218218
219219 assert debug_api .transformer_engine .fp8_gemm_enabled (
220220 "decoder.1.fc2" , gemm = "wgrad" , iteration = 0
221- )
221+ )[ 0 ]
222222 # caching
223223 assert debug_api .transformer_engine .fp8_gemm_enabled (
224224 "decoder.1.fc2" , gemm = "wgrad" , iteration = 0
225- )
225+ )[ 0 ]
226226 finally :
227227 debug_api .end_debug ()
228228
@@ -265,21 +265,20 @@ def assert_empty():
265265 assert stats [("decoder.1.mlp.fc1" , "activation" , "cur_amax" , 200 )] == tensor .abs ().max ()
266266 assert not debug_api .transformer_engine .inspect_tensor_enabled (
267267 "decoder.1.mlp.fc1" , tensor_name = "activation" , iteration = 201
268- )
268+ )[ 0 ]
269269 assert not debug_api .transformer_engine .inspect_tensor_enabled (
270270 "decoder.2.mlp.fc1" , tensor_name = "activation" , iteration = 200
271- )
271+ )[ 0 ]
272272 assert not debug_api .transformer_engine .inspect_tensor_enabled (
273273 "decoder.1.mlp.fc1" , tensor_name = "gradient" , iteration = 200
274- )
274+ )[ 0 ]
275275
276276 expected_underflows = (tensor_fp8 ._data == 0 ).sum () * 100 / (100 * 100 * 5 )
277- expected_overflows = (tensor_fp8 ._data == 126 ).sum () * 100 / (100 * 100 * 5 )
278277
279278 # TE FP8 tensor stats --
280279 assert debug_api .transformer_engine .inspect_tensor_postquantize_enabled (
281280 "decoder.1.mlp.fc1" , tensor_name = "gradient" , gemm = "wgrad" , iteration = 200
282- )
281+ )[ 0 ]
283282 debug_api .transformer_engine .inspect_tensor_postquantize (
284283 "decoder.1.mlp.fc1" ,
285284 tensor = tensor_fp8 ,
@@ -295,10 +294,10 @@ def assert_empty():
295294
296295 assert not debug_api .transformer_engine .inspect_tensor_postquantize_enabled (
297296 "decoder.1.mlp.fc1" , tensor_name = "activation" , gemm = "fprop" , iteration = 201
298- )
297+ )[ 0 ]
299298 assert not debug_api .transformer_engine .inspect_tensor_postquantize_enabled (
300299 "decoder.2.mlp.fc1" , tensor_name = "gradient" , gemm = "wgrad" , iteration = 200
301- )
300+ )[ 0 ]
302301
303302 # Second config in same yaml
304303 tensor = torch .rand ((100 , 100 , 5 ))
@@ -328,7 +327,7 @@ def assert_empty():
328327
329328 assert not debug_api .transformer_engine .inspect_tensor_enabled (
330329 "decoder.7.mlp.fc1" , tensor_name = "weight" , iteration = 201
331- )
330+ )[ 0 ]
332331 assert_empty ()
333332
334333 finally :
0 commit comments