@@ -800,7 +800,11 @@ def forward(
800800 scale_a = inv_scale_x ,
801801 scale_b = inv_scale_w ,
802802 use_fast_accum = True ,
803- )[0 ]
803+ )
804+
805+ if isinstance (out , tuple ):
806+ out = out [0 ]
807+
804808 return out .reshape (* ctx .x_shape [:- 1 ], w .shape [0 ])
805809
806810 @staticmethod
@@ -814,15 +818,23 @@ def backward(ctx: Any, out_grad) -> Any:
814818 scale_a = out_grad_scale ,
815819 scale_b = ctx .inv_scale_w ,
816820 use_fast_accum = True ,
817- )[0 ]
821+ )
822+
823+ if isinstance (x_grad , tuple ):
824+ x_grad = x_grad [0 ]
825+
818826 w_grad = torch ._scaled_mm (
819827 out_grad_fp8 .t ().contiguous (),
820828 ctx .x_fp8 .t ().contiguous ().t (),
821829 out_dtype = ctx .out_dtype ,
822830 scale_a = out_grad_scale ,
823831 scale_b = ctx .inv_scale_x ,
824832 use_fast_accum = True ,
825- )[0 ]
833+ )
834+
835+ if isinstance (w_grad , tuple ):
836+ w_grad = w_grad [0 ]
837+
826838 bias_grad = None
827839 if ctx .has_bias :
828840 bias_grad = out_grad .sum (0 )
@@ -835,8 +847,14 @@ class _LinearFp8DeepGemm(torch.autograd.Function):
835847 """
836848
837849 def forward (ctx : Any , x : torch .Tensor , w : torch .Tensor ) -> torch .Tensor :
838- if not (x .dim () == 2 and w .dim () == 2 ):
839- raise ValueError ("Batched fp8 deep_gemm is not supported" )
850+ has_batch_dim = False
851+ if x .dim () == 3 :
852+ has_batch_dim = True
853+ if x .size (1 ) != 1 :
854+ raise ValueError (f"Batched fp8 deep_gemm is not supported, found x shape: { x .shape } " )
855+ x = x .squeeze (1 )
856+ ctx .has_batch_dim = has_batch_dim
857+
840858 # x: (m, k), w: (n, k)
841859 # x @ w_t -> (m, k) @ (k, n) -> deep_gemm((m, k), (n, k))
842860 (m , k ), (n , _ ) = x .shape , w .shape
@@ -848,12 +866,17 @@ def forward(ctx: Any, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
848866 ctx .w_t_per_plk = per_block_cast_to_fp8 (w .t ())
849867 ctx .x_t_per_blk = per_block_cast_to_fp8 (x .t ())
850868 ctx .mnk = (m , n , k )
869+ if has_batch_dim :
870+ out = out .unsqueeze (1 )
851871 return out
852872
853873 def backward (ctx : Any , o_grad : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
854874 # o_grad: (m, n)
855875 # x_grad: (m, k) -> (m, n) @ (n, k) -> deep_gemm((m, n), (k, n))
856876 # w_grad: (n, k) -> (m, n).t() @ (m, k) -> deep_gemm((m, n).t(), (k, m))
877+ if ctx .has_batch_dim :
878+ o_grad = o_grad .squeeze (1 )
879+
857880 m , n , k = ctx .mnk
858881 o_per_tok = per_token_cast_to_fp8 (o_grad )
859882
@@ -864,6 +887,9 @@ def backward(ctx: Any, o_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor
864887 w_grad = torch .empty ((n , k ), dtype = torch .bfloat16 , device = o_grad .device )
865888 deepgemm_fp8_gemm (o_grad_t_per_tok , ctx .x_t_per_blk , w_grad )
866889
890+ if ctx .has_batch_dim :
891+ x_grad = x_grad .unsqueeze (1 )
892+
867893 return x_grad , w_grad
868894
869895
0 commit comments