Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
{CapabilityCooperativeMatrixKHR});
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixBFloat16ComponentTypeINTEL,
{CapabilityCooperativeMatrixKHR});
ADD_VEC_INIT(internal::CapabilityPackedCooperativeMatrixINTEL,
{CapabilityCooperativeMatrixKHR});
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixPrefetchINTEL,
{CapabilityCooperativeMatrixKHR});
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixInvocationInstructionsINTEL,
Expand Down
90 changes: 66 additions & 24 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3691,27 +3691,6 @@ _SPIRV_OP(ConvertFToBF16INTEL)
_SPIRV_OP(ConvertBF16ToFINTEL)
#undef _SPIRV_OP

class SPIRVJointMatrixINTELInstBase : public SPIRVInstTemplateBase {
protected:
std::optional<ExtensionID> getRequiredExtension() const override {
return ExtensionID::SPV_INTEL_joint_matrix;
}
};

class SPIRVJointMatrixINTELWorkItemInst : public SPIRVJointMatrixINTELInstBase {
protected:
SPIRVCapVec getRequiredCapability() const override {
return getVec(internal::CapabilityJointMatrixWIInstructionsINTEL);
}
};

#define _SPIRV_OP(x, ...) \
typedef SPIRVInstTemplate<SPIRVJointMatrixINTELWorkItemInst, \
internal::Op##x##INTEL, __VA_ARGS__> \
SPIRV##x##INTEL;
_SPIRV_OP(JointMatrixGetElementCoord, true, 5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To double check, what instruction is used in SYCL headers right now?

Copy link
Contributor Author

@vmaksimo vmaksimo Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, then I'd suggest to first replace the joint matrix version to cooperative matrix version in the headers and see if anything fails.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created a draft PR: intel/llvm#20855

#undef _SPIRV_OP

class SPIRVCooperativeMatrixPrefetchINTELInstBase
: public SPIRVInstTemplateBase {
protected:
Expand All @@ -3737,8 +3716,28 @@ class SPIRVCooperativeMatrixCheckedInstructionsINTELInstBase
return ExtensionID::SPV_INTEL_joint_matrix;
}
SPIRVCapVec getRequiredCapability() const override {
return getVec(
internal::CapabilityCooperativeMatrixCheckedInstructionsINTEL);
auto CV =
getVec(internal::CapabilityCooperativeMatrixCheckedInstructionsINTEL);
if (SPIRVValue *LayoutVal = getMemoryLayout()) {
if (isConstantOpCode(LayoutVal->getOpCode())) {
uint64_t Layout =
static_cast<SPIRVConstant *>(LayoutVal)->getZExtIntValue();
if (Layout == internal::CooperativeMatrixLayoutPackedINTEL)
CV.push_back(internal::CapabilityPackedCooperativeMatrixINTEL);
}
}
return CV;
}
SPIRVValue *getMemoryLayout() const {
if (OpCode == internal::OpCooperativeMatrixLoadCheckedINTEL)
return const_cast<
SPIRVCooperativeMatrixCheckedInstructionsINTELInstBase *>(this)
->getOperand(3);
if (OpCode == internal::OpCooperativeMatrixStoreCheckedINTEL)
return const_cast<
SPIRVCooperativeMatrixCheckedInstructionsINTELInstBase *>(this)
->getOperand(4);
return nullptr;
}
};

Expand Down Expand Up @@ -3790,6 +3789,7 @@ class SPIRVCooperativeMatrixInvocationInstructionsINTELInstBase
internal::Op##x##INTEL, __VA_ARGS__> \
SPIRV##x##INTEL;
_SPIRV_OP(CooperativeMatrixApplyFunction, true, 5)
_SPIRV_OP(CooperativeMatrixGetElementCoord, true, 5)
#undef _SPIRV_OP

