diff --git a/include/ShaderWriter/BaseTypes/ReturnWrapper.inl b/include/ShaderWriter/BaseTypes/ReturnWrapper.inl index b7933b83..8e803735 100644 --- a/include/ShaderWriter/BaseTypes/ReturnWrapper.inl +++ b/include/ShaderWriter/BaseTypes/ReturnWrapper.inl @@ -97,6 +97,7 @@ namespace sdw template< typename T > ReturnWrapperT< ValueT > & ReturnWrapperT< ValueT >::operator=( T const & rhs ) { + static_assert( areCompatible< ValueT, T > ); auto & writer = *this->getWriter(); this->m_expr = sdw::makeAssign( this->getType() , makeExpr( writer, *this ) diff --git a/source/CompilerHlsl/HlslAdaptStatements.cpp b/source/CompilerHlsl/HlslAdaptStatements.cpp index da9d3db5..60953f5c 100644 --- a/source/CompilerHlsl/HlslAdaptStatements.cpp +++ b/source/CompilerHlsl/HlslAdaptStatements.cpp @@ -472,7 +472,8 @@ namespace hlsl , m_exprCache.makeIdentifier( m_typesCache , ast::var::makeVariable( m_adaptationData.getNextVarId() , m_typesCache.getArray( ssboVar->getType(), 1u ) - , ssboVar->getName() ) ) + , ssboVar->getName() + , uint64_t( ast::var::Flag::eUniform ) ) ) , m_exprCache.makeLiteral( m_typesCache, 0 ) ) , mbrIndex , uint64_t( ast::var::Flag::eUniform ) ) ); diff --git a/source/CompilerHlsl/HlslGenerateStatements.cpp b/source/CompilerHlsl/HlslGenerateStatements.cpp index f0b74536..f4b95dcb 100644 --- a/source/CompilerHlsl/HlslGenerateStatements.cpp +++ b/source/CompilerHlsl/HlslGenerateStatements.cpp @@ -1282,19 +1282,20 @@ namespace hlsl void visitCompoundStmt( ast::stmt::Compound const * stmt )override { doAppendLineEnd(); - m_result += "\n" + m_indent + "{\n"; + m_result += "\n"; + if ( !m_allowSingleLineCompound || stmt->size() > 1u ) + m_result += m_indent + "{\n"; auto save = m_indent; m_indent += "\t"; visitContainerStmt( stmt ); m_indent = save; - if ( m_appendSemiColon ) + if ( !m_allowSingleLineCompound || stmt->size() > 1u ) { - m_result += m_indent + "};\n"; - } - else - { - m_result += m_indent + "}\n"; + if ( m_appendSemiColon ) + m_result += m_indent + "};\n"; + else + m_result += m_indent + "}\n"; } } @@ -1304,7 +1305,10 @@ namespace hlsl doAppendLineEnd(); m_result += m_indent + "do"; m_appendSemiColon = false; + auto save = m_allowSingleLineCompound; + m_allowSingleLineCompound = true; visitCompoundStmt( stmt ); + m_allowSingleLineCompound = save; m_result += m_indent + "while (" + doSubmit( *stmt->getCtrlExpr() ) + ");\n"; m_appendLineEnd = true; } @@ -1314,7 +1318,10 @@ namespace hlsl m_result += m_indent + "else if (" + doSubmit( *stmt->getCtrlExpr() ) + ")"; m_appendSemiColon = false; m_appendLineEnd = false; + auto save = m_allowSingleLineCompound; + m_allowSingleLineCompound = true; visitCompoundStmt( stmt ); + m_allowSingleLineCompound = save; m_appendLineEnd = true; } @@ -1323,7 +1330,10 @@ namespace hlsl m_result += m_indent + "else"; m_appendSemiColon = false; m_appendLineEnd = false; + auto save = m_allowSingleLineCompound; + m_allowSingleLineCompound = true; visitCompoundStmt( stmt ); + m_allowSingleLineCompound = save; m_appendLineEnd = true; } @@ -1334,9 +1344,12 @@ namespace hlsl m_result += m_indent + "for (" + doSubmit( *stmt->getInitExpr() ) + "; "; m_result += doSubmit( *stmt->getCtrlExpr() ) + "; "; m_result += doSubmit( *stmt->getIncrExpr() ) + ")"; + auto save = m_allowSingleLineCompound; + m_allowSingleLineCompound = true; m_appendSemiColon = false; visitCompoundStmt( stmt ); m_appendLineEnd = true; + m_allowSingleLineCompound = save; } void visitFragmentLayoutStmt( ast::stmt::FragmentLayout const * stmt )override @@ -1626,7 +1639,10 @@ namespace hlsl doAppendLineEnd(); m_result += m_indent + "if (" + doSubmit( *stmt->getCtrlExpr() ) + ")"; m_appendSemiColon = false; + auto save = m_allowSingleLineCompound; + m_allowSingleLineCompound = true; visitCompoundStmt( stmt ); + m_allowSingleLineCompound = save; for ( auto & elseIf : stmt->getElseIfList() ) { @@ -1882,7 +1898,10 @@ namespace hlsl doAppendLineEnd(); m_result += m_indent + "while (" + doSubmit( *stmt->getCtrlExpr() ) + ")"; m_appendSemiColon = false; + auto save = m_allowSingleLineCompound; + m_allowSingleLineCompound = true; visitCompoundStmt( stmt ); + m_allowSingleLineCompound = save; m_appendLineEnd = true; } @@ -1928,6 +1947,7 @@ namespace hlsl std::string & m_result; bool m_appendSemiColon{ false }; bool m_appendLineEnd{ false }; + bool m_allowSingleLineCompound{ false }; }; } diff --git a/source/CompilerHlsl/HlslIntrinsicConfig.hpp b/source/CompilerHlsl/HlslIntrinsicConfig.hpp index 385751b1..a64c312b 100644 --- a/source/CompilerHlsl/HlslIntrinsicConfig.hpp +++ b/source/CompilerHlsl/HlslIntrinsicConfig.hpp @@ -636,6 +636,28 @@ namespace hlsl case ast::expr::Intrinsic::eSubgroupBroadcastFirst3D: case ast::expr::Intrinsic::eSubgroupBroadcastFirst4D: case ast::expr::Intrinsic::eSubgroupBallot: + case ast::expr::Intrinsic::eSubgroupBallotBitCount: + case ast::expr::Intrinsic::eSubgroupBallotExclusiveBitCount: + case ast::expr::Intrinsic::eSubgroupShuffle1F: + case ast::expr::Intrinsic::eSubgroupShuffle2F: + case ast::expr::Intrinsic::eSubgroupShuffle3F: + case ast::expr::Intrinsic::eSubgroupShuffle4F: + case ast::expr::Intrinsic::eSubgroupShuffle1I: + case ast::expr::Intrinsic::eSubgroupShuffle2I: + case ast::expr::Intrinsic::eSubgroupShuffle3I: + case ast::expr::Intrinsic::eSubgroupShuffle4I: + case ast::expr::Intrinsic::eSubgroupShuffle1U: + case ast::expr::Intrinsic::eSubgroupShuffle2U: + case ast::expr::Intrinsic::eSubgroupShuffle3U: + case ast::expr::Intrinsic::eSubgroupShuffle4U: + case ast::expr::Intrinsic::eSubgroupShuffle1B: + case ast::expr::Intrinsic::eSubgroupShuffle2B: + case ast::expr::Intrinsic::eSubgroupShuffle3B: + case ast::expr::Intrinsic::eSubgroupShuffle4B: + case ast::expr::Intrinsic::eSubgroupShuffle1D: + case ast::expr::Intrinsic::eSubgroupShuffle2D: + case ast::expr::Intrinsic::eSubgroupShuffle3D: + case ast::expr::Intrinsic::eSubgroupShuffle4D: case ast::expr::Intrinsic::eSubgroupAdd1F: case ast::expr::Intrinsic::eSubgroupAdd2F: case ast::expr::Intrinsic::eSubgroupAdd3F: @@ -971,6 +993,66 @@ namespace hlsl case ast::expr::Intrinsic::eSubgroupQuadBroadcast2D: case ast::expr::Intrinsic::eSubgroupQuadBroadcast3D: case ast::expr::Intrinsic::eSubgroupQuadBroadcast4D: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal1F: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal2F: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal3F: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal4F: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal1I: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal2I: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal3I: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal4I: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal1U: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal2U: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal3U: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal4U: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal1B: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal2B: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal3B: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal4B: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal1D: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal2D: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal3D: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal4D: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical1F: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical2F: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical3F: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical4F: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical1I: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical2I: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical3I: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical4I: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical1U: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical2U: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical3U: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical4U: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical1B: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical2B: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical3B: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical4B: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical1D: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical2D: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical3D: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical4D: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal1F: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal2F: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal3F: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal4F: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal1I: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal2I: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal3I: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal4I: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal1U: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal2U: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal3U: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal4U: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal1B: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal2B: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal3B: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal4B: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal1D: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal2D: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal3D: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal4D: config.requiresWaveOps = true; break; @@ -1023,32 +1105,10 @@ namespace hlsl case ast::expr::Intrinsic::eAtomicExchange2H: case ast::expr::Intrinsic::eAtomicExchange4H: case ast::expr::Intrinsic::eSubgroupInverseBallot: - case ast::expr::Intrinsic::eSubgroupBallotBitCount: case ast::expr::Intrinsic::eSubgroupBallotBitExtract: - case ast::expr::Intrinsic::eSubgroupBallotExclusiveBitCount: case ast::expr::Intrinsic::eSubgroupBallotInclusiveBitCount: case ast::expr::Intrinsic::eSubgroupBallotFindLSB: case ast::expr::Intrinsic::eSubgroupBallotFindMSB: - case ast::expr::Intrinsic::eSubgroupShuffle1F: - case ast::expr::Intrinsic::eSubgroupShuffle2F: - case ast::expr::Intrinsic::eSubgroupShuffle3F: - case ast::expr::Intrinsic::eSubgroupShuffle4F: - case ast::expr::Intrinsic::eSubgroupShuffle1I: - case ast::expr::Intrinsic::eSubgroupShuffle2I: - case ast::expr::Intrinsic::eSubgroupShuffle3I: - case ast::expr::Intrinsic::eSubgroupShuffle4I: - case ast::expr::Intrinsic::eSubgroupShuffle1U: - case ast::expr::Intrinsic::eSubgroupShuffle2U: - case ast::expr::Intrinsic::eSubgroupShuffle3U: - case ast::expr::Intrinsic::eSubgroupShuffle4U: - case ast::expr::Intrinsic::eSubgroupShuffle1B: - case ast::expr::Intrinsic::eSubgroupShuffle2B: - case ast::expr::Intrinsic::eSubgroupShuffle3B: - case ast::expr::Intrinsic::eSubgroupShuffle4B: - case ast::expr::Intrinsic::eSubgroupShuffle1D: - case ast::expr::Intrinsic::eSubgroupShuffle2D: - case ast::expr::Intrinsic::eSubgroupShuffle3D: - case ast::expr::Intrinsic::eSubgroupShuffle4D: case ast::expr::Intrinsic::eSubgroupShuffleXor1F: case ast::expr::Intrinsic::eSubgroupShuffleXor2F: case ast::expr::Intrinsic::eSubgroupShuffleXor3F: diff --git a/source/CompilerHlsl/HlslIntrinsicNames.hpp b/source/CompilerHlsl/HlslIntrinsicNames.hpp index 79a28e3f..9c9d3576 100644 --- a/source/CompilerHlsl/HlslIntrinsicNames.hpp +++ b/source/CompilerHlsl/HlslIntrinsicNames.hpp @@ -1137,6 +1137,26 @@ namespace hlsl case ast::expr::Intrinsic::eSubgroupBroadcast2D: case ast::expr::Intrinsic::eSubgroupBroadcast3D: case ast::expr::Intrinsic::eSubgroupBroadcast4D: + case ast::expr::Intrinsic::eSubgroupShuffle1F: + case ast::expr::Intrinsic::eSubgroupShuffle2F: + case ast::expr::Intrinsic::eSubgroupShuffle3F: + case ast::expr::Intrinsic::eSubgroupShuffle4F: + case ast::expr::Intrinsic::eSubgroupShuffle1I: + case ast::expr::Intrinsic::eSubgroupShuffle2I: + case ast::expr::Intrinsic::eSubgroupShuffle3I: + case ast::expr::Intrinsic::eSubgroupShuffle4I: + case ast::expr::Intrinsic::eSubgroupShuffle1U: + case ast::expr::Intrinsic::eSubgroupShuffle2U: + case ast::expr::Intrinsic::eSubgroupShuffle3U: + case ast::expr::Intrinsic::eSubgroupShuffle4U: + case ast::expr::Intrinsic::eSubgroupShuffle1B: + case ast::expr::Intrinsic::eSubgroupShuffle2B: + case ast::expr::Intrinsic::eSubgroupShuffle3B: + case ast::expr::Intrinsic::eSubgroupShuffle4B: + case ast::expr::Intrinsic::eSubgroupShuffle1D: + case ast::expr::Intrinsic::eSubgroupShuffle2D: + case ast::expr::Intrinsic::eSubgroupShuffle3D: + case ast::expr::Intrinsic::eSubgroupShuffle4D: case ast::expr::Intrinsic::eReadInvocation1F: case ast::expr::Intrinsic::eReadInvocation2F: case ast::expr::Intrinsic::eReadInvocation3F: @@ -1398,6 +1418,74 @@ namespace hlsl result = "QuadReadLaneAt"; break; + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal1F: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal2F: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal3F: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal4F: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal1I: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal2I: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal3I: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal4I: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal1U: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal2U: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal3U: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal4U: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal1B: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal2B: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal3B: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal4B: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal1D: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal2D: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal3D: + case ast::expr::Intrinsic::eSubgroupQuadSwapHorizontal4D: + result = "QuadReadAcrossX"; + break; + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical1F: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical2F: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical3F: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical4F: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical1I: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical2I: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical3I: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical4I: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical1U: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical2U: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical3U: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical4U: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical1B: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical2B: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical3B: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical4B: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical1D: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical2D: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical3D: + case ast::expr::Intrinsic::eSubgroupQuadSwapVertical4D: + result = "QuadReadAcrossY"; + break; + + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal1F: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal2F: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal3F: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal4F: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal1I: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal2I: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal3I: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal4I: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal1U: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal2U: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal3U: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal4U: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal1B: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal2B: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal3B: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal4B: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal1D: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal2D: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal3D: + case ast::expr::Intrinsic::eSubgroupQuadSwapDiagonal4D: + result = "QuadReadAcrossDiagonal"; + break; + case ast::expr::Intrinsic::eSubgroupQuadAll: result = "QuadAll"; break; diff --git a/source/CompilerSpirV/SpirVGenerateStatements.cpp b/source/CompilerSpirV/SpirVGenerateStatements.cpp index 5b747ff7..604a4a72 100644 --- a/source/CompilerSpirV/SpirVGenerateStatements.cpp +++ b/source/CompilerSpirV/SpirVGenerateStatements.cpp @@ -421,6 +421,7 @@ namespace spirv , glsl::Statement * currentDebugStatement ) { bool allLiterals{ false }; + bool preventLoading{ false }; return submit( exprCache , expr , context @@ -428,6 +429,7 @@ namespace spirv , currentBlock , shaderModule , allLiterals + , preventLoading , currentDebugStatement ); } @@ -439,6 +441,7 @@ namespace spirv , Module & shaderModule , DebugId initialiser , bool hasFuncInit + , bool preventLoading , glsl::Statement * currentDebugStatement ) { bool allLiterals{ false }; @@ -451,6 +454,7 @@ namespace spirv , allLiterals , std::move( initialiser ) , hasFuncInit + , preventLoading , currentDebugStatement ); } @@ -461,6 +465,7 @@ namespace spirv , Block & currentBlock , Module & shaderModule , bool & allLiterals + , bool preventLoading , glsl::Statement * currentDebugStatement ) { DebugId result{ 0u, expr.getType() }; @@ -471,6 +476,7 @@ namespace spirv , currentBlock , shaderModule , allLiterals + , preventLoading , currentDebugStatement }; expr.accept( &vis ); @@ -491,6 +497,7 @@ namespace spirv , bool & allLiterals , DebugId initialiser , bool hasFuncInit + , bool preventLoading , glsl::Statement * currentDebugStatement ) { DebugId result{ 0u, expr.getType() }; @@ -503,6 +510,7 @@ namespace spirv , allLiterals , std::move( initialiser ) , hasFuncInit + , preventLoading , currentDebugStatement }; expr.accept( &vis ); return result; @@ -516,6 +524,7 @@ namespace spirv , Block & currentBlock , Module & shaderModule , bool & allLiterals + , bool preventLoading , glsl::Statement * currentDebugStatement ) : m_exprCache{ exprCache } , m_context{ context } @@ -528,6 +537,7 @@ namespace spirv , m_allLiterals{ allLiterals } , m_allocator{ shaderModule.allocator } , m_initialiser{ 0u } + , m_preventLoading{ preventLoading } { } @@ -540,6 +550,7 @@ namespace spirv , bool & allLiterals , DebugId initialiser , bool hasFuncInit + , bool preventLoading , glsl::Statement * currentDebugStatement ) : m_exprCache{ exprCache } , m_context{ context } @@ -553,6 +564,7 @@ namespace spirv , m_allocator{ shaderModule.allocator } , m_initialiser{ std::move( initialiser ) } , m_hasFuncInit{ hasFuncInit } + , m_preventLoading{ preventLoading } { } @@ -565,13 +577,13 @@ namespace spirv , DebugId initialiser , bool hasFuncInit ) { - return submit( m_exprCache, expr, m_context, m_moduleConfig, m_currentBlock, m_module, std::move( initialiser ), hasFuncInit, m_currentDebugStatement ); + return submit( m_exprCache, expr, m_context, m_moduleConfig, m_currentBlock, m_module, std::move( initialiser ), hasFuncInit, false, m_currentDebugStatement ); } DebugId doSubmit( ast::expr::Expr const & expr , bool & allLiterals ) { - return submit( m_exprCache, expr, m_context, m_moduleConfig, m_currentBlock, m_module, allLiterals, m_currentDebugStatement ); + return submit( m_exprCache, expr, m_context, m_moduleConfig, m_currentBlock, m_module, allLiterals, false, m_currentDebugStatement ); } glsl::RangeInfo getColumnData( ast::expr::Expr const & expr )const @@ -680,6 +692,8 @@ namespace spirv DebugId loadVariable( DebugId const & variableId , ast::expr::Expr const & expr ) { + if ( m_preventLoading ) + return variableId; return m_module.loadVariable( variableId , m_currentBlock , m_currentDebugStatement @@ -1368,10 +1382,11 @@ namespace spirv { m_allLiterals = false; - if ( expr->getSwizzle().isOneComponent() - && expr->getOuterExpr()->getKind() == ast::expr::Kind::eIdentifier - && !static_cast< ast::expr::Identifier const & >( *expr->getOuterExpr() ).getVariable()->isTempVar() - && static_cast< ast::expr::Identifier const & >( *expr->getOuterExpr() ).getVariable()->getBuiltin() != ast::Builtin::eWorkGroupSize ) + if ( m_preventLoading + || ( expr->getSwizzle().isOneComponent() + && expr->getOuterExpr()->getKind() == ast::expr::Kind::eIdentifier + && !static_cast< ast::expr::Identifier const & >( *expr->getOuterExpr() ).getVariable()->isTempVar() + && static_cast< ast::expr::Identifier const & >( *expr->getOuterExpr() ).getVariable()->getBuiltin() != ast::Builtin::eWorkGroupSize ) ) { m_result = loadVariable( makeAccessChain( m_exprCache , *expr @@ -1673,7 +1688,11 @@ namespace spirv void handleAtomicIntrinsicCallExpr( spv::Op opCode, ast::expr::IntrinsicCall const * expr ) { DebugIdList params{ m_allocator }; - params.push_back( doSubmit( *expr->getArgList()[0].get() ) ); + { + bool allLiterals{ true }; + params.push_back( submit( m_exprCache, *expr->getArgList()[0] + , m_context, m_moduleConfig, m_currentBlock, m_module, allLiterals, true, m_currentDebugStatement ) ); + } auto scopeId = registerLiteral( uint32_t( spv::ScopeDevice ) ); auto memorySemanticsId = registerLiteral( uint32_t( spv::MemorySemanticsAcquireReleaseMask ) ); @@ -2140,6 +2159,7 @@ namespace spirv ast::ShaderAllocatorBlock * m_allocator; DebugId m_initialiser; bool m_hasFuncInit{ false }; + bool m_preventLoading{ false }; std::array< ast::type::BaseStructPtr, 4u > m_unsignedExtendedTypes{}; std::array< ast::type::BaseStructPtr, 4u > m_signedExtendedTypes{}; uint32_t m_aliasId{ 1u }; diff --git a/source/ShaderAST/Visitors/SelectEntryPoint.cpp b/source/ShaderAST/Visitors/SelectEntryPoint.cpp index 404212d3..f016311f 100644 --- a/source/ShaderAST/Visitors/SelectEntryPoint.cpp +++ b/source/ShaderAST/Visitors/SelectEntryPoint.cpp @@ -901,11 +901,9 @@ namespace ast void visitVariableDeclStmt( stmt::VariableDecl const * stmt )override { - if ( stmt->getVariable()->isLocale() ) - { - StmtCloner::visitVariableDeclStmt( stmt ); - } - else if ( isUsed( stmt->getVariable() ) ) + if ( stmt->getVariable()->isLocale() + || stmt->getVariable()->isUniform() + || isUsed( stmt->getVariable() ) ) { StmtCloner::visitVariableDeclStmt( stmt ); } diff --git a/test/ShaderWriter/TestWriterShaderStage_Compute.cpp b/test/ShaderWriter/TestWriterShaderStage_Compute.cpp index 8839c3dc..71812417 100644 --- a/test/ShaderWriter/TestWriterShaderStage_Compute.cpp +++ b/test/ShaderWriter/TestWriterShaderStage_Compute.cpp @@ -780,6 +780,44 @@ namespace , CurrentCompilers ); sdwTestEnd() } + + TEST_F( Compute, sharedVecArrayAtomic ) + { + sdwTestBegin( "sharedVecArrayAtomic" ); + sdw::ComputeWriter writer{ &testCounts.allocator }; + { + auto gs = writer.declSharedVariable< sdw::U32Vec4 >( "gs", 10u ); + + writer.implementMain( 1u + , [&]( sdw::ComputeIn in ) + { + atomicAdd( gs[0_u].x(), 1_u ); + } ); + } + test::writeShader( writer + , testCounts + , CurrentCompilers ); + sdwTestEnd() + } + + TEST_F( Compute, sharedVecAtomic ) + { + sdwTestBegin( "sharedVecArrayAtomic" ); + sdw::ComputeWriter writer{ &testCounts.allocator }; + { + auto gs = writer.declSharedVariable< sdw::U32Vec4 >( "gs" ); + + writer.implementMain( 1u + , [&]( sdw::ComputeIn in ) + { + atomicAdd( gs.x(), 1_u ); + } ); + } + test::writeShader( writer + , testCounts + , CurrentCompilers ); + sdwTestEnd() + } } sdwTestSuiteMain()