Skip to content

Commit

Permalink
Merge pull request #5084 from quic/topic/sgemm_direct_sme1
Browse files Browse the repository at this point in the history
Support for SGEMM_DIRECT Kernel based on SME1
  • Loading branch information
martin-frbg authored Feb 19, 2025
2 parents abbd78a + f66ca05 commit eb84aac
Show file tree
Hide file tree
Showing 25 changed files with 617 additions and 14 deletions.
5 changes: 5 additions & 0 deletions Makefile.arm64
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ FCOMMON_OPT += -march=armv8-a+sve
endif
endif

ifeq ($(CORE), ARMV9SME)
CCOMMON_OPT += -march=armv9-a+sve2+sme
FCOMMON_OPT += -march=armv9-a+sve2
endif

ifeq ($(CORE), CORTEXA53)
CCOMMON_OPT += -march=armv8-a -mtune=cortex-a53
ifneq ($(F_COMPILER), NAG)
Expand Down
8 changes: 8 additions & 0 deletions Makefile.system
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ ifeq ($(ARCH), arm64)
export MACOSX_DEPLOYMENT_TARGET=11.0
ifeq ($(C_COMPILER), GCC)
export NO_SVE = 1
export NO_SME = 1
endif
else
export MACOSX_DEPLOYMENT_TARGET=10.8
Expand Down Expand Up @@ -709,6 +710,9 @@ DYNAMIC_CORE += NEOVERSEN2
DYNAMIC_CORE += ARMV8SVE
DYNAMIC_CORE += A64FX
endif
ifneq ($(NO_SME), 1)
DYNAMIC_CORE += ARMV9SME
endif
DYNAMIC_CORE += THUNDERX
DYNAMIC_CORE += THUNDERX2T99
DYNAMIC_CORE += TSV110
Expand Down Expand Up @@ -1472,6 +1476,10 @@ ifeq ($(NO_SVE), 1)
CCOMMON_OPT += -DNO_SVE
endif

ifeq ($(NO_SME), 1)
CCOMMON_OPT += -DNO_SME
endif

ifdef SMP
CCOMMON_OPT += -DSMP_SERVER

Expand Down
1 change: 1 addition & 0 deletions TargetList.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ THUNDERX3T110
VORTEX
A64FX
ARMV8SVE
ARMV9SME
FT2000

9.System Z:
Expand Down
19 changes: 19 additions & 0 deletions c_check
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,24 @@ if [ "$architecture" = "arm64" ]; then
rm -rf "$tmpd"
fi

no_sme=0
if [ "$architecture" = "arm64" ]; then
tmpd=$(mktemp -d 2>/dev/null || mktemp -d -t 'OBC')
tmpf="$tmpd/a.S"
printf ".text \n.global sme_test\n\nsme_test:\nsmstart\nsmstop\nret\n">> "$tmpf"
args=" -march=armv9-a+sve2+sme -c -o $tmpf.o $tmpf"
no_sme=0
{
$compiler_name $flags $args >/dev/null 2>&1
} || {
args=" -march=armv9-a+sme -c -o $tmpf.o $tmpf"
$compiler_name $flags $args >/dev/null 2>&1
} || {
no_sme=1
}
rm -rf "$tmpd"
fi

c11_atomics=0
case "$data" in
*HAVE_C11*)
Expand Down Expand Up @@ -475,6 +493,7 @@ done
printf "CEXTRALIB=%s %s %s\n" "$linker_L" "$linker_l" "$linker_a"
[ "$no_msa" -eq 1 ] && printf "NO_MSA=1\n"
[ "$no_sve" -eq 1 ] && printf "NO_SVE=1\n"
[ "$no_sme" -eq 1 ] && printf "NO_SME=1\n"
[ "$no_rv64gv" -eq 1 ] && printf "NO_RV64GV=1\n"
[ "$no_avx512" -eq 1 ] && printf "NO_AVX512=1\n"
[ "$no_avx512bf" -eq 1 ] && printf "NO_AVX512BF16=1\n"
Expand Down
18 changes: 15 additions & 3 deletions cmake/arch.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,21 @@ endif ()

