Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 36 additions & 36 deletions src/sycl/chunked_prefill.cpp → src/sycl/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
#include "cutlass/util/device_memory.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/sycl_event_manager.hpp"
#include "kernels/chunk_prefill/fmha_fusion.hpp"
#include "kernels/chunk_prefill/tile_scheduler_chunk_prefill.hpp"
#include "kernels/chunk_prefill/xe_chunk_prefill.hpp"
#include "kernels/chunk_prefill/xe_flash_attn_chunk_prefill_epilogue.hpp"
#include "kernels/chunk_prefill/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp"
#include "kernels/flash_attention/fmha_fusion.hpp"
#include "kernels/flash_attention/tile_scheduler_split_kv.hpp"
#include "kernels/flash_attention/xe_flash_attn_split_kv.hpp"
#include "kernels/flash_attention/xe_flash_attn_split_kv_epilogue.hpp"
#include "kernels/flash_attention/xe_flash_attn_split_kv_softmax_epilogue.hpp"

using namespace cute;

Expand Down Expand Up @@ -166,25 +166,25 @@ using LayoutK = cutlass::layout::ColumnMajor;
using LayoutV = cutlass::layout::RowMajor;
using LayoutO = cutlass::layout::RowMajor;

template <class FMHAChunkPrefillKernel, bool isVarLen>
template <class FMHASplitKVKernel, bool isVarLen>
struct KernelRunner {
using StrideQ = typename FMHAChunkPrefillKernel::StrideQ;
using StrideK = typename FMHAChunkPrefillKernel::StrideK;
using StrideV = typename FMHAChunkPrefillKernel::StrideV;
using StrideO = typename FMHAChunkPrefillKernel::StrideO;

using ElementQ = typename FMHAChunkPrefillKernel::ElementQ;
using ElementK = typename FMHAChunkPrefillKernel::ElementK;
using ElementV = typename FMHAChunkPrefillKernel::ElementV;
using ElementAcc = typename FMHAChunkPrefillKernel::ElementAccumulator;
using ElementSink = typename FMHAChunkPrefillKernel::ElementSink;

using CollectiveEpilogue = typename FMHAChunkPrefillKernel::CollectiveEpilogue;
using StrideQ = typename FMHASplitKVKernel::StrideQ;
using StrideK = typename FMHASplitKVKernel::StrideK;
using StrideV = typename FMHASplitKVKernel::StrideV;
using StrideO = typename FMHASplitKVKernel::StrideO;

using ElementQ = typename FMHASplitKVKernel::ElementQ;
using ElementK = typename FMHASplitKVKernel::ElementK;
using ElementV = typename FMHASplitKVKernel::ElementV;
using ElementAcc = typename FMHASplitKVKernel::ElementAccumulator;
using ElementSink = typename FMHASplitKVKernel::ElementSink;

using CollectiveEpilogue = typename FMHASplitKVKernel::CollectiveEpilogue;
using ElementOutput = typename CollectiveEpilogue::ElementOutput;
using ElementCompute = typename CollectiveEpilogue::ElementCompute;
using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;

using ProblemShapeType = typename FMHAChunkPrefillKernel::ProblemShape;
using ProblemShapeType = typename FMHASplitKVKernel::ProblemShape;

//
// Data members
Expand Down Expand Up @@ -274,12 +274,12 @@ struct KernelRunner {

// Note that the GemmUniversalAdapter currently doesn't support flash attention, which is why this
// secondary `run` function is required to launch the kernel.
static void run(typename FMHAChunkPrefillKernel::Params params) {
dim3 const block = FMHAChunkPrefillKernel::get_block_shape();
dim3 const grid = FMHAChunkPrefillKernel::get_grid_shape(params);
static void run(typename FMHASplitKVKernel::Params params) {
dim3 const block = FMHASplitKVKernel::get_block_shape();
dim3 const grid = FMHASplitKVKernel::get_grid_shape(params);

// configure smem size and carveout
int smem_size = FMHAChunkPrefillKernel::SharedStorageSize;
int smem_size = FMHASplitKVKernel::SharedStorageSize;

const auto sycl_block = compat::dim3(block.x, block.y, block.z);
const auto sycl_grid = compat::dim3(grid.x, grid.y, grid.z);
Expand All @@ -289,18 +289,18 @@ struct KernelRunner {
sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size),
};
compat::experimental::kernel_properties kernel_props{
sycl::ext::oneapi::experimental::sub_group_size<FMHAChunkPrefillKernel::DispatchPolicy::SubgroupSize>};
sycl::ext::oneapi::experimental::sub_group_size<FMHASplitKVKernel::DispatchPolicy::SubgroupSize>};
compat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props};

sycl::ext::oneapi::experimental::launch_config config(policy.get_range(), policy.get_launch_properties());
auto cgf = [&](::sycl::handler& cgh) {
auto KernelFunctor =
compat::experimental::detail::build_kernel_functor<cutlass::device_kernel<FMHAChunkPrefillKernel>>(
compat::experimental::detail::build_kernel_functor<cutlass::device_kernel<FMHASplitKVKernel>>(
cgh, policy, params);
sycl::ext::oneapi::experimental::detail::
LaunchConfigAccess<sycl::nd_range<3>, decltype(policy.get_launch_properties())>
ConfigAccess(config);
cgh.parallel_for<KernelCur<FMHAChunkPrefillKernel>>(
cgh.parallel_for<KernelCur<FMHASplitKVKernel>>(
ConfigAccess.getRange(), ConfigAccess.getProperties(), KernelFunctor);
};
auto stream = at::xpu::getCurrentXPUStream();
Expand All @@ -311,7 +311,7 @@ struct KernelRunner {
cutlass::Status run(const Flash_fwd_params& params, const cutlass::KernelHardwareInfo& hw_info) {
ProblemShapeType problem_size = initialize(params);

typename FMHAChunkPrefillKernel::Arguments arguments{
typename FMHASplitKVKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{// static_cast<const ElementQ*>(params.q_ptr),
Expand All @@ -337,18 +337,18 @@ struct KernelRunner {
hw_info};

// Define device-global scratch memory
size_t workspace_size = FMHAChunkPrefillKernel::get_workspace_size(arguments);
size_t workspace_size = FMHASplitKVKernel::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

if (!FMHAChunkPrefillKernel::can_implement(arguments)) {
if (!FMHASplitKVKernel::can_implement(arguments)) {
return cutlass::Status::kErrorInvalidProblem;
}

// Initialize the workspace
(FMHAChunkPrefillKernel::initialize_workspace(arguments, workspace.get()));
(FMHASplitKVKernel::initialize_workspace(arguments, workspace.get()));

// Convert host-side arguments to device-side arguments to be passed to the kernel
auto params_kernel = FMHAChunkPrefillKernel::to_underlying_arguments(arguments, workspace.get());
auto params_kernel = FMHASplitKVKernel::to_underlying_arguments(arguments, workspace.get());

// Run the Flash Attention implementation.
run(params_kernel);
Expand Down Expand Up @@ -386,7 +386,7 @@ struct FMHAConfig {

using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
using CollectiveEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillEpilogue<
using CollectiveEpilogue = cutlass::flash_attention::collective::FlashSplitKVEpilogue<
Sink,
EpilogueDispatchPolicy,
MMAOperation,
Expand All @@ -399,15 +399,15 @@ struct FMHAConfig {
GmemTiledCopyStore,
ElementSink>;
using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::
FlashChunkPrefillSoftmaxEpilogue<Causal, LocalMask, EpilogueDispatchPolicy, ElementAccumulator>;
FlashSplitKVSoftmaxEpilogue<Causal, LocalMask, EpilogueDispatchPolicy, ElementAccumulator>;

using ProblemShapeRegular = cute::tuple<int, int, int, int, int, int, int, int>;
using namespace cutlass::fmha::collective;
using ProblemShapeVarlen = cute::tuple<int, int, int, VariableLength, VariableLength, VariableLength, int, int>;
using ProblemShapeType = std::conditional_t<isVarLen, ProblemShapeVarlen, ProblemShapeRegular>;

// Mainloop
using CollectiveMainloop = cutlass::flash_attention::collective::FlashChunkPrefillMma<
using CollectiveMainloop = cutlass::flash_attention::collective::FlashSplitKVMma<
GEMMDispatchPolicy,
ProblemShapeType,
ElementInputQ,
Expand All @@ -427,14 +427,14 @@ struct FMHAConfig {
LocalMask,
PagedKV>;

using FMHAChunkPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefillChunk<
using FMHASplitKVKernel = cutlass::flash_attention::kernel::FMHASplitKV<
ProblemShapeType,
CollectiveMainloop,
CollectiveSoftmaxEpilogue,
CollectiveEpilogue,
Scheduler>;

KernelRunner<FMHAChunkPrefillKernel, isVarLen> runner;
KernelRunner<FMHASplitKVKernel, isVarLen> runner;

(runner.run(params, hw_info));
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "xe_flash_attn_chunk_prefill_mma.hpp"
#include "xe_flash_attn_split_kv_mma.hpp"

namespace cutlass::flash_attention::kernel {

Expand All @@ -44,15 +44,15 @@ template <
class CollectiveSoftmaxEpilogue_,
class CollectiveEpilogue_,
class TileScheduler_ = void>
class FMHAPrefillChunk;
class FMHASplitKV;
///////////////////////////////////////////////////////////////////////////////
template <
class ProblemShape_,
class CollectiveMainloop_,
class CollectiveSoftmaxEpilogue_,
class CollectiveEpilogue_,
class TileScheduler_>
class FMHAPrefillChunk {
class FMHASplitKV {
public:
//
// Type Aliases
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ template <
class TileShapeOutput_,
class SubgroupLayout_,
class... Args>
class FlashChunkPrefillEpilogue {
class FlashSplitKVEpilogue {
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Could not find an epilogue specialization.");
};

Expand All @@ -71,7 +71,7 @@ template <
class ElementLSE_,
class CopyOpO_,
class ElementSink_>
class FlashChunkPrefillEpilogue<
class FlashSplitKVEpilogue<
Sink_,
epilogue::IntelXeXMX16,
MMAOperation_,
Expand Down Expand Up @@ -191,7 +191,7 @@ class FlashChunkPrefillEpilogue<
}

CUTLASS_HOST_DEVICE
FlashChunkPrefillEpilogue(Params const& params_, TensorStorage const&) : params(params_) {}
FlashSplitKVEpilogue(Params const& params_, TensorStorage const&) : params(params_) {}

template <
class ProblemShape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ template <
bool CausalMask_,
bool LocalMask_,
bool PagedKV_>
struct FlashChunkPrefillMma {
struct FlashSplitKVMma {
static_assert(cutlass::detail::dependent_false<ElementQ_>, "Could not find a mainloop specialization.");
};

Expand All @@ -92,7 +92,7 @@ template <
bool CausalMask_,
bool LocalMask_,
bool PagedKV_>
struct FlashChunkPrefillMma<
struct FlashSplitKVMma<
gemm::MainloopIntelXeXMX16<Stages>,
ProblemShapeType_,
ElementQ_,
Expand Down Expand Up @@ -224,7 +224,7 @@ struct FlashChunkPrefillMma<
// Methods
//

FlashChunkPrefillMma() = default;
FlashSplitKVMma() = default;

static constexpr Params
to_underlying_arguments(ProblemShapeType const& problem_shape, Arguments const& args, void* workspace) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ namespace collective {
/////////////////////////////////////////////////////////////////////////////////////////////////

template <bool CausalMask_, bool LocalMask_, class DispatchPolicy, class... Args>
class FlashChunkPrefillSoftmaxEpilogue {
class FlashSplitKVSoftmaxEpilogue {
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Could not find an epilogue specialization.");
};

template <bool CausalMask_, bool LocalMask_, class Element_>
class FlashChunkPrefillSoftmaxEpilogue<CausalMask_, LocalMask_, epilogue::IntelXeXMX16, Element_> {
class FlashSplitKVSoftmaxEpilogue<CausalMask_, LocalMask_, epilogue::IntelXeXMX16, Element_> {
public:
//
// Type Aliases
Expand Down Expand Up @@ -103,7 +103,7 @@ class FlashChunkPrefillSoftmaxEpilogue<CausalMask_, LocalMask_, epilogue::IntelX
}

CUTLASS_HOST_DEVICE
FlashChunkPrefillSoftmaxEpilogue(Params const& params_) : params(params_) {}
FlashSplitKVSoftmaxEpilogue(Params const& params_) : params(params_) {}

template <int Vec, int FragsM, int FragsN, class FragAcc, class FragMax, class FragSum>
CUTLASS_DEVICE void scale_exp_log2(FragAcc& frag_s, FragMax const& max, FragSum& sum) {
Expand Down