@@ -195,6 +195,19 @@ void dispatch(
195195 );
196196}
197197
198+ void fake_dispatch (
199+ fptr_t ptr,
200+ at::Tensor &outExpertNumTokens,
201+ at::Tensor &outExpertX,
202+ const std::optional<at::Tensor> &outExpertXScale,
203+ const at::Tensor &dpX,
204+ const std::optional<at::Tensor> &dpXScale,
205+ const at::Tensor &indices,
206+ const std::optional<at::Tensor> &boundM,
207+ bool doSend,
208+ bool doRecv
209+ ) {}
210+
198211template <typename Kernel, typename T, typename U>
199212void combineImpl (
200213 Kernel *all_to_all,
@@ -297,6 +310,17 @@ void combine(
297310 }
298311}
299312
313+ void fake_combine (
314+ fptr_t ptr,
315+ at::Tensor &outTokens,
316+ const at::Tensor &indices,
317+ const at::Tensor &weights,
318+ const at::Tensor &expertY,
319+ const std::optional<at::Tensor> &boundM,
320+ bool doSend,
321+ bool doRecv
322+ ) {}
323+
300324#undef _CHECK_TENSOR
301325
302326} // namespace
@@ -306,11 +330,64 @@ void register_all_to_all_ops(torch::Library &m) {
306330 m.def (" all_to_all_destroy" , &destroy);
307331
308332 m.def (" all_to_all_internode_create" , &create_internode);
309- m.def (" all_to_all_internode_dispatch" , &dispatch<AllToAllInterNode>);
310- m.def (" all_to_all_internode_combine" , &combine<AllToAllInterNode>);
333+
334+ m.def (" all_to_all_internode_dispatch("
335+ " int fptr,"
336+ " Tensor! out_expert_num_tokens,"
337+ " Tensor! out_expert_x,"
338+ " Tensor!? out_expert_x_scale,"
339+ " Tensor dp_x,"
340+ " Tensor? dp_x_scale,"
341+ " Tensor indices,"
342+ " Tensor? bound_m,"
343+ " bool do_send,"
344+ " bool do_recv"
345+ " ) -> ()" );
346+ m.impl (" all_to_all_internode_dispatch" , c10::kCUDA , &dispatch<AllToAllInterNode>);
347+ m.impl (" all_to_all_internode_dispatch" , c10::kMeta , &fake_dispatch);
348+
349+ m.def (" all_to_all_internode_combine("
350+ " int fptr,"
351+ " Tensor! out_tokens,"
352+ " Tensor indices,"
353+ " Tensor weights,"
354+ " Tensor expert_y,"
355+ " Tensor? bound_m,"
356+ " bool do_send,"
357+ " bool do_recv"
358+ " ) -> ()" );
359+ m.impl (" all_to_all_internode_combine" , c10::kCUDA , &combine<AllToAllInterNode>);
360+ m.impl (" all_to_all_internode_combine" , c10::kMeta , &fake_combine);
311361
312362 m.def (" all_to_all_intranode_create" , &create_intranode);
313- m.def (" all_to_all_intranode_dispatch" , &dispatch<AllToAllIntraNode>);
314- m.def (" all_to_all_intranode_combine" , &combine<AllToAllIntraNode>);
363+
364+ m.def (" all_to_all_intranode_dispatch("
365+ " int fptr,"
366+ " Tensor! out_expert_num_tokens,"
367+ " Tensor! out_expert_x,"
368+ " Tensor!? out_expert_x_scale,"
369+ " Tensor dp_x,"
370+ " Tensor? dp_x_scale,"
371+ " Tensor indices,"
372+ " Tensor? bound_m,"
373+ " bool do_send,"
374+ " bool do_recv"
375+ " ) -> ()" );
376+ m.impl (" all_to_all_intranode_dispatch" , c10::kCUDA , &dispatch<AllToAllIntraNode>);
377+ m.impl (" all_to_all_intranode_dispatch" , c10::kMeta , &fake_dispatch);
378+
379+ m.def (" all_to_all_intranode_combine("
380+ " int fptr,"
381+ " Tensor! out_tokens,"
382+ " Tensor indices,"
383+ " Tensor weights,"
384+ " Tensor expert_y,"
385+ " Tensor? bound_m,"
386+ " bool do_send,"
387+ " bool do_recv"
388+ " ) -> ()" );
389+ m.impl (" all_to_all_intranode_combine" , c10::kCUDA , &combine<AllToAllIntraNode>);
390+ m.impl (" all_to_all_intranode_combine" , c10::kMeta , &fake_combine);
315391}
392+
316393} // namespace pplx
0 commit comments