if (DYNAMIC_ARCH)
if (ARM64)
set(DYNAMIC_CORE ARMV8 CORTEXA53 CORTEXA57 THUNDERX THUNDERX2T99 TSV110 EMAG8180 NEOVERSEN1 THUNDERX3T110)
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER 9.99)
set(DYNAMIC_CORE ${DYNAMIC_CORE} NEOVERSEV1 NEOVERSEN2 ARMV8SVE A64FX)
set(DYNAMIC_CORE ARMV8 CORTEXA53 CORTEXA57 THUNDERX THUNDERX2T99 TSV110 EMAG8180 NEOVERSEN1 THUNDERX3T110)
if (${CMAKE_C_COMPILER_ID} STREQUAL "GNU")
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 10) # SVE ACLE supported in GCC >= 10
set(DYNAMIC_CORE ${DYNAMIC_CORE} NEOVERSEV1 NEOVERSEN2 ARMV8SVE A64FX)
endif ()
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 14) # SME ACLE supported in GCC >= 14
set(DYNAMIC_CORE ${DYNAMIC_CORE} ARMV9SME)
endif()
elseif (${CMAKE_C_COMPILER_ID} MATCHES "Clang")
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 11) # SVE ACLE supported in LLVM >= 11
set(DYNAMIC_CORE ${DYNAMIC_CORE} NEOVERSEV1 NEOVERSEN2 ARMV8SVE A64FX)
endif ()
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 19) # SME ACLE supported in LLVM >= 19
set(DYNAMIC_CORE ${DYNAMIC_CORE} ARMV9SME)
endif()
endif ()
if (DYNAMIC_LIST)
set(DYNAMIC_CORE ARMV8 ${DYNAMIC_LIST})
Expand Down
6 changes: 6 additions & 0 deletions cmake/cc.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ if (${CORE} STREQUAL ARMV8SVE)
endif ()
endif ()

if (${CORE} STREQUAL ARMV9SME)
if (NOT DYNAMIC_ARCH)
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv9-a+sme")
endif ()
endif ()

if (${CORE} STREQUAL CORTEXA510)
if (NOT DYNAMIC_ARCH)
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv8-a+sve")
Expand Down
2 changes: 1 addition & 1 deletion cmake/prebuild.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ endif ()
set(ZGEMM_UNROLL_M 4)
set(ZGEMM_UNROLL_N 4)
set(SYMV_P 16)
elseif ("${TCORE}" STREQUAL "NEOVERSEN2")
elseif ("${TCORE}" STREQUAL "NEOVERSEN2" or "${TCORE}" STREQUAL "ARMV9SME")
file(APPEND ${TARGET_CONF_TEMP}
"#define L1_CODE_SIZE\t65536\n"
"#define L1_CODE_LINESIZE\t64\n"
Expand Down
3 changes: 3 additions & 0 deletions cmake/system.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ if (${TARGET} STREQUAL NEOVERSEV1)
set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -march=armv8.2-a+sve")
endif()
endif()
if (${TARGET} STREQUAL ARMV9SME)
set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -march=armv9-a+sme -O3")
endif()
if (${TARGET} STREQUAL A64FX)
if (${CMAKE_C_COMPILER_ID} STREQUAL "PGI" AND NOT NO_SVE)
set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -Msve-intrinsics -march=armv8.2-a+sve -mtune=a64fx")
Expand Down
11 changes: 11 additions & 0 deletions cmake/system_check.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ endif()
endif()
endif()

if (ARM64)
if (NOT NO_SME)
file(WRITE ${PROJECT_BINARY_DIR}/sme.c ".text \n.global sme_test\n\nsme_test:\nsmstart\nsmstop\nret\n")
execute_process(COMMAND ${CMAKE_C_COMPILER} -march=armv9-a+sve2+sme -c -v -o ${PROJECT_BINARY_DIR}/sme.o ${PROJECT_BINARY_DIR}/sme.c OUTPUT_QUIET ERROR_QUIET RESULT_VARIABLE NO_SME)
if (NO_SME EQUAL 1)
set (CCOMMON_OPT "${CCOMMON_OPT} -DNO_SME")
endif()
file(REMOVE "${PROJECT_BINARY_DIR}/sme.c" "${PROJECT_BINARY_DIR}/sme.o")
endif()
endif()

include(CheckIncludeFile)
CHECK_INCLUDE_FILE("stdatomic.h" HAVE_C11)
if (HAVE_C11 EQUAL 1)
Expand Down
1 change: 1 addition & 0 deletions common.h
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ void gotoblas_profile_init(void);
void gotoblas_profile_quit(void);

int support_avx512(void);
int support_sme1(void);

#ifdef USE_OPENMP

Expand Down
2 changes: 1 addition & 1 deletion common_arm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ static inline int blas_quickdivide(blasint x, blasint y){
#define HUGE_PAGESIZE ( 4 << 20)

#ifndef BUFFERSIZE
#if defined(NEOVERSEN1) || defined(NEOVERSEN2) || defined(NEOVERSEV1) || defined(A64FX) || defined(ARMV8SVE)
#if defined(NEOVERSEN1) || defined(NEOVERSEN2) || defined(NEOVERSEV1) || defined(A64FX) || defined(ARMV8SVE) || defined(ARMV9SME)
#define BUFFER_SIZE (32 << 22)
#else
#define BUFFER_SIZE (32 << 20)
Expand Down
6 changes: 6 additions & 0 deletions common_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG);
int (*sgemm_direct_performant) (BLASLONG M, BLASLONG N, BLASLONG K);
#endif
#ifdef ARCH_ARM64
#ifdef HAVE_SME
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG);
#endif
#endif


