Skip to content

Commit 41b2c7c

Browse files
authored
Fix flash_attention test error
Differential Revision: D81511241 Pull Request resolved: #382
1 parent a8dfdf8 commit 41b2c7c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

test/test_gpu/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ def _run_one_operator(args: List[str]):
8787
return
8888
if op.has_bwd():
8989
del op
90-
tb_args.mode = "bwd"
9190
if tb_args.op in BWD_ARGS_OPS:
9291
args.extend(BWD_ARGS_OPS[tb_args.op])
9392
tb_args, extra_args = parser.parse_known_args(args)
93+
tb_args.mode = "bwd"
9494
op = Operator(tb_args=tb_args, extra_args=extra_args)
9595
op.run()
9696
check_ci_output(op)

0 commit comments

Comments
 (0)