class SPIRVCooperativeMatrixKHRInstBase : public SPIRVInstTemplateBase {
Expand All @@ -3798,7 +3798,49 @@ class SPIRVCooperativeMatrixKHRInstBase : public SPIRVInstTemplateBase {
return ExtensionID::SPV_KHR_cooperative_matrix;
}
SPIRVCapVec getRequiredCapability() const override {
return getVec(CapabilityCooperativeMatrixKHR);
auto CV = getVec(CapabilityCooperativeMatrixKHR);
if (SPIRVValue *LayoutVal = getMemoryLayout()) {
if (isConstantOpCode(LayoutVal->getOpCode())) {
uint64_t Layout =
static_cast<SPIRVConstant *>(LayoutVal)->getZExtIntValue();
if (Layout == internal::CooperativeMatrixLayoutPackedINTEL)
CV.push_back(internal::CapabilityPackedCooperativeMatrixINTEL);
}
}
if (OpCode == OpCooperativeMatrixMulAddKHR && Ops.size() == 4) {
// If Cooperative Matrix Operand literal is present, check for the
// additional capabilities it may require.
uint64_t CoopOperands = Ops[3];
if (CoopOperands &
internal::
CooperativeMatrixOperandsMatrixAAndBTF32ComponentsINTELMask) {
CV.push_back(
internal::CapabilityCooperativeMatrixTF32ComponentTypeINTEL);
Module->addExtension(ExtensionID::SPV_INTEL_joint_matrix);
}
if (CoopOperands &
(internal::
CooperativeMatrixOperandsMatrixAAndBBFloat16ComponentsINTELMask |
internal::
CooperativeMatrixOperandsMatrixCBFloat16ComponentsINTELMask |
internal::
CooperativeMatrixOperandsMatrixResultBFloat16ComponentsINTELMask)) {
CV.push_back(
internal::CapabilityCooperativeMatrixBFloat16ComponentTypeINTEL);
Module->addExtension(ExtensionID::SPV_INTEL_joint_matrix);
}
}
return CV;
}

SPIRVValue *getMemoryLayout() const {
if (OpCode == OpCooperativeMatrixLoadKHR)
return const_cast<SPIRVCooperativeMatrixKHRInstBase *>(this)->getOperand(
1);
if (OpCode == OpCooperativeMatrixStoreKHR)
return const_cast<SPIRVCooperativeMatrixKHRInstBase *>(this)->getOperand(
2);
return nullptr;
}
};

Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,8 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
"CooperativeMatrixTF32ComponentTypeINTEL");
add(internal::CapabilityCooperativeMatrixBFloat16ComponentTypeINTEL,
"CooperativeMatrixBFloat16ComponentTypeINTEL");
add(internal::CapabilityPackedCooperativeMatrixINTEL,
"PackedCooperativeMatrixINTEL");
add(internal::CapabilityCooperativeMatrixPrefetchINTEL,
"CooperativeMatrixPrefetchINTEL");
add(internal::CapabilityCooperativeMatrixInvocationInstructionsINTEL,
Expand Down
4 changes: 2 additions & 2 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

_SPIRV_OP_INTERNAL(Forward, internal::OpForward)
_SPIRV_OP_INTERNAL(TypeTokenINTEL, internal::OpTypeTokenINTEL)
_SPIRV_OP_INTERNAL(JointMatrixGetElementCoordINTEL,
internal::OpJointMatrixGetElementCoordINTEL)
_SPIRV_OP_INTERNAL(CooperativeMatrixGetElementCoordINTEL,
internal::OpCooperativeMatrixGetElementCoordINTEL)
_SPIRV_OP_INTERNAL(CooperativeMatrixPrefetchINTEL,
internal::OpCooperativeMatrixPrefetchINTEL)
_SPIRV_OP_INTERNAL(CooperativeMatrixLoadCheckedINTEL,
Expand Down
21 changes: 14 additions & 7 deletions lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ enum InternalOp {
IOpClampConvertFToSINTEL = 6424,
IOpMaskedGatherINTEL = 6428,
IOpMaskedScatterINTEL = 6429,
IOpJointMatrixGetElementCoordINTEL = 6440,
IOpCooperativeMatrixGetElementCoordINTEL = 6440,
IOpCooperativeMatrixApplyFunctionINTEL = 6448,
IOpCooperativeMatrixPrefetchINTEL = 6449,
IOpConvertHandleToImageINTEL = 6529,
Expand Down Expand Up @@ -118,6 +118,7 @@ enum InternalCapability {
ICapabilityAtomicBFloat16LoadStoreINTEL = 6262,
ICapabilityCooperativeMatrixPrefetchINTEL = 6411,
ICapabilityMaskedGatherScatterINTEL = 6427,
ICapabilityPackedCooperativeMatrixINTEL = 6434,
ICapabilityJointMatrixWIInstructionsINTEL = 6435,
ICapabilityCooperativeMatrixInvocationInstructionsINTEL = 6435,
ICapabilityCooperativeMatrixTF32ComponentTypeINTEL = 6436,
Expand All @@ -133,11 +134,16 @@ enum InternalExecutionMode {
constexpr LinkageType LinkageTypeInternal =
static_cast<LinkageType>(ILTInternal);

enum InternalJointMatrixLayout {
RowMajor = 0,
ColumnMajor = 1,
PackedA = 2,
PackedB = 3
// Cooperative Matrix Layout for SPV_INTEL_joint_matrix extension.
enum InternalCooperativeMatrixLayout { CooperativeMatrixLayoutPackedINTEL = 2 };

// Cooperative Matrix Operands for SPV_INTEL_joint_matrix extension
// These are bit flags for component type interpretation.
enum InternalCooperativeMatrixOperandsMask {
CooperativeMatrixOperandsMatrixAAndBTF32ComponentsINTELMask = 0x20,
CooperativeMatrixOperandsMatrixAAndBBFloat16ComponentsINTELMask = 0x40,
CooperativeMatrixOperandsMatrixCBFloat16ComponentsINTELMask = 0x80,
CooperativeMatrixOperandsMatrixResultBFloat16ComponentsINTELMask = 0x100
};

enum InternalFPEncoding {
Expand All @@ -154,7 +160,8 @@ enum InternalBuiltIn {
_SPIRV_OP(Capability, JointMatrixWIInstructionsINTEL)
_SPIRV_OP(Capability, CooperativeMatrixTF32ComponentTypeINTEL)
_SPIRV_OP(Capability, CooperativeMatrixBFloat16ComponentTypeINTEL)
_SPIRV_OP(Op, JointMatrixGetElementCoordINTEL)
_SPIRV_OP(Capability, PackedCooperativeMatrixINTEL)
_SPIRV_OP(Op, CooperativeMatrixGetElementCoordINTEL)

_SPIRV_OP(Capability, CooperativeMatrixPrefetchINTEL)
_SPIRV_OP(Op, CooperativeMatrixPrefetchINTEL)
Expand Down
45 changes: 45 additions & 0 deletions test/extensions/INTEL/SPV_INTEL_joint_matrix/bf16.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
; RUN: llvm-as < %s -o %t.bc
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix -o %t.spv
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a spirv-val check? Or is not prepared yet?


; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM

; RUN: not llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix 2>&1 \
; RUN: | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension
; CHECK-ERROR-NEXT: SPV_INTEL_joint_matrix

; CHECK-SPIRV-DAG: Capability CooperativeMatrixKHR
; CHECK-SPIRV-DAG: Extension "SPV_KHR_cooperative_matrix"
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix"
; CHECK-SPIRV-DAG: Capability CooperativeMatrixBFloat16ComponentTypeINTEL
; 64 stays for MatrixAAndBBFloat16ComponentsINTEL (0x40)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also test 0x80 and 0x100? I don't see them tested now.

; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#]] [[#]] [[#]] [[#]] [[#]] 64

; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHRPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_48_0PU3AS145__spirv_CooperativeMatrixKHR__short_2_48_12_1PU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_2i({{.*}}, i32 64)

target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "spir64-unknown-unknown"

define spir_kernel void @matrix_multiply(ptr addrspace(1) noundef align 1 %_arg_accA, ptr addrspace(1) noundef align 1 %_arg_accB, ptr addrspace(1) noundef align 1 %_arg_accC, i64 noundef %_arg_N, i64 noundef %_arg_K, i32 noundef %_arg_Initvalue) {
entry:
%matrixC = tail call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(float 0.0)
%matrixA = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(1) noundef %_arg_accA, i32 noundef 0, i64 noundef %_arg_K, i32 noundef 1)
%matrixB = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 2, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(1) noundef %_arg_accB, i32 noundef 1, i64 noundef %_arg_K)
%res = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 48, 0) noundef %matrixA, target("spirv.CooperativeMatrixKHR", i16, 2, 48, 12, 1) noundef %matrixB, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) noundef %matrixC, i32 noundef 64)
tail call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(1) noundef %_arg_accC, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) noundef %res, i32 noundef 0, i64 noundef %_arg_N, i32 noundef 1)
ret void
}

declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(float noundef)

declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 48, 0) noundef, target("spirv.CooperativeMatrixKHR", i16, 2, 48, 12, 1) noundef, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) noundef, i32 noundef)

declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(1), i32, i64, i32)

declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i16, 2, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(1), i32, i64)

declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(1), target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2), i32, i64, i32)
Loading
Loading