int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
Expand Down
4 changes: 2 additions & 2 deletions common_s.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@
#ifdef ARCH_X86_64
#define SGEMM_DIRECT_PERFORMANT gotoblas -> sgemm_direct_performant
#define SGEMM_DIRECT gotoblas -> sgemm_direct
#else
#elif ARCH_ARM64
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
#define SGEMM_DIRECT sgemm_direct
#define SGEMM_DIRECT gotoblas -> sgemm_direct
#endif

#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy
Expand Down
34 changes: 34 additions & 0 deletions driver/others/dynamic_arm64.c
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ extern gotoblas_t gotoblas_ARMV8SVE;
#else
#define gotoblas_ARMV8SVE gotoblas_ARMV8
#endif
#ifdef DYN_ARMV9SME
extern gotoblas_t gotoblas_ARMV9SME;
#else
#define gotoblas_ARMV9SME gotoblas_ARMV8
#endif
#ifdef DYN_CORTEX_A55
extern gotoblas_t gotoblas_CORTEXA55;
#else
Expand Down Expand Up @@ -148,6 +153,13 @@ extern gotoblas_t gotoblas_A64FX;
#define gotoblas_ARMV8SVE gotoblas_ARMV8
#define gotoblas_A64FX gotoblas_ARMV8
#endif

#ifndef NO_SME
extern gotoblas_t gotoblas_ARMV9SME;
#else
#define gotoblas_ARMV9SME gotoblas_ARMV8SVE
#endif

extern gotoblas_t gotoblas_THUNDERX3T110;
#endif
#define gotoblas_NEOVERSEV2 gotoblas_NEOVERSEV1
Expand All @@ -168,6 +180,9 @@ extern void openblas_warning(int verbose, const char * msg);
#ifndef HWCAP_SVE
#define HWCAP_SVE (1 << 22)
#endif
#ifndef HWCAP2_SME
#define HWCAP2_SME 1<<23
#endif

#define get_cpu_ftr(id, var) ({ \
__asm__ __volatile__ ("mrs %0, "#id : "=r" (var)); \
Expand Down Expand Up @@ -430,6 +445,13 @@ static gotoblas_t *get_coretype(void) {
snprintf(coremsg, 128, "Unknown CPU model - implementer %x part %x\n",implementer,part);
openblas_warning(1, coremsg);
}

#if !defined(NO_SME) && defined(HWCAP2_SME)
if ((getauxval(AT_HWCAP2) & HWCAP2_SME)) {
return &gotoblas_ARMV9SME;
}
#endif

#ifndef NO_SVE
if ((getauxval(AT_HWCAP) & HWCAP_SVE)) {
return &gotoblas_ARMV8SVE;
Expand Down Expand Up @@ -480,3 +502,15 @@ void gotoblas_dynamic_init(void) {
void gotoblas_dynamic_quit(void) {
gotoblas = NULL;
}

int support_sme1(void) {
int ret = 0;

#if (defined OS_LINUX || defined OS_ANDROID)
ret = getauxval(AT_HWCAP2) & HWCAP2_SME;
if(getauxval(AT_HWCAP2) & HWCAP2_SME){
ret = 1;
}
#endif
return ret;
}
13 changes: 13 additions & 0 deletions getarch.c
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,19 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define CORENAME "ARMV8SVE"
#endif

#ifdef FORCE_ARMV9SME
#define FORCE
#define ARCHITECTURE "ARM64"
#define SUBARCHITECTURE "ARMV9SME"
#define SUBDIRNAME "arm64"
#define ARCHCONFIG "-DARMV9SME " \
"-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \
"-DL2_SIZE=262144 -DL2_LINESIZE=64 " \
"-DDTB_DEFAULT_ENTRIES=64 -DDTB_SIZE=4096 -DL2_ASSOCIATIVE=32 " \
"-DHAVE_VFPV4 -DHAVE_VFPV3 -DHAVE_VFP -DHAVE_NEON -DHAVE_SVE -DHAVE_SME -DARMV8 -DARMV9"
#define LIBNAME "armv9sme"
#define CORENAME "ARMV9SME"
#endif

#ifdef FORCE_ARMV8
#define FORCE
Expand Down
13 changes: 10 additions & 3 deletions interface/gemm.c
Original file line number Diff line number Diff line change
Expand Up @@ -393,14 +393,21 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
PRINT_DEBUG_CNAME;

#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && defined(USE_SGEMM_KERNEL_DIRECT)
#ifdef DYNAMIC_ARCH
#if defined(DYNAMIC_ARCH) && defined(ARCH_x86)
if (support_avx512() )
#endif
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) {
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
return;
}

#endif
#if defined(DYNAMIC_ARCH) && defined(ARCH_ARM64)
if (support_sme1()){
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) {
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
return;
}
}
#endif
#endif

#ifndef COMPLEX
Expand Down
12 changes: 10 additions & 2 deletions kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -207,19 +207,27 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
if (ZARCH OR (UC_TARGET_CORE MATCHES POWER8) OR (UC_TARGET_CORE MATCHES POWER9) OR (UC_TARGET_CORE MATCHES POWER10))
set(USE_TRMM true)
endif ()

set(USE_DIRECT_SGEMM false)
if (X86_64)
if (X86_64 OR (ARM64 AND (UC_TARGET_CORE MATCHES ARMV9SME)))
set(USE_DIRECT_SGEMM true)
endif()

if (USE_DIRECT_SGEMM)
# if (NOT DEFINED SGEMMDIRECTKERNEL)
if (X86_64)
set (SGEMMDIRECTKERNEL sgemm_direct_skylakex.c)
set (SGEMMDIRECTPERFORMANT sgemm_direct_performant.c)
# endif()
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false SINGLE)
elseif (ARM64)
set (SGEMMDIRECTKERNEL sgemm_direct_arm64_sme1.c)
set (SGEMMDIRECTSMEKERNEL sgemm_direct_sme1.S)
set (SGEMMDIRECTPREKERNEL sgemm_direct_sme1_preprocess.S)
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTSMEKERNEL}" "" "gemm_direct_sme1" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPREKERNEL}" "" "gemm_direct_sme1_preprocess" false "" "" false SINGLE)
endif ()
endif()

foreach (float_type SINGLE DOUBLE)
Expand Down
4 changes: 4 additions & 0 deletions kernel/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ ifdef NO_AVX2
AVX2OPT=
endif


ifdef TARGET_CORE
ifeq ($(TARGET_CORE), ARMV9SME)
override CFLAGS += -DBUILD_KERNEL -DTABLE_NAME=gotoblas_$(TARGET_CORE) -DHAVE_SME -march=armv9-a+sve2+sme
endif
ifeq ($(TARGET_CORE), SAPPHIRERAPIDS)
override CFLAGS += -DBUILD_KERNEL -DTABLE_NAME=gotoblas_$(TARGET_CORE)
ifeq (1, $(filter 1,$(GCCVERSIONGTEQ11) $(CLANGVERSIONGTEQ12)))
Expand Down
Loading

0 comments on commit eb84aac

Please sign in to comment.