diff --git a/.github/workflows/utests-minimal.yml b/.github/workflows/utests-minimal.yml index 6e2328de..6a416e45 100644 --- a/.github/workflows/utests-minimal.yml +++ b/.github/workflows/utests-minimal.yml @@ -19,7 +19,7 @@ jobs: submodules: true - name: Configure CMake - run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} -DLIBSCRATCHCPP_BUILD_UNIT_TESTS=ON -DLIBSCRATCHCPP_ENABLE_SANITIZER=ON -DLIBSCRATCHCPP_AUDIO_SUPPORT=OFF -DLIBSCRATCHCPP_NETWORK_SUPPORT=OFF + run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} -DLIBSCRATCHCPP_BUILD_UNIT_TESTS=ON -DLIBSCRATCHCPP_ENABLE_SANITIZER=ON -DLIBSCRATCHCPP_AUDIO_SUPPORT=OFF -DLIBSCRATCHCPP_NETWORK_SUPPORT=OFF -DLIBSCRATCHCPP_ENABLE_CODE_ANALYZER=OFF -DLIBSCRATCHCPP_LLVM_INTEGER_SUPPORT=OFF - name: Build run: cmake --build ${{github.workspace}}/build --config ${{env.BUILD_TYPE}} -j$(nproc --all) diff --git a/CMakeLists.txt b/CMakeLists.txt index e986c93e..c36a42dd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,6 +10,7 @@ option(LIBSCRATCHCPP_BUILD_UNIT_TESTS "Build unit tests" ON) option(LIBSCRATCHCPP_NETWORK_SUPPORT "Support for downloading projects" ON) option(LIBSCRATCHCPP_PRINT_LLVM_IR "Print LLVM IR of compiled Scratch scripts (for debugging)" OFF) option(LIBSCRATCHCPP_ENABLE_CODE_ANALYZER "Analyze Scratch scripts to enable various optimizations" ON) +option(LIBSCRATCHCPP_LLVM_INTEGER_SUPPORT "Use integers when possible to enable various optimizations" ON) option(LIBSCRATCHCPP_ENABLE_SANITIZER "Enable sanitizer to detect memory issues" OFF) if (LIBSCRATCHCPP_ENABLE_SANITIZER) @@ -126,6 +127,10 @@ if(LIBSCRATCHCPP_ENABLE_CODE_ANALYZER) target_compile_definitions(scratchcpp PRIVATE ENABLE_CODE_ANALYZER) endif() +if(LIBSCRATCHCPP_LLVM_INTEGER_SUPPORT) + target_compile_definitions(scratchcpp PRIVATE LLVM_INTEGER_SUPPORT) +endif() + # Macros target_compile_definitions(scratchcpp PRIVATE LIBSCRATCHCPP_LIBRARY) target_compile_definitions(scratchcpp PRIVATE LIBSCRATCHCPP_VERSION="${PROJECT_VERSION}") diff --git a/include/scratchcpp/compiler.h b/include/scratchcpp/compiler.h index cf534783..2b48866d 100644 --- a/include/scratchcpp/compiler.h +++ b/include/scratchcpp/compiler.h @@ -148,6 +148,7 @@ class LIBSCRATCHCPP_EXPORT Compiler void createYield(); void createStop(); + void createStopWithoutSync(); void createProcedureCall(BlockPrototype *prototype, const Compiler::Args &args); diff --git a/include/scratchcpp/value_functions.h b/include/scratchcpp/value_functions.h index 79738ee2..3849d5dc 100644 --- a/include/scratchcpp/value_functions.h +++ b/include/scratchcpp/value_functions.h @@ -85,6 +85,7 @@ extern "C" LIBSCRATCHCPP_EXPORT StringPtr *value_doubleToStringPtr(double v); LIBSCRATCHCPP_EXPORT const StringPtr *value_boolToStringPtr(bool v); LIBSCRATCHCPP_EXPORT double value_stringToDouble(const StringPtr *s); + LIBSCRATCHCPP_EXPORT double value_stringToDoubleWithCheck(const StringPtr *s, bool *ok); LIBSCRATCHCPP_EXPORT bool value_stringToBool(const StringPtr *s); LIBSCRATCHCPP_EXPORT void value_add(const ValueData *v1, const ValueData *v2, ValueData *dst); diff --git a/src/blocks/controlblocks.cpp b/src/blocks/controlblocks.cpp index c29134c4..53d7f90f 100644 --- a/src/blocks/controlblocks.cpp +++ b/src/blocks/controlblocks.cpp @@ -178,7 +178,10 @@ CompilerValue *ControlBlocks::compileCreateCloneOf(Compiler *compiler) CompilerValue *ControlBlocks::compileDeleteThisClone(Compiler *compiler) { - compiler->addTargetFunctionCall("control_delete_this_clone"); + CompilerValue *deleted = compiler->addTargetFunctionCall("control_delete_this_clone", Compiler::StaticType::Bool); + compiler->beginIfStatement(deleted); + compiler->createStopWithoutSync(); // sync happens before the function call + compiler->endIf(); return nullptr; } @@ -234,10 +237,17 @@ extern "C" void control_create_clone(ExecutionContext *ctx, const StringPtr *spr } } -extern "C" void control_delete_this_clone(Target *target) +extern "C" bool control_delete_this_clone(Target *target) { if (!target->isStage()) { - target->engine()->stopTarget(target, nullptr); - static_cast(target)->deleteClone(); + Sprite *sprite = static_cast(target); + + if (sprite->isClone()) { + target->engine()->stopTarget(target, nullptr); + sprite->deleteClone(); + return true; + } } + + return false; } diff --git a/src/engine/compiler.cpp b/src/engine/compiler.cpp index 90f95afe..33d97516 100644 --- a/src/engine/compiler.cpp +++ b/src/engine/compiler.cpp @@ -136,7 +136,7 @@ void Compiler::preoptimize() /*! * Adds a call to the given function.\n - * For example: extern "C" bool some_block(double arg1, const char *arg2) + * For example: extern "C" bool some_block(double arg1, const StringPtr *arg2) */ CompilerValue *Compiler::addFunctionCall(const std::string &functionName, StaticType returnType, const ArgTypes &argTypes, const Args &args) { @@ -146,7 +146,7 @@ CompilerValue *Compiler::addFunctionCall(const std::string &functionName, Static /*! * Adds a call to the given function with a target parameter.\n - * For example: extern "C" bool some_block(Target *target, double arg1, const char *arg2) + * For example: extern "C" bool some_block(Target *target, double arg1, const StringPtr *arg2) */ CompilerValue *Compiler::addTargetFunctionCall(const std::string &functionName, StaticType returnType, const ArgTypes &argTypes, const Args &args) { @@ -156,7 +156,7 @@ CompilerValue *Compiler::addTargetFunctionCall(const std::string &functionName, /*! * Adds a call to the given function with an execution context parameter.\n - * For example: extern "C" bool some_block(ExecutionContext *ctx, double arg1, const char *arg2) + * For example: extern "C" bool some_block(ExecutionContext *ctx, double arg1, const StringPtr *arg2) */ CompilerValue *Compiler::addFunctionCallWithCtx(const std::string &functionName, StaticType returnType, const ArgTypes &argTypes, const Args &args) { @@ -705,6 +705,16 @@ void Compiler::createStop() impl->builder->createStop(); } +/*! + * Creates a stop script without synchronization instruction.\n + * Use this if synchronization is not possible at the stop point. + * \note Only use this when everything is synchronized, e. g. after a function call. + */ +void Compiler::createStopWithoutSync() +{ + impl->builder->createStopWithoutSync(); +} + /*! Creates a call to the procedure with the given prototype. */ void Compiler::createProcedureCall(BlockPrototype *prototype, const libscratchcpp::Compiler::Args &args) { diff --git a/src/engine/internal/icodebuilder.h b/src/engine/internal/icodebuilder.h index d175c850..1c793c1d 100644 --- a/src/engine/internal/icodebuilder.h +++ b/src/engine/internal/icodebuilder.h @@ -99,6 +99,7 @@ class ICodeBuilder virtual void yield() = 0; virtual void createStop() = 0; + virtual void createStopWithoutSync() = 0; virtual void createProcedureCall(BlockPrototype *prototype, const Compiler::Args &args) = 0; }; diff --git a/src/engine/internal/llvm/CMakeLists.txt b/src/engine/internal/llvm/CMakeLists.txt index 54b79e88..818286c2 100644 --- a/src/engine/internal/llvm/CMakeLists.txt +++ b/src/engine/internal/llvm/CMakeLists.txt @@ -15,6 +15,7 @@ target_sources(scratchcpp llvmloop.h llvmcoroutine.cpp llvmcoroutine.h + llvmlocalvariableinfo.h llvmvariableptr.h llvmlistptr.h llvmtypes.cpp diff --git a/src/engine/internal/llvm/instructions/control.cpp b/src/engine/internal/llvm/instructions/control.cpp index 75c4d9fc..b41dd292 100644 --- a/src/engine/internal/llvm/instructions/control.cpp +++ b/src/engine/internal/llvm/instructions/control.cpp @@ -60,6 +60,10 @@ ProcessResult Control::process(LLVMInstruction *ins) ret.next = buildStop(ins); break; + case LLVMInstruction::Type::StopWithoutSync: + ret.next = buildStopWithoutSync(ins); + break; + default: ret.match = false; break; @@ -88,6 +92,8 @@ LLVMInstruction *Control::buildSelect(LLVMInstruction *ins) } ins->functionReturnReg->value = m_builder.CreateSelect(cond, trueValue, falseValue); + ins->functionReturnReg->isInt = m_builder.CreateSelect(cond, arg2.second->isInt, arg3.second->isInt); + ins->functionReturnReg->intValue = m_builder.CreateSelect(cond, arg2.second->intValue, arg3.second->intValue); return ins->next; } @@ -241,6 +247,8 @@ LLVMInstruction *Control::buildLoopIndex(LLVMInstruction *ins) LLVMLoop &loop = m_utils.loops().back(); llvm::Value *index = m_builder.CreateLoad(m_builder.getInt64Ty(), loop.index); ins->functionReturnReg->value = m_builder.CreateUIToFP(index, m_builder.getDoubleTy()); + ins->functionReturnReg->intValue = index; + ins->functionReturnReg->isInt = m_builder.getInt1(true); return ins->next; } @@ -336,6 +344,12 @@ LLVMInstruction *Control::buildEndLoop(LLVMInstruction *ins) } LLVMInstruction *Control::buildStop(LLVMInstruction *ins) +{ + m_utils.syncVariables(); + return buildStopWithoutSync(ins); +} + +LLVMInstruction *Control::buildStopWithoutSync(LLVMInstruction *ins) { m_utils.freeScopeHeap(); m_builder.CreateBr(m_utils.endBranch()); diff --git a/src/engine/internal/llvm/instructions/control.h b/src/engine/internal/llvm/instructions/control.h index 0c80f817..bb294f9b 100644 --- a/src/engine/internal/llvm/instructions/control.h +++ b/src/engine/internal/llvm/instructions/control.h @@ -27,6 +27,7 @@ class Control : public InstructionGroup LLVMInstruction *buildBeginLoopCondition(LLVMInstruction *ins); LLVMInstruction *buildEndLoop(LLVMInstruction *ins); LLVMInstruction *buildStop(LLVMInstruction *ins); + LLVMInstruction *buildStopWithoutSync(LLVMInstruction *ins); }; } // namespace libscratchcpp::llvmins diff --git a/src/engine/internal/llvm/instructions/functions.cpp b/src/engine/internal/llvm/instructions/functions.cpp index 94cc4bc0..a40a8c79 100644 --- a/src/engine/internal/llvm/instructions/functions.cpp +++ b/src/engine/internal/llvm/instructions/functions.cpp @@ -30,7 +30,7 @@ LLVMInstruction *Functions::buildFunctionCall(LLVMInstruction *ins) std::vector args; // Variables must be synchronized because the function can read them - m_utils.syncVariables(m_utils.targetVariables()); + m_utils.syncVariables(); // Add execution context arg if (ins->functionCtxArg) { diff --git a/src/engine/internal/llvm/instructions/instructionbuilder.cpp b/src/engine/internal/llvm/instructions/instructionbuilder.cpp index 2263f521..87977e32 100644 --- a/src/engine/internal/llvm/instructions/instructionbuilder.cpp +++ b/src/engine/internal/llvm/instructions/instructionbuilder.cpp @@ -10,11 +10,14 @@ #include "variables.h" #include "lists.h" #include "procedures.h" +#include "../llvminstruction.h" +#include "../llvmbuildutils.h" using namespace libscratchcpp; using namespace libscratchcpp::llvmins; -InstructionBuilder::InstructionBuilder(LLVMBuildUtils &utils) +InstructionBuilder::InstructionBuilder(LLVMBuildUtils &utils) : + m_utils(utils) { // Create groups m_groups.push_back(std::make_shared(utils)); @@ -34,8 +37,14 @@ LLVMInstruction *InstructionBuilder::process(LLVMInstruction *ins) for (const auto &group : m_groups) { ProcessResult result = group->process(ins); - if (result.match) + if (result.match) { +#ifndef LLVM_INTEGER_SUPPORT + if (ins->functionReturnReg) + ins->functionReturnReg->isInt = m_utils.builder().getInt1(false); +#endif + return result.next; + } } assert(false); // instruction not found diff --git a/src/engine/internal/llvm/instructions/instructionbuilder.h b/src/engine/internal/llvm/instructions/instructionbuilder.h index 4c1190ae..88356a37 100644 --- a/src/engine/internal/llvm/instructions/instructionbuilder.h +++ b/src/engine/internal/llvm/instructions/instructionbuilder.h @@ -15,6 +15,7 @@ class InstructionBuilder LLVMInstruction *process(LLVMInstruction *ins); private: + LLVMBuildUtils &m_utils; std::vector> m_groups; }; diff --git a/src/engine/internal/llvm/instructions/lists.cpp b/src/engine/internal/llvm/instructions/lists.cpp index 5c0927a5..88a75a4d 100644 --- a/src/engine/internal/llvm/instructions/lists.cpp +++ b/src/engine/internal/llvm/instructions/lists.cpp @@ -4,6 +4,7 @@ #include "../llvminstruction.h" #include "../llvmbuildutils.h" #include "../llvmconstantregister.h" +#include "../llvmcompilercontext.h" using namespace libscratchcpp; using namespace libscratchcpp::llvmins; @@ -72,6 +73,13 @@ LLVMInstruction *Lists::buildClearList(LLVMInstruction *ins) // Update size m_builder.CreateStore(m_builder.getInt64(0), listPtr.size); } + + if (listPtr.hasNumber && listPtr.hasBool && listPtr.hasString) { + // Reset type info + m_builder.CreateStore(m_builder.getInt1(false), listPtr.hasNumber); + m_builder.CreateStore(m_builder.getInt1(false), listPtr.hasBool); + m_builder.CreateStore(m_builder.getInt1(false), listPtr.hasString); + } } return ins->next; @@ -92,18 +100,15 @@ LLVMInstruction *Lists::buildRemoveListItem(LLVMInstruction *ins) LLVMListPtr &listPtr = m_utils.listPtr(ins->targetList); // Range check - llvm::Value *min = llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.0)); - llvm::Value *size = m_utils.getListSize(listPtr); - size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy()); - llvm::Value *index = m_utils.castValue(arg.second, arg.first); - llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size)); + llvm::Value *index = m_utils.castValue(arg.second, Compiler::StaticType::Number, LLVMBuildUtils::NumberType::Int); + llvm::Value *inRange = createIndexRangeCheck(listPtr, index, "removeListItem.indexInRange"); + llvm::BasicBlock *removeBlock = llvm::BasicBlock::Create(llvmCtx, "", function); llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(llvmCtx, "", function); m_builder.CreateCondBr(inRange, removeBlock, nextBlock); // Remove m_builder.SetInsertPoint(removeBlock); - index = m_builder.CreateFPToUI(m_utils.castValue(arg.second, arg.first), m_builder.getInt64Ty()); m_builder.CreateCall(m_utils.functions().resolve_list_remove(), { listPtr.ptr, index }); if (listPtr.size) { @@ -138,18 +143,21 @@ LLVMInstruction *Lists::buildAppendToList(LLVMInstruction *ins) llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(llvmCtx, "", function); m_builder.CreateCondBr(isAllocated, ifBlock, elseBlock); + // TODO: Add integer support for lists + llvm::Value *isIntVar = m_utils.addAlloca(m_builder.getInt1Ty()); + llvm::Value *intVar = m_utils.addAlloca(m_builder.getInt64Ty()); + // If there's enough space, use the allocated memory m_builder.SetInsertPoint(ifBlock); llvm::Value *itemPtr = m_utils.getListItem(listPtr, size); - m_utils.createValueStore(arg.second, itemPtr, type); + m_utils.createValueStore(itemPtr, m_utils.getValueTypePtr(itemPtr), isIntVar, intVar, arg.second, type); m_builder.CreateStore(m_builder.CreateAdd(size, m_builder.getInt64(1)), listPtr.sizePtr); // update size stored in *sizePtr m_builder.CreateBr(nextBlock); // Otherwise call appendEmpty() m_builder.SetInsertPoint(elseBlock); itemPtr = m_builder.CreateCall(m_utils.functions().resolve_list_append_empty(), listPtr.ptr); - // NOTE: Items created using appendEmpty() are always numbers - m_utils.createValueStore(arg.second, itemPtr, Compiler::StaticType::Number, type); + m_utils.createValueStore(itemPtr, m_utils.getValueTypePtr(itemPtr), isIntVar, intVar, arg.second, type); m_builder.CreateBr(nextBlock); m_builder.SetInsertPoint(nextBlock); @@ -161,6 +169,7 @@ LLVMInstruction *Lists::buildAppendToList(LLVMInstruction *ins) m_builder.CreateStore(size, listPtr.size); } + createListTypeUpdate(listPtr, arg.second, type); return ins->next; } @@ -176,20 +185,21 @@ LLVMInstruction *Lists::buildInsertToList(LLVMInstruction *ins) LLVMListPtr &listPtr = m_utils.listPtr(ins->targetList); // Range check - llvm::Value *size = m_utils.getListSize(listPtr); - llvm::Value *min = llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.0)); - size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy()); - llvm::Value *index = m_utils.castValue(indexArg.second, indexArg.first); - llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLE(index, size)); + llvm::Value *index = m_utils.castValue(indexArg.second, Compiler::StaticType::Number, LLVMBuildUtils::NumberType::Int); + llvm::Value *inRange = createIndexRangeCheck(listPtr, index, "insertToList.indexInRange", true); + llvm::BasicBlock *insertBlock = llvm::BasicBlock::Create(llvmCtx, "", function); llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(llvmCtx, "", function); m_builder.CreateCondBr(inRange, insertBlock, nextBlock); + // TODO: Add integer support for lists + llvm::Value *isIntVar = m_utils.addAlloca(m_builder.getInt1Ty()); + llvm::Value *intVar = m_utils.addAlloca(m_builder.getInt64Ty()); + // Insert m_builder.SetInsertPoint(insertBlock); - index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty()); llvm::Value *itemPtr = m_builder.CreateCall(m_utils.functions().resolve_list_insert_empty(), { listPtr.ptr, index }); - m_utils.createValueStore(valueArg.second, itemPtr, type); + m_utils.createValueStore(itemPtr, m_utils.getValueTypePtr(itemPtr), isIntVar, intVar, valueArg.second, type); if (listPtr.size) { // Update size @@ -198,6 +208,7 @@ LLVMInstruction *Lists::buildInsertToList(LLVMInstruction *ins) m_builder.CreateStore(size, listPtr.size); } + createListTypeUpdate(listPtr, valueArg.second, type); m_builder.CreateBr(nextBlock); m_builder.SetInsertPoint(nextBlock); @@ -222,20 +233,33 @@ LLVMInstruction *Lists::buildListReplace(LLVMInstruction *ins) Compiler::StaticType listType = ins->targetType; // Range check - llvm::Value *min = llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.0)); - llvm::Value *size = m_utils.getListSize(listPtr); - size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy()); - llvm::Value *index = m_utils.castValue(indexArg.second, indexArg.first); - llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size)); + llvm::Value *index = m_utils.castValue(indexArg.second, Compiler::StaticType::Number, LLVMBuildUtils::NumberType::Int); + llvm::Value *inRange = createIndexRangeCheck(listPtr, index, "listReplace.indexInRange"); + llvm::BasicBlock *replaceBlock = llvm::BasicBlock::Create(llvmCtx, "", function); llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(llvmCtx, "", function); m_builder.CreateCondBr(inRange, replaceBlock, nextBlock); // Replace m_builder.SetInsertPoint(replaceBlock); - index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty()); + llvm::Value *itemPtr = m_utils.getListItem(listPtr, index); - m_utils.createValueStore(valueArg.second, itemPtr, listType, type); + llvm::Value *typePtr = m_utils.getValueTypePtr(itemPtr); + llvm::Value *loadedType = m_builder.CreateLoad(m_builder.getInt32Ty(), typePtr); + llvm::Value *typeVar = createListTypeVar(listPtr, loadedType); + + // TODO: Add integer support for lists + llvm::Value *isIntVar = m_utils.addAlloca(m_builder.getInt1Ty()); + llvm::Value *intVar = m_utils.addAlloca(m_builder.getInt64Ty()); + + createListTypeAssumption(listPtr, typeVar, ins->targetType); + m_utils.createValueStore(itemPtr, typeVar, isIntVar, intVar, valueArg.second, listType, type); + + // Value store may change type, make sure to update it + loadedType = m_builder.CreateLoad(m_builder.getInt32Ty(), typeVar); + m_builder.CreateStore(loadedType, typePtr); + + createListTypeUpdate(listPtr, valueArg.second, type); m_builder.CreateBr(nextBlock); m_builder.SetInsertPoint(nextBlock); @@ -255,6 +279,9 @@ LLVMInstruction *Lists::buildGetListContents(LLVMInstruction *ins) LLVMInstruction *Lists::buildGetListItem(LLVMInstruction *ins) { + llvm::LLVMContext &llvmCtx = m_utils.llvmCtx(); + llvm::Function *function = m_utils.function(); + // Return empty string for empty lists if (ins->targetType == Compiler::StaticType::Void) { LLVMConstantRegister nullReg(Compiler::StaticType::String, ""); @@ -266,17 +293,43 @@ LLVMInstruction *Lists::buildGetListItem(LLVMInstruction *ins) const auto &arg = ins->args[0]; LLVMListPtr &listPtr = m_utils.listPtr(ins->targetList); - llvm::Value *min = llvm::ConstantFP::get(m_utils.llvmCtx(), llvm::APFloat(0.0)); - llvm::Value *size = m_utils.getListSize(listPtr); - size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy()); - llvm::Value *index = m_utils.castValue(arg.second, arg.first); - llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size)); + // Range check + llvm::Value *index = m_utils.castValue(arg.second, Compiler::StaticType::Number, LLVMBuildUtils::NumberType::Int); + llvm::Value *inRange = createIndexRangeCheck(listPtr, index, "getListItem.indexInRange"); - LLVMConstantRegister nullReg(Compiler::StaticType::String, ""); - llvm::Value *null = m_utils.createValue(static_cast(&nullReg)); + llvm::BasicBlock *inRangeBlock = llvm::BasicBlock::Create(llvmCtx, "getListItem.inRange", function); + llvm::BasicBlock *outOfRangeBlock = llvm::BasicBlock::Create(llvmCtx, "getListItem.outOfRange", function); + llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(llvmCtx, "getListItem.next", function); + m_builder.CreateCondBr(inRange, inRangeBlock, outOfRangeBlock); - index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty()); - ins->functionReturnReg->value = m_builder.CreateSelect(inRange, m_utils.getListItem(listPtr, index), null); + // In range + m_builder.SetInsertPoint(inRangeBlock); + llvm::Value *itemPtr = m_utils.getListItem(listPtr, index); + llvm::Value *itemType = m_builder.CreateLoad(m_builder.getInt32Ty(), m_utils.getValueTypePtr(itemPtr)); + m_builder.CreateBr(nextBlock); + + // Out of range + m_builder.SetInsertPoint(outOfRangeBlock); + LLVMConstantRegister emptyStringReg(Compiler::StaticType::String, ""); + llvm::Value *emptyString = m_utils.createValue(static_cast(&emptyStringReg)); + llvm::Value *stringType = m_builder.getInt32(static_cast(ValueType::String)); + m_builder.CreateBr(nextBlock); + + m_builder.SetInsertPoint(nextBlock); + + // Result + llvm::PHINode *result = m_builder.CreatePHI(itemPtr->getType(), 2, "getListItem.result"); + result->addIncoming(itemPtr, inRangeBlock); + result->addIncoming(emptyString, outOfRangeBlock); + + llvm::PHINode *itemTypeResult = m_builder.CreatePHI(m_builder.getInt32Ty(), 2, "getListItem.itemType"); + itemTypeResult->addIncoming(itemType, inRangeBlock); + itemTypeResult->addIncoming(stringType, outOfRangeBlock); + + llvm::Value *typeVar = createListTypeVar(listPtr, itemTypeResult); + ins->functionReturnReg->value = result; + ins->functionReturnReg->typeVar = typeVar; + createListTypeAssumption(listPtr, typeVar, ins->targetType, inRange); return ins->next; } @@ -287,6 +340,8 @@ LLVMInstruction *Lists::buildGetListSize(LLVMInstruction *ins) const LLVMListPtr &listPtr = m_utils.listPtr(ins->targetList); llvm::Value *size = m_utils.getListSize(listPtr); ins->functionReturnReg->value = m_builder.CreateUIToFP(size, m_builder.getDoubleTy()); + ins->functionReturnReg->isInt = m_builder.getInt1(true); + ins->functionReturnReg->intValue = size; return ins->next; } @@ -316,3 +371,172 @@ LLVMInstruction *Lists::buildListContainsItem(LLVMInstruction *ins) return ins->next; } + +llvm::Value *Lists::createIndexRangeCheck(const LLVMListPtr &listPtr, llvm::Value *index, const std::string &name, bool includeSize) +{ + llvm::Function *expectIntrinsic = llvm::Intrinsic::getDeclaration(m_utils.module(), llvm::Intrinsic::expect, m_builder.getInt1Ty()); + + llvm::Value *min = llvm::ConstantInt::get(m_builder.getInt64Ty(), 0, true); + llvm::Value *size = m_utils.getListSize(listPtr); + llvm::Value *sizeCheck = includeSize ? m_builder.CreateICmpSLE(index, size) : m_builder.CreateICmpSLT(index, size); + llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateICmpSGE(index, min), sizeCheck, name); + + // Tell the optimizer that indices in range are more common + return m_builder.CreateCall(expectIntrinsic, { inRange, m_builder.getInt1(true) }); +} + +void Lists::createListTypeUpdate(const LLVMListPtr &listPtr, const LLVMRegister *newValue, Compiler::StaticType newValueType) +{ + if (listPtr.hasNumber && listPtr.hasBool && listPtr.hasString) { + // Get the new type + llvm::Value *newType; + + if (newValue->isRawValue) + newType = m_builder.getInt32(static_cast(m_utils.mapType(newValueType))); + else { + llvm::Value *typeField = m_builder.CreateStructGEP(m_utils.compilerCtx()->valueDataType(), newValue->value, 1); + newType = m_builder.CreateLoad(m_builder.getInt32Ty(), typeField); + } + + bool staticHasNumber = (newValueType & Compiler::StaticType::Number) == Compiler::StaticType::Number; + bool staticHasBool = (newValueType & Compiler::StaticType::Bool) == Compiler::StaticType::Bool; + bool staticHasString = (newValueType & Compiler::StaticType::String) == Compiler::StaticType::String; + + llvm::Value *isNumber; + + if (staticHasNumber) + isNumber = m_builder.CreateICmpEQ(newType, m_builder.getInt32(static_cast(ValueType::Number))); + else + isNumber = m_builder.getInt1(false); + + llvm::Value *isBool; + + if (staticHasBool) + isBool = m_builder.CreateICmpEQ(newType, m_builder.getInt32(static_cast(ValueType::Bool))); + else + isBool = m_builder.getInt1(false); + + llvm::Value *isString; + + if (staticHasString) + isString = m_builder.CreateICmpEQ(newType, m_builder.getInt32(static_cast(ValueType::String))); + else + isString = m_builder.getInt1(false); + + // Update flags + llvm::Value *previous = m_builder.CreateLoad(m_builder.getInt1Ty(), listPtr.hasNumber); + m_builder.CreateStore(m_builder.CreateOr(previous, isNumber), listPtr.hasNumber); + + previous = m_builder.CreateLoad(m_builder.getInt1Ty(), listPtr.hasBool); + m_builder.CreateStore(m_builder.CreateOr(previous, isBool), listPtr.hasBool); + + previous = m_builder.CreateLoad(m_builder.getInt1Ty(), listPtr.hasString); + m_builder.CreateStore(m_builder.CreateOr(previous, isString), listPtr.hasString); + } +} + +llvm::Value *Lists::createListTypeVar(const LLVMListPtr &listPtr, llvm::Value *type) +{ + llvm::Value *typeVar = m_utils.addAlloca(m_builder.getInt32Ty()); + m_builder.CreateStore(type, typeVar); + return typeVar; +} + +void Lists::createListTypeAssumption(const LLVMListPtr &listPtr, llvm::Value *typeVar, Compiler::StaticType staticType, llvm::Value *inRange) +{ + if (listPtr.hasNumber && listPtr.hasBool && listPtr.hasString) { + llvm::Function *assumeIntrinsic = llvm::Intrinsic::getDeclaration(m_utils.module(), llvm::Intrinsic::assume); + + // Load the compile time list type information + bool staticHasNumber = (staticType & Compiler::StaticType::Number) == Compiler::StaticType::Number; + bool staticHasBool = (staticType & Compiler::StaticType::Bool) == Compiler::StaticType::Bool; + bool staticHasString = (staticType & Compiler::StaticType::String) == Compiler::StaticType::String; + + // Load the runtime list type information + llvm::Value *hasNumber; + + if (staticHasNumber) + hasNumber = m_builder.CreateLoad(m_builder.getInt1Ty(), listPtr.hasNumber); + else + hasNumber = m_builder.getInt1(false); + + llvm::Value *hasBool; + + if (staticHasBool) + hasBool = m_builder.CreateLoad(m_builder.getInt1Ty(), listPtr.hasBool); + else + hasBool = m_builder.getInt1(false); + + llvm::Value *hasString; + + if (staticHasString) + hasString = m_builder.CreateLoad(m_builder.getInt1Ty(), listPtr.hasString); + else + hasString = m_builder.getInt1(false); + + llvm::Value *type = m_builder.CreateLoad(m_builder.getInt32Ty(), typeVar); + + if (!inRange) + inRange = m_builder.getInt1(true); + + llvm::Value *numberType = m_builder.getInt32(static_cast(ValueType::Number)); + llvm::Value *boolType = m_builder.getInt32(static_cast(ValueType::Bool)); + llvm::Value *stringType = m_builder.getInt32(static_cast(ValueType::String)); + + // Create assumptions + llvm::BasicBlock *outOfRangeBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "outOfRange", m_utils.function()); + llvm::BasicBlock *inRangeBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "inRange", m_utils.function()); + llvm::BasicBlock *afterBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "afterAssume", m_utils.function()); + + m_builder.CreateCondBr(inRange, inRangeBlock, outOfRangeBlock); + + // In-range assumptions + m_builder.SetInsertPoint(inRangeBlock); + + auto assume = [&](llvm::Value *cond) { m_builder.CreateCall(assumeIntrinsic, cond); }; + + llvm::Value *notNumber = m_builder.CreateNot(hasNumber); + llvm::Value *notBool = m_builder.CreateNot(hasBool); + llvm::Value *notString = m_builder.CreateNot(hasString); + + // if (!hasBool && !hasString) type == Number + llvm::Value *cond1 = m_builder.CreateAnd(notBool, notString); + llvm::Value *assume1 = m_builder.CreateICmpEQ(type, numberType); + assume(m_builder.CreateSelect(cond1, assume1, m_builder.getInt1(true))); + + // else if (!hasNumber && !hasString) type == Bool + llvm::Value *cond2 = m_builder.CreateAnd(notNumber, notString); + llvm::Value *assume2 = m_builder.CreateICmpEQ(type, boolType); + assume(m_builder.CreateSelect(cond2, assume2, m_builder.getInt1(true))); + + // else if (!hasNumber && !hasBool) type == String + llvm::Value *cond3 = m_builder.CreateAnd(notNumber, notBool); + llvm::Value *assume3 = m_builder.CreateICmpEQ(type, stringType); + assume(m_builder.CreateSelect(cond3, assume3, m_builder.getInt1(true))); + + // else if (!hasBool) type == Number || type == String + llvm::Value *cond4 = notBool; + llvm::Value *assume4 = m_builder.CreateOr(m_builder.CreateICmpEQ(type, numberType), m_builder.CreateICmpEQ(type, stringType)); + assume(m_builder.CreateSelect(cond4, assume4, m_builder.getInt1(true))); + + // else if (!hasNumber) type == Bool || type == String + llvm::Value *cond5 = notNumber; + llvm::Value *assume5 = m_builder.CreateOr(m_builder.CreateICmpEQ(type, boolType), m_builder.CreateICmpEQ(type, stringType)); + assume(m_builder.CreateSelect(cond5, assume5, m_builder.getInt1(true))); + + // else if (!hasString) type == Number || type == Bool + llvm::Value *cond6 = notString; + llvm::Value *assume6 = m_builder.CreateOr(m_builder.CreateICmpEQ(type, numberType), m_builder.CreateICmpEQ(type, boolType)); + assume(m_builder.CreateSelect(cond6, assume6, m_builder.getInt1(true))); + + m_builder.CreateBr(afterBlock); + + // Out-of-range: always string + m_builder.SetInsertPoint(outOfRangeBlock); + llvm::Value *isString = m_builder.CreateICmpEQ(type, stringType); + assume(isString); + m_builder.CreateBr(afterBlock); + + m_builder.SetInsertPoint(afterBlock); + } +} diff --git a/src/engine/internal/llvm/instructions/lists.h b/src/engine/internal/llvm/instructions/lists.h index 59aed889..06ceedee 100644 --- a/src/engine/internal/llvm/instructions/lists.h +++ b/src/engine/internal/llvm/instructions/lists.h @@ -2,9 +2,17 @@ #pragma once +#include + #include "instructiongroup.h" -namespace libscratchcpp::llvmins +namespace libscratchcpp +{ + +class LLVMListPtr; +class LLVMRegister; + +namespace llvmins { class Lists : public InstructionGroup @@ -25,6 +33,14 @@ class Lists : public InstructionGroup LLVMInstruction *buildGetListSize(LLVMInstruction *ins); LLVMInstruction *buildGetListItemIndex(LLVMInstruction *ins); LLVMInstruction *buildListContainsItem(LLVMInstruction *ins); + + llvm::Value *createIndexRangeCheck(const LLVMListPtr &listPtr, llvm::Value *index, const std::string &name, bool includeSize = false); + + void createListTypeUpdate(const LLVMListPtr &listPtr, const LLVMRegister *newValue, Compiler::StaticType newValueType); + llvm::Value *createListTypeVar(const LLVMListPtr &listPtr, llvm::Value *type); + void createListTypeAssumption(const LLVMListPtr &listPtr, llvm::Value *typeVar, Compiler::StaticType staticType, llvm::Value *inRange = nullptr); }; -} // namespace libscratchcpp::llvmins +} // namespace llvmins + +} // namespace libscratchcpp diff --git a/src/engine/internal/llvm/instructions/math.cpp b/src/engine/internal/llvm/instructions/math.cpp index 1a6c452c..aeb95cc3 100644 --- a/src/engine/internal/llvm/instructions/math.cpp +++ b/src/engine/internal/llvm/instructions/math.cpp @@ -113,9 +113,15 @@ LLVMInstruction *Math::buildAdd(LLVMInstruction *ins) assert(ins->args.size() == 2); const auto &arg1 = ins->args[0]; const auto &arg2 = ins->args[1]; - llvm::Value *num1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first)); - llvm::Value *num2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first)); - ins->functionReturnReg->value = m_builder.CreateFAdd(num1, num2); + + llvm::Value *double1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first, LLVMBuildUtils::NumberType::Double)); + llvm::Value *double2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Double)); + ins->functionReturnReg->value = m_builder.CreateFAdd(double1, double2); + + llvm::Value *int1 = m_utils.castValue(arg1.second, arg1.first, LLVMBuildUtils::NumberType::Int); + llvm::Value *int2 = m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Int); + ins->functionReturnReg->isInt = m_builder.CreateAnd(arg1.second->isInt, arg2.second->isInt); + ins->functionReturnReg->intValue = m_builder.CreateAdd(int1, int2); return ins->next; } @@ -125,9 +131,15 @@ LLVMInstruction *Math::buildSub(LLVMInstruction *ins) assert(ins->args.size() == 2); const auto &arg1 = ins->args[0]; const auto &arg2 = ins->args[1]; - llvm::Value *num1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first)); - llvm::Value *num2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first)); - ins->functionReturnReg->value = m_builder.CreateFSub(num1, num2); + + llvm::Value *double1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first, LLVMBuildUtils::NumberType::Double)); + llvm::Value *double2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Double)); + ins->functionReturnReg->value = m_builder.CreateFSub(double1, double2); + + llvm::Value *int1 = m_utils.castValue(arg1.second, arg1.first, LLVMBuildUtils::NumberType::Int); + llvm::Value *int2 = m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Int); + ins->functionReturnReg->isInt = m_builder.CreateAnd(arg1.second->isInt, arg2.second->isInt); + ins->functionReturnReg->intValue = m_builder.CreateSub(int1, int2); return ins->next; } @@ -137,9 +149,15 @@ LLVMInstruction *Math::buildMul(LLVMInstruction *ins) assert(ins->args.size() == 2); const auto &arg1 = ins->args[0]; const auto &arg2 = ins->args[1]; - llvm::Value *num1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first)); - llvm::Value *num2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first)); - ins->functionReturnReg->value = m_builder.CreateFMul(num1, num2); + + llvm::Value *double1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first, LLVMBuildUtils::NumberType::Double)); + llvm::Value *double2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Double)); + ins->functionReturnReg->value = m_builder.CreateFMul(double1, double2); + + llvm::Value *int1 = m_utils.castValue(arg1.second, arg1.first, LLVMBuildUtils::NumberType::Int); + llvm::Value *int2 = m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Int); + ins->functionReturnReg->isInt = m_builder.CreateAnd(arg1.second->isInt, arg2.second->isInt); + ins->functionReturnReg->intValue = m_builder.CreateMul(int1, int2); return ins->next; } @@ -197,7 +215,10 @@ LLVMInstruction *Math::buildRandomInt(LLVMInstruction *ins) const auto &arg2 = ins->args[1]; llvm::Value *from = m_builder.CreateFPToSI(m_utils.castValue(arg1.second, arg1.first), m_builder.getInt64Ty()); llvm::Value *to = m_builder.CreateFPToSI(m_utils.castValue(arg2.second, arg2.first), m_builder.getInt64Ty()); - ins->functionReturnReg->value = m_builder.CreateCall(m_utils.functions().resolve_llvm_random_long(), { m_utils.executionContextPtr(), from, to }); + llvm::Value *intValue = m_builder.CreateCall(m_utils.functions().resolve_llvm_random_int64(), { m_utils.executionContextPtr(), from, to }); + ins->functionReturnReg->value = m_builder.CreateSIToFP(intValue, m_builder.getDoubleTy()); + ins->functionReturnReg->intValue = intValue; + ins->functionReturnReg->isInt = m_builder.getInt1(true); return ins->next; } @@ -207,13 +228,49 @@ LLVMInstruction *Math::buildMod(LLVMInstruction *ins) assert(ins->args.size() == 2); const auto &arg1 = ins->args[0]; const auto &arg2 = ins->args[1]; - // rem(a, b) / b < 0.0 ? rem(a, b) + b : rem(a, b) - llvm::Constant *zero = llvm::ConstantFP::get(m_utils.llvmCtx(), llvm::APFloat(0.0)); - llvm::Value *num1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first)); - llvm::Value *num2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first)); - llvm::Value *value = m_builder.CreateFRem(num1, num2); // rem(a, b) - llvm::Value *cond = m_builder.CreateFCmpOLT(m_builder.CreateFDiv(value, num2), zero); // rem(a, b) / b < 0.0 // rem(a, b) - ins->functionReturnReg->value = m_builder.CreateSelect(cond, m_builder.CreateFAdd(value, num2), value); + + // double: rem(a, b) / b < 0.0 ? rem(a, b) + b : rem(a, b) + llvm::Constant *doubleZero = llvm::ConstantFP::get(m_utils.llvmCtx(), llvm::APFloat(0.0)); + llvm::Value *double1 = m_utils.removeNaN(m_utils.castValue(arg1.second, arg1.first)); + llvm::Value *double2 = m_utils.removeNaN(m_utils.castValue(arg2.second, arg2.first)); + llvm::Value *doubleRem = m_builder.CreateFRem(double1, double2); // rem(a, b) + llvm::Value *doubleCond = m_builder.CreateFCmpOLT(m_builder.CreateFDiv(doubleRem, double2), doubleZero); // rem(a, b) / b < 0.0 + ins->functionReturnReg->value = m_builder.CreateSelect(doubleCond, m_builder.CreateFAdd(doubleRem, double2), doubleRem); + + // int: b == 0 ? 0 (double fallback) : ((rem(a, b) < 0) != (b < 0) ? rem(a, b) + b : rem(a, b)) + llvm::Constant *intZero = llvm::ConstantInt::get(m_builder.getInt64Ty(), 0, true); + llvm::Value *int1 = m_utils.castValue(arg1.second, arg1.first, LLVMBuildUtils::NumberType::Int); + llvm::Value *int2 = m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Int); + llvm::Value *nanResult = m_builder.CreateICmpEQ(int2, intZero); + + llvm::BasicBlock *nanBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "", m_utils.function()); + llvm::BasicBlock *intBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "", m_utils.function()); + llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "", m_utils.function()); + m_builder.CreateCondBr(nanResult, nanBlock, intBlock); + + m_builder.SetInsertPoint(nanBlock); + llvm::Value *noInt = m_builder.getInt1(false); + m_builder.CreateBr(nextBlock); + + m_builder.SetInsertPoint(intBlock); + llvm::Value *isInt = m_builder.CreateAnd(arg1.second->isInt, arg2.second->isInt); + llvm::Value *intRem = m_builder.CreateSRem(int1, int2); // rem(a, b) + llvm::Value *intCond = m_builder.CreateICmpSLT(m_builder.CreateSDiv(intRem, int2), intZero); // rem(a, b) / b < 0 + llvm::Value *intResult = m_builder.CreateSelect(intCond, m_builder.CreateAdd(intRem, int2), intRem); + m_builder.CreateBr(nextBlock); + + m_builder.SetInsertPoint(nextBlock); + + llvm::PHINode *resultPhi = m_builder.CreatePHI(m_builder.getInt64Ty(), 2); + resultPhi->addIncoming(intZero, nanBlock); + resultPhi->addIncoming(intResult, intBlock); + + llvm::PHINode *isIntPhi = m_builder.CreatePHI(m_builder.getInt1Ty(), 2); + isIntPhi->addIncoming(noInt, nanBlock); + isIntPhi->addIncoming(isInt, intBlock); + + ins->functionReturnReg->intValue = resultPhi; + ins->functionReturnReg->isInt = isIntPhi; return ins->next; } @@ -225,18 +282,29 @@ LLVMInstruction *Math::buildRound(LLVMInstruction *ins) assert(ins->args.size() == 1); const auto &arg = ins->args[0]; - // x >= 0.0 ? round(x) : (x >= -0.5 ? -0.0 : floor(x + 0.5)) + + // double: x >= 0.0 ? round(x) : (x >= -0.5 ? -0.0 : floor(x + 0.5)) llvm::Constant *zero = llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.0)); llvm::Constant *negativeZero = llvm::ConstantFP::get(llvmCtx, llvm::APFloat(-0.0)); llvm::Function *roundFunc = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::round, m_builder.getDoubleTy()); llvm::Function *floorFunc = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::floor, m_builder.getDoubleTy()); - llvm::Value *num = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first)); - llvm::Value *notNegative = m_builder.CreateFCmpOGE(num, zero); // num >= 0.0 - llvm::Value *roundNum = m_builder.CreateCall(roundFunc, num); // round(num) - llvm::Value *negativeCond = m_builder.CreateFCmpOGE(num, llvm::ConstantFP::get(llvmCtx, llvm::APFloat(-0.5))); // num >= -0.5 - llvm::Value *negativeRound = m_builder.CreateCall(floorFunc, m_builder.CreateFAdd(num, llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.5)))); // floor(x + 0.5) + llvm::Value *doubleValue = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first)); + llvm::Value *notNegative = m_builder.CreateFCmpOGE(doubleValue, zero); // num >= 0.0 + llvm::Value *roundNum = m_builder.CreateCall(roundFunc, doubleValue); // round(num) + llvm::Value *negativeCond = m_builder.CreateFCmpOGE(doubleValue, llvm::ConstantFP::get(llvmCtx, llvm::APFloat(-0.5))); // num >= -0.5 + llvm::Value *negativeRound = m_builder.CreateCall(floorFunc, m_builder.CreateFAdd(doubleValue, llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.5)))); // floor(x + 0.5) ins->functionReturnReg->value = m_builder.CreateSelect(notNegative, roundNum, m_builder.CreateSelect(negativeCond, negativeZero, negativeRound)); + // int: doubleX == inf || doubleX == -inf ? doubleX : intX + llvm::Constant *posInf = llvm::ConstantFP::getInfinity(m_builder.getDoubleTy(), false); + llvm::Constant *negInf = llvm::ConstantFP::getInfinity(m_builder.getDoubleTy(), true); + llvm::Value *isInt = arg.second->isInt; + llvm::Value *intValue = arg.second->intValue; + llvm::Value *isNotInf = m_builder.CreateAnd(m_builder.CreateFCmpONE(doubleValue, posInf), m_builder.CreateFCmpONE(doubleValue, negInf)); + llvm::Value *cast = m_builder.CreateFPToSI(ins->functionReturnReg->value, m_builder.getInt64Ty()); + ins->functionReturnReg->isInt = isNotInf; + ins->functionReturnReg->intValue = m_builder.CreateSelect(isInt, intValue, cast); + return ins->next; } @@ -244,9 +312,15 @@ LLVMInstruction *Math::buildAbs(LLVMInstruction *ins) { assert(ins->args.size() == 1); const auto &arg = ins->args[0]; - llvm::Function *absFunc = llvm::Intrinsic::getDeclaration(m_utils.module(), llvm::Intrinsic::fabs, m_builder.getDoubleTy()); - llvm::Value *num = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first)); - ins->functionReturnReg->value = m_builder.CreateCall(absFunc, num); + + llvm::Function *fabsFunc = llvm::Intrinsic::getDeclaration(m_utils.module(), llvm::Intrinsic::fabs, m_builder.getDoubleTy()); + llvm::Value *doubleValue = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first)); + ins->functionReturnReg->value = m_builder.CreateCall(fabsFunc, doubleValue); + + llvm::Function *absFunc = llvm::Intrinsic::getDeclaration(m_utils.module(), llvm::Intrinsic::abs, m_builder.getInt64Ty()); + llvm::Value *intValue = arg.second->intValue; + ins->functionReturnReg->isInt = arg.second->isInt; + ins->functionReturnReg->intValue = m_builder.CreateCall(absFunc, { intValue, m_builder.getInt1(false) }); return ins->next; } @@ -255,9 +329,21 @@ LLVMInstruction *Math::buildFloor(LLVMInstruction *ins) { assert(ins->args.size() == 1); const auto &arg = ins->args[0]; + + // double: floor(doubleX) llvm::Function *floorFunc = llvm::Intrinsic::getDeclaration(m_utils.module(), llvm::Intrinsic::floor, m_builder.getDoubleTy()); - llvm::Value *num = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first)); - ins->functionReturnReg->value = m_builder.CreateCall(floorFunc, num); + llvm::Value *doubleValue = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first)); + ins->functionReturnReg->value = m_builder.CreateCall(floorFunc, doubleValue); + + // int: doubleX == inf || doubleX == -inf ? doubleX : intX + llvm::Constant *posInf = llvm::ConstantFP::getInfinity(m_builder.getDoubleTy(), false); + llvm::Constant *negInf = llvm::ConstantFP::getInfinity(m_builder.getDoubleTy(), true); + llvm::Value *isInt = arg.second->isInt; + llvm::Value *intValue = arg.second->intValue; + llvm::Value *isNotInf = m_builder.CreateAnd(m_builder.CreateFCmpONE(doubleValue, posInf), m_builder.CreateFCmpONE(doubleValue, negInf)); + llvm::Value *cast = m_builder.CreateFPToSI(ins->functionReturnReg->value, m_builder.getInt64Ty()); + ins->functionReturnReg->isInt = isNotInf; + ins->functionReturnReg->intValue = m_builder.CreateSelect(isInt, intValue, cast); return ins->next; } @@ -266,9 +352,21 @@ LLVMInstruction *Math::buildCeil(LLVMInstruction *ins) { assert(ins->args.size() == 1); const auto &arg = ins->args[0]; + + // double: ceil(doubleX) llvm::Function *ceilFunc = llvm::Intrinsic::getDeclaration(m_utils.module(), llvm::Intrinsic::ceil, m_builder.getDoubleTy()); - llvm::Value *num = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first)); - ins->functionReturnReg->value = m_builder.CreateCall(ceilFunc, num); + llvm::Value *doubleValue = m_utils.removeNaN(m_utils.castValue(arg.second, arg.first)); + ins->functionReturnReg->value = m_builder.CreateCall(ceilFunc, doubleValue); + + // int: doubleX == inf || doubleX == -inf ? doubleX : intX + llvm::Constant *posInf = llvm::ConstantFP::getInfinity(m_builder.getDoubleTy(), false); + llvm::Constant *negInf = llvm::ConstantFP::getInfinity(m_builder.getDoubleTy(), true); + llvm::Value *isInt = arg.second->isInt; + llvm::Value *intValue = arg.second->intValue; + llvm::Value *isNotInf = m_builder.CreateAnd(m_builder.CreateFCmpONE(doubleValue, posInf), m_builder.CreateFCmpONE(doubleValue, negInf)); + llvm::Value *cast = m_builder.CreateFPToSI(ins->functionReturnReg->value, m_builder.getInt64Ty()); + ins->functionReturnReg->isInt = isNotInf; + ins->functionReturnReg->intValue = m_builder.CreateSelect(isInt, intValue, cast); return ins->next; } diff --git a/src/engine/internal/llvm/instructions/procedures.cpp b/src/engine/internal/llvm/instructions/procedures.cpp index 5b6f9f75..f6022458 100644 --- a/src/engine/internal/llvm/instructions/procedures.cpp +++ b/src/engine/internal/llvm/instructions/procedures.cpp @@ -39,7 +39,7 @@ LLVMInstruction *Procedures::buildCallProcedure(LLVMInstruction *ins) assert(ins->procedurePrototype); assert(ins->args.size() == ins->procedurePrototype->argumentTypes().size()); m_utils.freeScopeHeap(); - m_utils.syncVariables(m_utils.targetVariables()); + m_utils.syncVariables(); std::string name = m_utils.scriptFunctionName(ins->procedurePrototype); llvm::FunctionType *type = m_utils.scriptFunctionType(ins->procedurePrototype); @@ -80,7 +80,7 @@ LLVMInstruction *Procedures::buildCallProcedure(LLVMInstruction *ins) m_builder.SetInsertPoint(nextBranch); } - m_utils.reloadVariables(m_utils.targetVariables()); + m_utils.reloadVariables(); m_utils.reloadLists(); return ins->next; } diff --git a/src/engine/internal/llvm/instructions/string.cpp b/src/engine/internal/llvm/instructions/string.cpp index c973196b..8cf7a0da 100644 --- a/src/engine/internal/llvm/instructions/string.cpp +++ b/src/engine/internal/llvm/instructions/string.cpp @@ -85,7 +85,7 @@ LLVMInstruction *String::buildStringChar(LLVMInstruction *ins) const auto &arg1 = ins->args[0]; const auto &arg2 = ins->args[1]; llvm::Value *str = m_utils.castValue(arg1.second, arg1.first); - llvm::Value *index = m_builder.CreateFPToSI(m_utils.castValue(arg2.second, arg2.first), m_builder.getInt64Ty()); + llvm::Value *index = m_utils.castValue(arg2.second, arg2.first, LLVMBuildUtils::NumberType::Int); llvm::PointerType *charPointerType = m_builder.getInt16Ty()->getPointerTo(); llvm::StructType *stringPtrType = m_utils.compilerCtx()->stringPtrType(); @@ -131,6 +131,8 @@ LLVMInstruction *String::buildStringLength(LLVMInstruction *ins) llvm::Value *sizeField = m_builder.CreateStructGEP(stringPtrType, str, 1); llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), sizeField); ins->functionReturnReg->value = m_builder.CreateSIToFP(size, m_builder.getDoubleTy()); + ins->functionReturnReg->intValue = size; + ins->functionReturnReg->isInt = m_builder.getInt1(true); return ins->next; } diff --git a/src/engine/internal/llvm/instructions/variables.cpp b/src/engine/internal/llvm/instructions/variables.cpp index 9298c3d5..6eec5cf7 100644 --- a/src/engine/internal/llvm/instructions/variables.cpp +++ b/src/engine/internal/llvm/instructions/variables.cpp @@ -3,6 +3,7 @@ #include "variables.h" #include "../llvminstruction.h" #include "../llvmbuildutils.h" +#include "../llvmconstantregister.h" using namespace libscratchcpp; using namespace libscratchcpp::llvmins; @@ -43,31 +44,15 @@ ProcessResult Variables::process(LLVMInstruction *ins) LLVMInstruction *Variables::buildCreateLocalVariable(LLVMInstruction *ins) { assert(ins->args.empty()); - llvm::Type *type = nullptr; + LLVMLocalVariableInfo *info = ins->localVarInfo; - switch (ins->functionReturnReg->type()) { - case Compiler::StaticType::Number: - type = m_builder.getDoubleTy(); - break; - - case Compiler::StaticType::Bool: - type = m_builder.getInt1Ty(); - break; - - case Compiler::StaticType::String: - std::cerr << "error: local variables do not support string type" << std::endl; - break; - - case Compiler::StaticType::Pointer: - std::cerr << "error: local variables do not support pointer type" << std::endl; - break; - - default: - assert(false); - break; - } + info->isInt = m_utils.addAlloca(m_builder.getInt1Ty()); + info->intValue = m_utils.addAlloca(m_builder.getInt64Ty()); + m_builder.CreateStore(m_builder.getInt1(false), info->isInt); + m_builder.CreateStore(llvm::ConstantInt::get(m_builder.getInt64Ty(), 0, true), info->isInt); - ins->functionReturnReg->value = m_utils.addAlloca(type); + LLVMConstantRegister null(ins->functionReturnReg->type(), Value()); + ins->functionReturnReg->value = m_utils.createValue(&null); return ins->next; } @@ -76,8 +61,11 @@ LLVMInstruction *Variables::buildWriteLocalVariable(LLVMInstruction *ins) assert(ins->args.size() == 2); const auto &arg1 = ins->args[0]; const auto &arg2 = ins->args[1]; - llvm::Value *converted = m_utils.castValue(arg2.second, arg2.first); - m_builder.CreateStore(converted, arg1.second->value); + LLVMLocalVariableInfo *info = ins->localVarInfo; + llvm::Value *typeVar = m_utils.addAlloca(m_builder.getInt32Ty()); + m_builder.CreateStore(m_builder.getInt32(static_cast(m_utils.mapType(arg2.first))), typeVar); + + m_utils.createValueStore(arg1.second->value, typeVar, info->isInt, info->intValue, arg2.second, arg2.first, arg2.first); return ins->next; } @@ -85,23 +73,10 @@ LLVMInstruction *Variables::buildReadLocalVariable(LLVMInstruction *ins) { assert(ins->args.size() == 1); const auto &arg = ins->args[0]; - llvm::Type *type = nullptr; - - switch (ins->functionReturnReg->type()) { - case Compiler::StaticType::Number: - type = m_builder.getDoubleTy(); - break; - - case Compiler::StaticType::Bool: - type = m_builder.getInt1Ty(); - break; - - default: - assert(false); - break; - } - - ins->functionReturnReg->value = m_builder.CreateLoad(type, arg.second->value); + LLVMLocalVariableInfo *info = ins->localVarInfo; + ins->functionReturnReg->value = m_utils.castValue(arg.second, ins->functionReturnReg->type()); + ins->functionReturnReg->isInt = m_builder.CreateLoad(m_builder.getInt1Ty(), info->isInt); + ins->functionReturnReg->intValue = m_builder.CreateLoad(m_builder.getInt64Ty(), info->intValue); return ins->next; } @@ -111,30 +86,9 @@ LLVMInstruction *Variables::buildWriteVariable(LLVMInstruction *ins) const auto &arg = ins->args[0]; Compiler::StaticType argType = m_utils.optimizeRegisterType(arg.second); LLVMVariablePtr &varPtr = m_utils.variablePtr(ins->targetVariable); - varPtr.changed = true; // TODO: Handle loops and if statements - - // Initialize stack variable on first assignment - // TODO: Use stack in the top level (outside loops and if statements) - /*if (!varPtr.onStack) { - varPtr.onStack = true; - varPtr.type = type; // don't care about unknown type on first assignment - - ValueType mappedType; - - if (type == Compiler::StaticType::String || type == Compiler::StaticType::Unknown) { - // Value functions are used for these types, so don't break them - mappedType = ValueType::Number; - } else { - auto it = std::find_if(TYPE_MAP.begin(), TYPE_MAP.end(), [type](const std::pair &pair) { return pair.second == type; }); - assert(it != TYPE_MAP.cend()); - mappedType = it->first; - } - - llvm::Value *typeField = m_builder.CreateStructGEP(m_valueDataType, varPtr.stackPtr, 1); - m_builder.CreateStore(m_builder.getInt32(static_cast(mappedType)), typeField); - }*/ - - m_utils.createValueStore(arg.second, varPtr.stackPtr, ins->targetType, argType); + + m_utils.createValueStore(varPtr.stackPtr, m_utils.getValueTypePtr(varPtr.stackPtr), varPtr.isInt, varPtr.intValue, arg.second, ins->targetType, argType); + m_builder.CreateStore(m_builder.getInt1(true), varPtr.changed); return ins->next; } @@ -143,6 +97,8 @@ LLVMInstruction *Variables::buildReadVariable(LLVMInstruction *ins) assert(ins->args.size() == 0); LLVMVariablePtr &varPtr = m_utils.variablePtr(ins->targetVariable); - ins->functionReturnReg->value = varPtr.onStack && !(ins->loopCondition && !m_utils.warp()) ? varPtr.stackPtr : varPtr.heapPtr; + ins->functionReturnReg->value = varPtr.stackPtr; + ins->functionReturnReg->isInt = m_builder.CreateLoad(m_builder.getInt1Ty(), varPtr.isInt); + ins->functionReturnReg->intValue = m_builder.CreateLoad(m_builder.getInt64Ty(), varPtr.intValue); return ins->next; } diff --git a/src/engine/internal/llvm/llvmbuildutils.cpp b/src/engine/internal/llvm/llvmbuildutils.cpp index f61efe75..c8a70426 100644 --- a/src/engine/internal/llvm/llvmbuildutils.cpp +++ b/src/engine/internal/llvm/llvmbuildutils.cpp @@ -21,6 +21,13 @@ static std::unordered_map TYPE_MAP = { { ValueType::Pointer, Compiler::StaticType::Pointer } }; +static std::unordered_map REVERSE_TYPE_MAP = { + { Compiler::StaticType::Number, ValueType::Number }, + { Compiler::StaticType::Bool, ValueType::Bool }, + { Compiler::StaticType::String, ValueType::String }, + { Compiler::StaticType::Pointer, ValueType::Pointer } +}; + LLVMBuildUtils::LLVMBuildUtils(LLVMCompilerContext *ctx, llvm::IRBuilder<> &builder, Compiler::CodeType codeType) : m_ctx(ctx), m_llvmCtx(*ctx->llvmCtx()), @@ -34,7 +41,7 @@ LLVMBuildUtils::LLVMBuildUtils(LLVMCompilerContext *ctx, llvm::IRBuilder<> &buil createListMap(); } -void LLVMBuildUtils::init(llvm::Function *function, BlockPrototype *procedurePrototype, bool warp) +void LLVMBuildUtils::init(llvm::Function *function, BlockPrototype *procedurePrototype, bool warp, const std::vector> ®s) { m_function = function; m_procedurePrototype = procedurePrototype; @@ -56,33 +63,31 @@ void LLVMBuildUtils::init(llvm::Function *function, BlockPrototype *procedurePro if (!m_warp) m_coroutine = std::make_unique(m_ctx->module(), &m_builder, m_function); + // Init registers + for (auto reg : regs) { +#ifdef LLVM_INTEGER_SUPPORT + bool isIntConst = (reg->isConst() && optimizeRegisterType(reg.get()) == Compiler::StaticType::Number && reg->constValue().toDouble() == std::floor(reg->constValue().toDouble())); + reg->isInt = m_builder.getInt1(isIntConst); +#else + reg->isInt = m_builder.getInt1(false); +#endif + reg->intValue = llvm::ConstantInt::get(m_builder.getInt64Ty(), reg->constValue().toDouble(), true); + } + // Create variable pointers for (auto &[var, varPtr] : m_variablePtrs) { llvm::Value *ptr = getVariablePtr(m_targetVariables, var); - - // Direct access varPtr.heapPtr = ptr; - // All variables are currently created on the stack and synced later (seems to be faster) - // NOTE: Strings are NOT copied, only the pointer is copied - // TODO: Restore this feature - // varPtr.stackPtr = m_builder.CreateAlloca(m_valueDataType); - varPtr.stackPtr = varPtr.heapPtr; - varPtr.onStack = false; - - // If there are no write operations outside loops, initialize the stack variable now - /*Variable *variable = var; - // TODO: Loop scope was used here, replace it with some "inside loop" flag if needed - auto it = std::find_if(m_variableInstructions.begin(), m_variableInstructions.end(), [variable](const LLVMInstruction *ins) { - return ins->type == LLVMInstruction::Type::WriteVariable && ins->targetVariable == variable && !ins->loopScope; - }); - - if (it == m_variableInstructions.end()) { - createValueCopy(ptr, varPtr.stackPtr); - varPtr.onStack = true; - } else - varPtr.onStack = false; // use heap before the first assignment - */ + // Store variables locally to enable optimizations + varPtr.stackPtr = m_builder.CreateAlloca(m_valueDataType); + varPtr.changed = m_builder.CreateAlloca(m_builder.getInt1Ty()); + + // Integer support + varPtr.isInt = m_builder.CreateAlloca(m_builder.getInt1Ty(), nullptr, var->name() + ".isInt"); + varPtr.intValue = m_builder.CreateAlloca(m_builder.getInt64Ty(), nullptr, var->name() + ".intValue"); + m_builder.CreateStore(m_builder.getInt1(false), varPtr.isInt); + m_builder.CreateStore(llvm::ConstantInt::get(m_builder.getInt64Ty(), 0, true), varPtr.isInt); } // Create list pointers @@ -101,9 +106,17 @@ void LLVMBuildUtils::init(llvm::Function *function, BlockPrototype *procedurePro llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr); m_builder.CreateStore(size, listPtr.size); + + // Store list type info locally to leave static type analysis to LLVM + listPtr.hasNumber = m_builder.CreateAlloca(m_builder.getInt1Ty(), nullptr, list->name() + ".hasNumber"); + listPtr.hasBool = m_builder.CreateAlloca(m_builder.getInt1Ty(), nullptr, list->name() + ".hasBool"); + listPtr.hasString = m_builder.CreateAlloca(m_builder.getInt1Ty(), nullptr, list->name() + ".hasString"); } } + reloadVariables(); + reloadLists(); + // Create end branch m_endBranch = llvm::BasicBlock::Create(m_llvmCtx, "end", m_function); } @@ -113,10 +126,15 @@ void LLVMBuildUtils::end(LLVMInstruction *lastInstruction, LLVMRegister *lastCon assert(m_stringHeap.size() == 1); freeScopeHeap(); + // Sync + llvm::BasicBlock *syncBranch = llvm::BasicBlock::Create(m_llvmCtx, "sync", m_function); + m_builder.CreateBr(syncBranch); + + m_builder.SetInsertPoint(syncBranch); + syncVariables(); m_builder.CreateBr(m_endBranch); m_builder.SetInsertPoint(m_endBranch); - syncVariables(m_targetVariables); // End the script function llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0); @@ -285,6 +303,12 @@ LLVMCoroutine *LLVMBuildUtils::coroutine() const return m_coroutine.get(); } +void LLVMBuildUtils::createLocalVariableInfo(CompilerLocalVariable *variable) +{ + if (m_localVariables.find(variable) == m_localVariables.cend()) + m_localVariables[variable] = LLVMLocalVariableInfo(); +} + void LLVMBuildUtils::createVariablePtr(Variable *variable) { if (m_variablePtrs.find(variable) == m_variablePtrs.cend()) @@ -297,6 +321,12 @@ void LLVMBuildUtils::createListPtr(List *list) m_listPtrs[list] = LLVMListPtr(); } +LLVMLocalVariableInfo &LLVMBuildUtils::localVariableInfo(CompilerLocalVariable *variable) +{ + assert(m_localVariables.find(variable) != m_localVariables.cend()); + return m_localVariables[variable]; +} + LLVMVariablePtr &LLVMBuildUtils::variablePtr(Variable *variable) { assert(m_variablePtrs.find(variable) != m_variablePtrs.cend()); @@ -309,34 +339,48 @@ LLVMListPtr &LLVMBuildUtils::listPtr(List *list) return m_listPtrs[list]; } -void LLVMBuildUtils::syncVariables(llvm::Value *targetVariables) +void LLVMBuildUtils::syncVariables() { // Copy stack variables to the actual variables for (auto &[var, varPtr] : m_variablePtrs) { - if (varPtr.onStack && varPtr.changed) - createValueCopy(varPtr.stackPtr, getVariablePtr(targetVariables, var)); + llvm::BasicBlock *copyBlock = llvm::BasicBlock::Create(m_llvmCtx, "syncVar", m_function); + llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "syncVar.next", m_function); + m_builder.CreateCondBr(m_builder.CreateLoad(m_builder.getInt1Ty(), varPtr.changed), copyBlock, nextBlock); + + m_builder.SetInsertPoint(copyBlock); + createValueCopy(varPtr.stackPtr, getVariablePtr(m_targetVariables, var)); + m_builder.CreateStore(m_builder.getInt1(false), varPtr.changed); + m_builder.CreateBr(nextBlock); - varPtr.changed = false; + m_builder.SetInsertPoint(nextBlock); } } -void LLVMBuildUtils::reloadVariables(llvm::Value *targetVariables) +void LLVMBuildUtils::reloadVariables() { - // Reset variables to use heap + // Load variables to stack for (auto &[var, varPtr] : m_variablePtrs) { - varPtr.onStack = false; - varPtr.changed = false; + llvm::Value *ptr = getVariablePtr(m_targetVariables, var); + createValueCopy(ptr, varPtr.stackPtr); + m_builder.CreateStore(m_builder.getInt1(false), varPtr.isInt); + m_builder.CreateStore(m_builder.getInt1(false), varPtr.changed); } } void LLVMBuildUtils::reloadLists() { - // Load list size info - if (m_warp) { - for (auto &[list, listPtr] : m_listPtrs) { + // Load list size and type info + for (auto &[list, listPtr] : m_listPtrs) { + if (listPtr.size) { llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr); m_builder.CreateStore(size, listPtr.size); } + + if (listPtr.hasNumber && listPtr.hasBool && listPtr.hasString) { + m_builder.CreateStore(m_builder.getInt1(true), listPtr.hasNumber); + m_builder.CreateStore(m_builder.getInt1(true), listPtr.hasBool); + m_builder.CreateStore(m_builder.getInt1(true), listPtr.hasString); + } } } @@ -402,6 +446,12 @@ Compiler::StaticType LLVMBuildUtils::mapType(ValueType type) return TYPE_MAP[type]; } +ValueType LLVMBuildUtils::mapType(Compiler::StaticType type) +{ + assert(REVERSE_TYPE_MAP.find(type) != REVERSE_TYPE_MAP.cend()); + return REVERSE_TYPE_MAP[type]; +} + bool LLVMBuildUtils::isSingleType(Compiler::StaticType type) { // Check if the type is a power of 2 (only one bit set) @@ -418,17 +468,17 @@ llvm::Value *LLVMBuildUtils::addAlloca(llvm::Type *type) return ret; } -llvm::Value *LLVMBuildUtils::castValue(LLVMRegister *reg, Compiler::StaticType targetType) +llvm::Value *LLVMBuildUtils::castValue(LLVMRegister *reg, Compiler::StaticType targetType, NumberType targetNumType) { if (reg->isConst()) { if (!isSingleType(targetType)) return createValue(reg); else - return castConstValue(reg->constValue(), targetType); + return castConstValue(reg->constValue(), targetType, targetNumType); } if (reg->isRawValue) - return castRawValue(reg, targetType); + return castRawValue(reg, targetType, targetNumType); assert(reg->type() != Compiler::StaticType::Void && targetType != Compiler::StaticType::Void); @@ -441,7 +491,7 @@ llvm::Value *LLVMBuildUtils::castValue(LLVMRegister *reg, Compiler::StaticType t if (isSingleType(targetType)) { // Handle multiple source type cases with runtime switch - typePtr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 1); + typePtr = getValueTypePtr(reg); loadedType = m_builder.CreateLoad(m_builder.getInt32Ty(), typePtr); mergeBlock = llvm::BasicBlock::Create(m_llvmCtx, "merge", m_function); @@ -460,8 +510,20 @@ llvm::Value *LLVMBuildUtils::castValue(LLVMRegister *reg, Compiler::StaticType t sw->addCase(m_builder.getInt32(static_cast(ValueType::Number)), numberBlock); m_builder.SetInsertPoint(numberBlock); - llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); - llvm::Value *numberResult = m_builder.CreateLoad(m_builder.getDoubleTy(), ptr); + llvm::Value *numberResult; + + if (targetNumType == NumberType::Int) { + // double/int -> int + llvm::Value *doubleInt = m_builder.CreateFPToSI(reg->intValue, m_builder.getInt64Ty()); + numberResult = m_builder.CreateSelect(reg->isInt, reg->intValue, doubleInt); + } else { + // double/int -> double + llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); + llvm::Value *doubleValue = m_builder.CreateLoad(m_builder.getDoubleTy(), ptr); + llvm::Value *intDouble = m_builder.CreateSIToFP(reg->intValue, m_builder.getDoubleTy()); + numberResult = m_builder.CreateSelect(reg->isInt, intDouble, doubleValue); + } + m_builder.CreateBr(mergeBlock); results.push_back({ numberBlock, numberResult }); } @@ -474,7 +536,16 @@ llvm::Value *LLVMBuildUtils::castValue(LLVMRegister *reg, Compiler::StaticType t m_builder.SetInsertPoint(boolBlock); llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); llvm::Value *boolValue = m_builder.CreateLoad(m_builder.getInt1Ty(), ptr); - llvm::Value *boolResult = m_builder.CreateUIToFP(boolValue, m_builder.getDoubleTy()); + llvm::Value *boolResult; + + if (targetNumType == NumberType::Int) { + // bool -> int + boolResult = m_builder.CreateZExt(boolValue, m_builder.getInt64Ty()); + } else { + // bool -> double + boolResult = m_builder.CreateUIToFP(boolValue, m_builder.getDoubleTy()); + } + m_builder.CreateBr(mergeBlock); results.push_back({ boolBlock, boolResult }); } @@ -488,6 +559,10 @@ llvm::Value *LLVMBuildUtils::castValue(LLVMRegister *reg, Compiler::StaticType t llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); llvm::Value *stringPtr = m_builder.CreateLoad(m_stringPtrType->getPointerTo(), ptr); llvm::Value *stringResult = m_builder.CreateCall(m_functions.resolve_value_stringToDouble(), stringPtr); + + if (targetNumType == NumberType::Int) + stringResult = m_builder.CreateFPToSI(stringResult, m_builder.getInt64Ty()); + m_builder.CreateBr(mergeBlock); results.push_back({ stringBlock, stringResult }); } @@ -504,7 +579,11 @@ llvm::Value *LLVMBuildUtils::castValue(LLVMRegister *reg, Compiler::StaticType t m_builder.SetInsertPoint(numberBlock); llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); llvm::Value *numberValue = m_builder.CreateLoad(m_builder.getDoubleTy(), ptr); - llvm::Value *numberResult = m_builder.CreateFCmpONE(numberValue, llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0))); + + llvm::Value *intResult = m_builder.CreateICmpNE(reg->intValue, llvm::ConstantInt::get(m_builder.getInt64Ty(), 0, true)); + llvm::Value *doubleResult = m_builder.CreateFCmpONE(numberValue, llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0))); + llvm::Value *numberResult = m_builder.CreateSelect(reg->isInt, intResult, doubleResult); + m_builder.CreateBr(mergeBlock); results.push_back({ numberBlock, numberResult }); } @@ -545,7 +624,11 @@ llvm::Value *LLVMBuildUtils::castValue(LLVMRegister *reg, Compiler::StaticType t m_builder.SetInsertPoint(numberBlock); llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); - llvm::Value *value = m_builder.CreateLoad(m_builder.getDoubleTy(), ptr); + llvm::Value *doubleValue = m_builder.CreateLoad(m_builder.getDoubleTy(), ptr); + + llvm::Value *intCast = m_builder.CreateSIToFP(reg->intValue, m_builder.getDoubleTy()); + llvm::Value *value = m_builder.CreateSelect(reg->isInt, intCast, doubleValue); + llvm::Value *numberResult = m_builder.CreateCall(m_functions.resolve_value_doubleToStringPtr(), value); m_builder.CreateBr(mergeBlock); results.push_back({ numberBlock, numberResult }); @@ -597,7 +680,12 @@ llvm::Value *LLVMBuildUtils::castValue(LLVMRegister *reg, Compiler::StaticType t // Create phi node to merge results m_builder.SetInsertPoint(mergeBlock); - llvm::PHINode *result = m_builder.CreatePHI(getType(targetType, false), results.size()); + llvm::Type *phiType = getType(targetType, false); + + if (targetType == Compiler::StaticType::Number && targetNumType == NumberType::Int) + phiType = m_builder.getInt64Ty(); + + llvm::PHINode *result = m_builder.CreatePHI(phiType, results.size()); for (auto &pair : results) result->addIncoming(pair.second, pair.first); @@ -634,6 +722,7 @@ llvm::Type *LLVMBuildUtils::getType(Compiler::StaticType type, bool isReturnType llvm::Value *LLVMBuildUtils::isNaN(llvm::Value *num) { + assert(num->getType() == m_builder.getDoubleTy()); return m_builder.CreateFCmpUNO(num, num); } @@ -643,30 +732,20 @@ llvm::Value *LLVMBuildUtils::removeNaN(llvm::Value *num) return m_builder.CreateSelect(isNaN(num), llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0)), num); } -void LLVMBuildUtils::createValueStore(LLVMRegister *reg, llvm::Value *destPtr, Compiler::StaticType destType, Compiler::StaticType targetType) +void LLVMBuildUtils::createValueStore( + llvm::Value *destPtr, + llvm::Value *destTypePtr, + llvm::Value *destIsIntVar, + llvm::Value *destIntVar, + LLVMRegister *reg, + Compiler::StaticType destType, + Compiler::StaticType targetType) { - llvm::Value *targetPtr = nullptr; - const bool targetTypeIsSingle = isSingleType(targetType); - - if (targetTypeIsSingle) - targetPtr = castValue(reg, targetType); - - auto it = std::find_if(TYPE_MAP.begin(), TYPE_MAP.end(), [targetType](const std::pair &pair) { return pair.second == targetType; }); - const ValueType mappedType = it == TYPE_MAP.cend() ? ValueType::Number : it->first; // unknown type can be ignored - assert(!(reg->isRawValue && it == TYPE_MAP.cend())); + assert(destIsIntVar->getType()->isPointerTy()); + assert(destIntVar->getType()->isPointerTy()); // Handle multiple type cases with runtime switch - llvm::Value *loadedTargetType = nullptr; - - if (reg->isRawValue) - loadedTargetType = m_builder.getInt32(static_cast(mappedType)); - else { - assert(!reg->isConst()); - llvm::Value *targetTypePtr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 1); - loadedTargetType = m_builder.CreateLoad(m_builder.getInt32Ty(), targetTypePtr); - } - - llvm::Value *destTypePtr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 1); + llvm::Value *loadedTargetType = loadRegisterType(reg, targetType); llvm::Value *loadedDestType = m_builder.CreateLoad(m_builder.getInt32Ty(), destTypePtr); llvm::BasicBlock *mergeBlock = llvm::BasicBlock::Create(m_llvmCtx, "merge", m_function); @@ -689,14 +768,15 @@ void LLVMBuildUtils::createValueStore(LLVMRegister *reg, llvm::Value *destPtr, C m_builder.SetInsertPoint(numberBlock); // Load number - if (!targetTypeIsSingle) { - llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); - targetPtr = m_builder.CreateLoad(m_builder.getDoubleTy(), ptr); - } + llvm::Value *doubleValue = castValue(reg, Compiler::StaticType::Number, NumberType::Double); // Write number to number directly llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 0); - m_builder.CreateStore(targetPtr, ptr); + m_builder.CreateStore(doubleValue, ptr); + + m_builder.CreateStore(reg->isInt, destIsIntVar); + m_builder.CreateStore(reg->intValue, destIntVar); + m_builder.CreateBr(mergeBlock); } @@ -707,16 +787,12 @@ void LLVMBuildUtils::createValueStore(LLVMRegister *reg, llvm::Value *destPtr, C m_builder.SetInsertPoint(boolBlock); // Load bool - if (!targetTypeIsSingle) { - llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); - targetPtr = m_builder.CreateLoad(m_builder.getInt1Ty(), ptr); - } + llvm::Value *targetPtr = castValue(reg, Compiler::StaticType::Bool); // Write bool to number value directly and change type llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 0); m_builder.CreateStore(targetPtr, ptr); - llvm::Value *typePtr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 1); - m_builder.CreateStore(m_builder.getInt32(static_cast(ValueType::Bool)), typePtr); + m_builder.CreateStore(m_builder.getInt32(static_cast(ValueType::Bool)), destTypePtr); m_builder.CreateBr(mergeBlock); } @@ -727,17 +803,13 @@ void LLVMBuildUtils::createValueStore(LLVMRegister *reg, llvm::Value *destPtr, C m_builder.SetInsertPoint(stringBlock); // Load string pointer - if (!targetTypeIsSingle) { - llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); - targetPtr = m_builder.CreateLoad(m_stringPtrType->getPointerTo(), ptr); - } + llvm::Value *targetPtr = castValue(reg, Compiler::StaticType::String); // Create a new string, change type and assign llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 0); llvm::Value *destStringPtr = m_builder.CreateCall(m_functions.resolve_string_pool_new(), m_builder.getInt1(false)); - llvm::Value *typePtr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 1); - m_builder.CreateStore(m_builder.getInt32(static_cast(ValueType::String)), typePtr); + m_builder.CreateStore(m_builder.getInt32(static_cast(ValueType::String)), destTypePtr); m_builder.CreateStore(destStringPtr, ptr); m_builder.CreateCall(m_functions.resolve_string_assign(), { destStringPtr, targetPtr }); @@ -760,16 +832,16 @@ void LLVMBuildUtils::createValueStore(LLVMRegister *reg, llvm::Value *destPtr, C m_builder.SetInsertPoint(numberBlock); // Load number - if (!targetTypeIsSingle) { - llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); - targetPtr = m_builder.CreateLoad(m_builder.getDoubleTy(), ptr); - } + llvm::Value *doubleValue = castValue(reg, Compiler::StaticType::Number, NumberType::Double); // Write number to bool value directly and change type llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 0); - llvm::Value *typePtr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 1); - m_builder.CreateStore(targetPtr, ptr); - m_builder.CreateStore(m_builder.getInt32(static_cast(ValueType::Number)), typePtr); + m_builder.CreateStore(doubleValue, ptr); + m_builder.CreateStore(m_builder.getInt32(static_cast(ValueType::Number)), destTypePtr); + + m_builder.CreateStore(reg->isInt, destIsIntVar); + m_builder.CreateStore(reg->intValue, destIntVar); + m_builder.CreateBr(mergeBlock); } @@ -780,10 +852,7 @@ void LLVMBuildUtils::createValueStore(LLVMRegister *reg, llvm::Value *destPtr, C m_builder.SetInsertPoint(boolBlock); // Load bool - if (!targetTypeIsSingle) { - llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); - targetPtr = m_builder.CreateLoad(m_builder.getInt1Ty(), ptr); - } + llvm::Value *targetPtr = castValue(reg, Compiler::StaticType::Bool); // Write bool to bool directly llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 0); @@ -798,17 +867,13 @@ void LLVMBuildUtils::createValueStore(LLVMRegister *reg, llvm::Value *destPtr, C m_builder.SetInsertPoint(stringBlock); // Load string pointer - if (!targetTypeIsSingle) { - llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); - targetPtr = m_builder.CreateLoad(m_stringPtrType->getPointerTo(), ptr); - } + llvm::Value *targetPtr = castValue(reg, Compiler::StaticType::String); // Create a new string, change type and assign llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 0); llvm::Value *destStringPtr = m_builder.CreateCall(m_functions.resolve_string_pool_new(), m_builder.getInt1(false)); - llvm::Value *typePtr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 1); - m_builder.CreateStore(m_builder.getInt32(static_cast(ValueType::String)), typePtr); + m_builder.CreateStore(m_builder.getInt32(static_cast(ValueType::String)), destTypePtr); m_builder.CreateStore(destStringPtr, ptr); m_builder.CreateCall(m_functions.resolve_string_assign(), { destStringPtr, targetPtr }); @@ -831,18 +896,18 @@ void LLVMBuildUtils::createValueStore(LLVMRegister *reg, llvm::Value *destPtr, C m_builder.SetInsertPoint(numberBlock); // Load number - if (!targetTypeIsSingle) { - llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); - targetPtr = m_builder.CreateLoad(m_builder.getDoubleTy(), ptr); - } + llvm::Value *doubleValue = castValue(reg, Compiler::StaticType::Number, NumberType::Double); // Free the string, write the number and change type llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 0); llvm::Value *destStringPtr = m_builder.CreateLoad(m_stringPtrType->getPointerTo(), ptr); m_builder.CreateCall(m_functions.resolve_string_pool_free(), destStringPtr); - m_builder.CreateStore(targetPtr, ptr); - llvm::Value *typePtr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 1); - m_builder.CreateStore(m_builder.getInt32(static_cast(ValueType::Number)), typePtr); + m_builder.CreateStore(doubleValue, ptr); + m_builder.CreateStore(m_builder.getInt32(static_cast(ValueType::Number)), destTypePtr); + + m_builder.CreateStore(reg->isInt, destIsIntVar); + m_builder.CreateStore(reg->intValue, destIntVar); + m_builder.CreateBr(mergeBlock); } @@ -853,18 +918,14 @@ void LLVMBuildUtils::createValueStore(LLVMRegister *reg, llvm::Value *destPtr, C m_builder.SetInsertPoint(boolBlock); // Load bool - if (!targetTypeIsSingle) { - llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); - targetPtr = m_builder.CreateLoad(m_builder.getInt1Ty(), ptr); - } + llvm::Value *targetPtr = castValue(reg, Compiler::StaticType::Bool); // Free the string, write the bool and change type llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 0); llvm::Value *destStringPtr = m_builder.CreateLoad(m_stringPtrType->getPointerTo(), ptr); m_builder.CreateCall(m_functions.resolve_string_pool_free(), destStringPtr); m_builder.CreateStore(targetPtr, ptr); - llvm::Value *typePtr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 1); - m_builder.CreateStore(m_builder.getInt32(static_cast(ValueType::Bool)), typePtr); + m_builder.CreateStore(m_builder.getInt32(static_cast(ValueType::Bool)), destTypePtr); m_builder.CreateBr(mergeBlock); } @@ -875,10 +936,7 @@ void LLVMBuildUtils::createValueStore(LLVMRegister *reg, llvm::Value *destPtr, C m_builder.SetInsertPoint(stringBlock); // Load string pointer - if (!targetTypeIsSingle) { - llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, reg->value, 0); - targetPtr = m_builder.CreateLoad(m_stringPtrType->getPointerTo(), ptr); - } + llvm::Value *targetPtr = castValue(reg, Compiler::StaticType::String); // Assign string directly llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, destPtr, 0); @@ -895,10 +953,23 @@ void LLVMBuildUtils::createValueStore(LLVMRegister *reg, llvm::Value *destPtr, C m_builder.SetInsertPoint(mergeBlock); } -void LLVMBuildUtils::createValueStore(LLVMRegister *reg, llvm::Value *destPtr, Compiler::StaticType targetType) +void LLVMBuildUtils::createValueStore(llvm::Value *destPtr, llvm::Value *destTypePtr, llvm::Value *destIsIntVar, llvm::Value *destIntVar, LLVMRegister *reg, Compiler::StaticType targetType) { // Same as createValueStore(), but the destination type is unknown at compile time - createValueStore(reg, destPtr, Compiler::StaticType::Unknown, targetType); + createValueStore(destPtr, destTypePtr, destIsIntVar, destIntVar, reg, Compiler::StaticType::Unknown, targetType); +} + +llvm::Value *LLVMBuildUtils::getValueTypePtr(llvm::Value *value) +{ + return m_builder.CreateStructGEP(m_valueDataType, value, 1); +} + +llvm::Value *LLVMBuildUtils::getValueTypePtr(LLVMRegister *reg) +{ + if (reg->typeVar) + return reg->typeVar; + else + return getValueTypePtr(reg->value); } llvm::Value *LLVMBuildUtils::getListSize(const LLVMListPtr &listPtr) @@ -932,10 +1003,13 @@ llvm::Value *LLVMBuildUtils::getListItemIndex(const LLVMListPtr &listPtr, Compil m_builder.CreateCondBr(cond, bodyBlock, notFoundBlock); // if (list[index] == item) + // TODO: Add integer support for lists m_builder.SetInsertPoint(bodyBlock); LLVMRegister currentItem(listType); currentItem.isRawValue = false; currentItem.value = getListItem(listPtr, m_builder.CreateLoad(m_builder.getInt64Ty(), index)); + currentItem.isInt = m_builder.getInt1(false); + currentItem.intValue = llvm::ConstantInt::get(m_builder.getInt64Ty(), 0, true); llvm::Value *cmp = createComparison(¤tItem, item, Comparison::EQ); m_builder.CreateCondBr(cmp, cmpIfBlock, cmpElseBlock); @@ -965,7 +1039,7 @@ llvm::Value *LLVMBuildUtils::createValue(LLVMRegister *reg) { if (reg->isConst()) { // Create a constant ValueData instance and store it - llvm::Constant *value = castConstValue(reg->constValue(), TYPE_MAP[reg->constValue().type()]); + llvm::Constant *value = castConstValue(reg->constValue(), TYPE_MAP[reg->constValue().type()], NumberType::Double); llvm::Value *ret = addAlloca(m_valueDataType); switch (reg->constValue().type()) { @@ -995,7 +1069,7 @@ llvm::Value *LLVMBuildUtils::createValue(LLVMRegister *reg) return ret; } else if (reg->isRawValue) { - llvm::Value *value = castRawValue(reg, reg->type()); + llvm::Value *value = castRawValue(reg, reg->type(), NumberType::Double); llvm::Value *ret = addAlloca(m_valueDataType); // Store value @@ -1010,7 +1084,7 @@ llvm::Value *LLVMBuildUtils::createValue(LLVMRegister *reg) } // Store type - llvm::Value *typeField = m_builder.CreateStructGEP(m_valueDataType, ret, 1); + llvm::Value *typeField = getValueTypePtr(ret); ValueType type = it->first; m_builder.CreateStore(m_builder.getInt32(static_cast(type)), typeField); @@ -1059,26 +1133,6 @@ llvm::Value *LLVMBuildUtils::createComparison(LLVMRegister *arg1, LLVMRegister * return m_builder.getInt1(result); } else { - // Optimize comparison of constant with number/bool - if (arg1->isConst() && arg1->constValue().isValidNumber() && (type2 == Compiler::StaticType::Number || type2 == Compiler::StaticType::Bool)) - type1 = Compiler::StaticType::Number; - - if (arg2->isConst() && arg2->constValue().isValidNumber() && (type1 == Compiler::StaticType::Number || type1 == Compiler::StaticType::Bool)) - type2 = Compiler::StaticType::Number; - - // Optimize number and bool comparison - int optNumberBool = 0; - - if (type1 == Compiler::StaticType::Number && type2 == Compiler::StaticType::Bool) { - type2 = Compiler::StaticType::Number; - optNumberBool = 2; // operand 2 was bool - } - - if (type1 == Compiler::StaticType::Bool && type2 == Compiler::StaticType::Number) { - type1 = Compiler::StaticType::Number; - optNumberBool = 1; // operand 1 was bool - } - // Optimize number and string constant comparison // TODO: GT and LT comparison can be optimized here (e. g. by checking the string constant characters and comparing with numbers and .+-e) if (type == Comparison::EQ) { @@ -1087,132 +1141,162 @@ llvm::Value *LLVMBuildUtils::createComparison(LLVMRegister *arg1, LLVMRegister * return m_builder.getInt1(false); } - if (type1 != type2 || !isSingleType(type1) || !isSingleType(type2)) { - // If the types are different or at least one of them - // is unknown, we must use value functions - llvm::Value *value1 = createValue(arg1); - llvm::Value *value2 = createValue(arg2); + // Handle multiple type cases with runtime switch + llvm::Value *loadedType1 = loadRegisterType(arg1, type1); + llvm::Value *loadedType2 = loadRegisterType(arg2, type2); - switch (type) { - case Comparison::EQ: - return m_builder.CreateCall(m_functions.resolve_value_equals(), { value1, value2 }); + llvm::BasicBlock *mergeBlock = llvm::BasicBlock::Create(m_llvmCtx, "merge", m_function); + llvm::BasicBlock *defaultBlock = llvm::BasicBlock::Create(m_llvmCtx, "default", m_function); + std::vector> results; - case Comparison::GT: - return m_builder.CreateCall(m_functions.resolve_value_greater(), { value1, value2 }); + llvm::SwitchInst *sw1 = m_builder.CreateSwitch(loadedType1, defaultBlock, 4); - case Comparison::LT: - return m_builder.CreateCall(m_functions.resolve_value_lower(), { value1, value2 }); + if ((type1 & Compiler::StaticType::Number) == Compiler::StaticType::Number) { + llvm::BasicBlock *numberBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + sw1->addCase(m_builder.getInt32(static_cast(ValueType::Number)), numberBlock); + m_builder.SetInsertPoint(numberBlock); - default: - assert(false); - return nullptr; + llvm::SwitchInst *sw2 = m_builder.CreateSwitch(loadedType2, defaultBlock, 4); + + if ((type2 & Compiler::StaticType::Number) == Compiler::StaticType::Number) { + // Number and number comparison + llvm::BasicBlock *block = llvm::BasicBlock::Create(m_llvmCtx, "numberAndNumberComparison", m_function); + sw2->addCase(m_builder.getInt32(static_cast(ValueType::Number)), block); + m_builder.SetInsertPoint(block); + + llvm::Value *result = createNumberAndNumberComparison(arg1, arg2, type); + + results.push_back({ m_builder.GetInsertBlock(), result }); + m_builder.CreateBr(mergeBlock); } - } else { - // Compare raw values - llvm::Value *value1 = castValue(arg1, type1); - llvm::Value *value2 = castValue(arg2, type2); - assert(type1 == type2); - switch (type1) { - case Compiler::StaticType::Number: { - // Compare two numbers - switch (type) { - case Comparison::EQ: { - llvm::Value *nan = m_builder.CreateAnd(isNaN(value1), isNaN(value2)); // NaN == NaN - llvm::Value *cmp = m_builder.CreateFCmpOEQ(value1, value2); - return m_builder.CreateSelect(nan, m_builder.getInt1(true), cmp); - } - - case Comparison::GT: { - llvm::Value *bothNan = m_builder.CreateAnd(isNaN(value1), isNaN(value2)); // NaN == NaN - llvm::Value *cmp = m_builder.CreateFCmpOGT(value1, value2); - llvm::Value *nan; - llvm::Value *nanCmp; - - if (optNumberBool == 1) { - nan = isNaN(value2); - nanCmp = castValue(arg1, Compiler::StaticType::Bool); - } else if (optNumberBool == 2) { - nan = isNaN(value1); - nanCmp = m_builder.CreateNot(castValue(arg2, Compiler::StaticType::Bool)); - } else { - nan = isNaN(value1); - nanCmp = m_builder.CreateFCmpUGT(value1, value2); - } - - return m_builder.CreateAnd(m_builder.CreateNot(bothNan), m_builder.CreateSelect(nan, nanCmp, cmp)); - } - - case Comparison::LT: { - llvm::Value *bothNan = m_builder.CreateAnd(isNaN(value1), isNaN(value2)); // NaN == NaN - llvm::Value *cmp = m_builder.CreateFCmpOLT(value1, value2); - llvm::Value *nan; - llvm::Value *nanCmp; - - if (optNumberBool == 1) { - nan = isNaN(value2); - nanCmp = m_builder.CreateNot(castValue(arg1, Compiler::StaticType::Bool)); - } else if (optNumberBool == 2) { - nan = isNaN(value1); - nanCmp = castValue(arg2, Compiler::StaticType::Bool); - } else { - nan = isNaN(value2); - nanCmp = m_builder.CreateFCmpULT(value1, value2); - } - - return m_builder.CreateAnd(m_builder.CreateNot(bothNan), m_builder.CreateSelect(nan, nanCmp, cmp)); - } - - default: - assert(false); - return nullptr; - } - } + if ((type2 & Compiler::StaticType::Bool) == Compiler::StaticType::Bool) { + // Number and bool comparison + llvm::BasicBlock *block = llvm::BasicBlock::Create(m_llvmCtx, "numberAndBoolComparison", m_function); + sw2->addCase(m_builder.getInt32(static_cast(ValueType::Bool)), block); + m_builder.SetInsertPoint(block); - case Compiler::StaticType::Bool: - // Compare two booleans - switch (type) { - case Comparison::EQ: - return m_builder.CreateICmpEQ(value1, value2); + llvm::Value *result = createNumberAndBoolComparison(arg1, arg2, type); - case Comparison::GT: - // value1 && !value2 - return m_builder.CreateAnd(value1, m_builder.CreateNot(value2)); + results.push_back({ m_builder.GetInsertBlock(), result }); + m_builder.CreateBr(mergeBlock); + } - case Comparison::LT: - // value2 && !value1 - return m_builder.CreateAnd(value2, m_builder.CreateNot(value1)); + if ((type2 & Compiler::StaticType::String) == Compiler::StaticType::String) { + // Number and string comparison + llvm::BasicBlock *block = llvm::BasicBlock::Create(m_llvmCtx, "numberAndStringComparison", m_function); + sw2->addCase(m_builder.getInt32(static_cast(ValueType::String)), block); + m_builder.SetInsertPoint(block); - default: - assert(false); - return nullptr; - } + llvm::Value *result = createNumberAndStringComparison(arg1, arg2, type); - case Compiler::StaticType::String: { - // Compare two strings - llvm::Value *cmpRet = m_builder.CreateCall(m_functions.resolve_string_compare_case_insensitive(), { value1, value2 }); + results.push_back({ m_builder.GetInsertBlock(), result }); + m_builder.CreateBr(mergeBlock); + } + } - switch (type) { - case Comparison::EQ: - return m_builder.CreateICmpEQ(cmpRet, m_builder.getInt32(0)); + if ((type1 & Compiler::StaticType::Bool) == Compiler::StaticType::Bool) { + llvm::BasicBlock *boolBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + sw1->addCase(m_builder.getInt32(static_cast(ValueType::Bool)), boolBlock); + m_builder.SetInsertPoint(boolBlock); - case Comparison::GT: - return m_builder.CreateICmpSGT(cmpRet, m_builder.getInt32(0)); + llvm::SwitchInst *sw2 = m_builder.CreateSwitch(loadedType2, defaultBlock, 4); - case Comparison::LT: - return m_builder.CreateICmpSLT(cmpRet, m_builder.getInt32(0)); + if ((type2 & Compiler::StaticType::Number) == Compiler::StaticType::Number) { + // Bool and number comparison + llvm::BasicBlock *block = llvm::BasicBlock::Create(m_llvmCtx, "boolAndNumberComparison", m_function); + sw2->addCase(m_builder.getInt32(static_cast(ValueType::Number)), block); + m_builder.SetInsertPoint(block); - default: - assert(false); - return nullptr; - } - } + llvm::Value *result = createNumberAndBoolComparison(arg2, arg1, swapComparisonArgs(type)); - default: - assert(false); - return nullptr; + results.push_back({ m_builder.GetInsertBlock(), result }); + m_builder.CreateBr(mergeBlock); + } + + if ((type2 & Compiler::StaticType::Bool) == Compiler::StaticType::Bool) { + // Bool and bool comparison + llvm::BasicBlock *block = llvm::BasicBlock::Create(m_llvmCtx, "boolAndBoolComparison", m_function); + sw2->addCase(m_builder.getInt32(static_cast(ValueType::Bool)), block); + m_builder.SetInsertPoint(block); + + llvm::Value *result = createBoolAndBoolComparison(arg1, arg2, type); + + results.push_back({ m_builder.GetInsertBlock(), result }); + m_builder.CreateBr(mergeBlock); + } + + if ((type2 & Compiler::StaticType::String) == Compiler::StaticType::String) { + // Bool and string comparison + llvm::BasicBlock *block = llvm::BasicBlock::Create(m_llvmCtx, "boolAndStringComparison", m_function); + sw2->addCase(m_builder.getInt32(static_cast(ValueType::String)), block); + m_builder.SetInsertPoint(block); + + llvm::Value *result = createBoolAndStringComparison(arg1, arg2, type); + + results.push_back({ m_builder.GetInsertBlock(), result }); + m_builder.CreateBr(mergeBlock); } } + + if ((type1 & Compiler::StaticType::String) == Compiler::StaticType::String) { + llvm::BasicBlock *stringBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + sw1->addCase(m_builder.getInt32(static_cast(ValueType::String)), stringBlock); + m_builder.SetInsertPoint(stringBlock); + + llvm::SwitchInst *sw2 = m_builder.CreateSwitch(loadedType2, defaultBlock, 4); + + if ((type2 & Compiler::StaticType::Number) == Compiler::StaticType::Number) { + // String and number comparison + llvm::BasicBlock *block = llvm::BasicBlock::Create(m_llvmCtx, "stringAndNumberComparison", m_function); + sw2->addCase(m_builder.getInt32(static_cast(ValueType::Number)), block); + m_builder.SetInsertPoint(block); + + llvm::Value *result = createNumberAndStringComparison(arg2, arg1, swapComparisonArgs(type)); + + results.push_back({ m_builder.GetInsertBlock(), result }); + m_builder.CreateBr(mergeBlock); + } + + if ((type2 & Compiler::StaticType::Bool) == Compiler::StaticType::Bool) { + // String and bool comparison + llvm::BasicBlock *block = llvm::BasicBlock::Create(m_llvmCtx, "stringAndBoolComparison", m_function); + sw2->addCase(m_builder.getInt32(static_cast(ValueType::Bool)), block); + m_builder.SetInsertPoint(block); + + llvm::Value *result = createBoolAndStringComparison(arg2, arg1, swapComparisonArgs(type)); + + results.push_back({ m_builder.GetInsertBlock(), result }); + m_builder.CreateBr(mergeBlock); + } + + if ((type2 & Compiler::StaticType::String) == Compiler::StaticType::String) { + // String and string comparison + llvm::BasicBlock *block = llvm::BasicBlock::Create(m_llvmCtx, "stringAndStringComparison", m_function); + sw2->addCase(m_builder.getInt32(static_cast(ValueType::String)), block); + m_builder.SetInsertPoint(block); + + llvm::Value *result = createStringAndStringComparison(arg1, arg2, type); + + results.push_back({ m_builder.GetInsertBlock(), result }); + m_builder.CreateBr(mergeBlock); + } + } + + // Default case + m_builder.SetInsertPoint(defaultBlock); + + // All possible types are covered, mark as unreachable + m_builder.CreateUnreachable(); + + // Create phi node to merge results + m_builder.SetInsertPoint(mergeBlock); + llvm::PHINode *result = m_builder.CreatePHI(m_builder.getInt1Ty(), results.size()); + + for (auto &pair : results) + result->addIncoming(pair.second, pair.first); + + return result; } } @@ -1264,9 +1348,9 @@ void LLVMBuildUtils::createSuspend() m_builder.SetInsertPoint(suspendBranch); } - syncVariables(m_targetVariables); + syncVariables(); m_coroutine->createSuspend(); - reloadVariables(m_targetVariables); + reloadVariables(); reloadLists(); if (m_warpArg) { @@ -1342,21 +1426,45 @@ void LLVMBuildUtils::createListMap() } } -llvm::Value *LLVMBuildUtils::castRawValue(LLVMRegister *reg, Compiler::StaticType targetType) +llvm::Value *LLVMBuildUtils::loadRegisterType(LLVMRegister *reg, Compiler::StaticType type) { - if (reg->type() == targetType) - return reg->value; + if (reg->isRawValue) + return m_builder.getInt32(static_cast(mapType(type))); + else { + assert(!reg->isConst()); + llvm::Value *typePtr = getValueTypePtr(reg); + return m_builder.CreateLoad(m_builder.getInt32Ty(), typePtr); + } +} + +llvm::Value *LLVMBuildUtils::castRawValue(LLVMRegister *reg, Compiler::StaticType targetType, NumberType targetNumType) +{ + if (reg->type() == targetType) { + if (targetType == Compiler::StaticType::Number && targetNumType == NumberType::Int) { + llvm::Value *cast = m_builder.CreateFPToSI(reg->value, m_builder.getInt64Ty()); + return m_builder.CreateSelect(reg->isInt, reg->intValue, cast); + } else + return reg->value; + } switch (targetType) { case Compiler::StaticType::Number: switch (reg->type()) { case Compiler::StaticType::Bool: - // Cast bool to double - return m_builder.CreateUIToFP(reg->value, m_builder.getDoubleTy()); + // Cast bool to double/int + if (targetNumType == NumberType::Int) + return m_builder.CreateZExt(reg->value, m_builder.getInt64Ty()); + else + return m_builder.CreateUIToFP(reg->value, m_builder.getDoubleTy()); case Compiler::StaticType::String: { - // Convert string to double - return m_builder.CreateCall(m_functions.resolve_value_stringToDouble(), reg->value); + // Convert string to double/int + llvm::Value *doubleValue = m_builder.CreateCall(m_functions.resolve_value_stringToDouble(), reg->value); + + if (targetNumType == NumberType::Int) + return m_builder.CreateFPToSI(doubleValue, m_builder.getInt64Ty()); + else + return doubleValue; } default: @@ -1366,9 +1474,12 @@ llvm::Value *LLVMBuildUtils::castRawValue(LLVMRegister *reg, Compiler::StaticTyp case Compiler::StaticType::Bool: switch (reg->type()) { - case Compiler::StaticType::Number: - // Cast double to bool (true if != 0) - return m_builder.CreateFCmpONE(reg->value, llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0))); + case Compiler::StaticType::Number: { + // Cast double/int to bool (true if != 0) + llvm::Value *intResult = m_builder.CreateICmpNE(reg->intValue, llvm::ConstantInt::get(m_builder.getInt64Ty(), 0, true)); + llvm::Value *doubleResult = m_builder.CreateFCmpONE(reg->value, llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0))); + return m_builder.CreateSelect(reg->isInt, intResult, doubleResult); + } case Compiler::StaticType::String: // Convert string to bool @@ -1382,8 +1493,10 @@ llvm::Value *LLVMBuildUtils::castRawValue(LLVMRegister *reg, Compiler::StaticTyp case Compiler::StaticType::String: switch (reg->type()) { case Compiler::StaticType::Number: { - // Convert double to string - llvm::Value *ptr = m_builder.CreateCall(m_functions.resolve_value_doubleToStringPtr(), reg->value); + // Convert double/int to string + llvm::Value *intCast = m_builder.CreateSIToFP(reg->intValue, m_builder.getDoubleTy()); + llvm::Value *doubleValue = m_builder.CreateSelect(reg->isInt, intCast, reg->value); + llvm::Value *ptr = m_builder.CreateCall(m_functions.resolve_value_doubleToStringPtr(), doubleValue); freeStringLater(ptr); return ptr; } @@ -1409,12 +1522,17 @@ llvm::Value *LLVMBuildUtils::castRawValue(LLVMRegister *reg, Compiler::StaticTyp } } -llvm::Constant *LLVMBuildUtils::castConstValue(const Value &value, Compiler::StaticType targetType) +llvm::Constant *LLVMBuildUtils::castConstValue(const Value &value, Compiler::StaticType targetType, NumberType targetNumType) { switch (targetType) { case Compiler::StaticType::Number: { const double nan = std::numeric_limits::quiet_NaN(); - return llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(value.isNaN() ? nan : value.toDouble())); + const double num = value.toDouble(); + + if (targetNumType == NumberType::Int) + return llvm::ConstantInt::get(m_builder.getInt64Ty(), num, true); + else + return llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(value.isNaN() ? nan : num)); } case Compiler::StaticType::Bool: @@ -1466,6 +1584,332 @@ void LLVMBuildUtils::copyStructField(llvm::Value *source, llvm::Value *target, i m_builder.CreateStore(m_builder.CreateLoad(fieldType, sourceField), targetField); } +LLVMBuildUtils::Comparison LLVMBuildUtils::swapComparisonArgs(Comparison type) +{ + switch (type) { + case Comparison::GT: + return Comparison::LT; + + case Comparison::LT: + return Comparison::GT; + + default: + return Comparison::EQ; + } +} + +llvm::Value *LLVMBuildUtils::createNumberAndNumberComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type) +{ + llvm::Value *double1 = castValue(arg1, Compiler::StaticType::Number, NumberType::Double); + llvm::Value *double2 = castValue(arg2, Compiler::StaticType::Number, NumberType::Double); + + llvm::Value *int1 = castValue(arg1, Compiler::StaticType::Number, NumberType::Int); + llvm::Value *int2 = castValue(arg2, Compiler::StaticType::Number, NumberType::Int); + + llvm::Value *isInt = m_builder.CreateAnd(arg1->isInt, arg2->isInt); + llvm::Value *bothNan = m_builder.CreateAnd(isNaN(double1), isNaN(double2)); // NaN == NaN + + switch (type) { + case Comparison::EQ: { + llvm::Value *fcmp = m_builder.CreateFCmpOEQ(double1, double2); + llvm::Value *doubleResult = m_builder.CreateOr(bothNan, fcmp); + llvm::Value *icmp = m_builder.CreateICmpEQ(int1, int2); + return m_builder.CreateSelect(isInt, icmp, doubleResult); + } + + case Comparison::GT: { + llvm::Value *fcmp = m_builder.CreateFCmpOGT(double1, double2); + llvm::Value *nan = isNaN(double1); + llvm::Value *nanCmp = m_builder.CreateFCmpUGT(double1, double2); + llvm::Value *doubleResult = m_builder.CreateAnd(m_builder.CreateNot(bothNan), m_builder.CreateSelect(nan, nanCmp, fcmp)); + llvm::Value *icmp = m_builder.CreateICmpSGT(int1, int2); + return m_builder.CreateSelect(isInt, icmp, doubleResult); + } + + case Comparison::LT: { + llvm::Value *fcmp = m_builder.CreateFCmpOLT(double1, double2); + llvm::Value *nan = isNaN(double2); + llvm::Value *nanCmp = m_builder.CreateFCmpULT(double1, double2); + llvm::Value *doubleResult = m_builder.CreateAnd(m_builder.CreateNot(bothNan), m_builder.CreateSelect(nan, nanCmp, fcmp)); + llvm::Value *icmp = m_builder.CreateICmpSLT(int1, int2); + return m_builder.CreateSelect(isInt, icmp, doubleResult); + } + + default: + assert(false); + return nullptr; + } +} + +llvm::Value *LLVMBuildUtils::createBoolAndBoolComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type) +{ + llvm::Value *value1 = castValue(arg1, Compiler::StaticType::Bool); + llvm::Value *value2 = castValue(arg2, Compiler::StaticType::Bool); + + switch (type) { + case Comparison::EQ: + return m_builder.CreateICmpEQ(value1, value2); + + case Comparison::GT: + // value1 && !value2 + return m_builder.CreateAnd(value1, m_builder.CreateNot(value2)); + + case Comparison::LT: + // value2 && !value1 + return m_builder.CreateAnd(value2, m_builder.CreateNot(value1)); + + default: + assert(false); + return nullptr; + } +} + +llvm::Value *LLVMBuildUtils::createStringAndStringComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type) +{ + llvm::Value *value1 = castValue(arg1, Compiler::StaticType::String); + llvm::Value *value2 = castValue(arg2, Compiler::StaticType::String); + + llvm::Value *cmp = m_builder.CreateCall(m_functions.resolve_string_compare_case_insensitive(), { value1, value2 }); + llvm::Value *zero = llvm::ConstantInt::get(m_builder.getInt32Ty(), 0, true); + + switch (type) { + case Comparison::EQ: + return m_builder.CreateICmpEQ(cmp, zero); + + case Comparison::GT: + return m_builder.CreateICmpSGT(cmp, zero); + + case Comparison::LT: + return m_builder.CreateICmpSLT(cmp, zero); + + default: + assert(false); + return nullptr; + } +} + +llvm::Value *LLVMBuildUtils::createNumberAndBoolComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type) +{ + llvm::Value *doubleValue1 = castValue(arg1, Compiler::StaticType::Number, NumberType::Double); + llvm::Value *intValue1 = castValue(arg1, Compiler::StaticType::Number, NumberType::Int); + + llvm::Value *boolValue2 = castValue(arg2, Compiler::StaticType::Bool); + llvm::Value *intValue2 = castValue(arg2, Compiler::StaticType::Number, NumberType::Int); + + llvm::Value *doubleValue2 = m_builder.CreateUIToFP(boolValue2, m_builder.getDoubleTy()); + llvm::Value *isInt = arg1->isInt; + + switch (type) { + case Comparison::EQ: { + llvm::Value *fcmp = m_builder.CreateFCmpOEQ(doubleValue1, doubleValue2); + llvm::Value *icmp = m_builder.CreateICmpEQ(intValue1, intValue2); + return m_builder.CreateSelect(isInt, icmp, fcmp); + } + + case Comparison::GT: { + llvm::Value *fcmp = m_builder.CreateFCmpOGT(doubleValue1, doubleValue2); + llvm::Value *nan = isNaN(doubleValue1); + llvm::Value *nanCmp = m_builder.CreateFCmpUGT(doubleValue1, doubleValue2); + llvm::Value *doubleResult = m_builder.CreateSelect(nan, nanCmp, fcmp); + llvm::Value *icmp = m_builder.CreateICmpSGT(intValue1, intValue2); + return m_builder.CreateSelect(isInt, icmp, doubleResult); + } + + case Comparison::LT: { + llvm::Value *fcmp = m_builder.CreateFCmpOLT(doubleValue1, doubleValue2); + llvm::Value *icmp = m_builder.CreateICmpSLT(intValue1, intValue2); + return m_builder.CreateSelect(isInt, icmp, fcmp); + } + + default: + assert(false); + return nullptr; + } +} + +llvm::Value *LLVMBuildUtils::createNumberAndStringComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type) +{ + llvm::Value *value1 = castValue(arg1, Compiler::StaticType::Number, NumberType::Double); + llvm::Value *value2 = castValue(arg2, Compiler::StaticType::String); + + // If the number is NaN, skip the string to double conversion + llvm::Value *nan = m_builder.CreateAnd(m_builder.CreateNot(arg1->isInt), isNaN(value1)); + + llvm::BasicBlock *nanBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + llvm::BasicBlock *stringCastBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + m_builder.CreateCondBr(nan, nanBlock, stringCastBlock); + + m_builder.SetInsertPoint(nanBlock); + m_builder.CreateBr(nextBlock); + + m_builder.SetInsertPoint(stringCastBlock); + llvm::Value *okPtr = addAlloca(m_builder.getInt1Ty()); + llvm::Value *doubleValue = m_builder.CreateCall(m_functions.resolve_value_stringToDoubleWithCheck(), { value2, okPtr }); + llvm::Value *ok = m_builder.CreateLoad(m_builder.getInt1Ty(), okPtr); + m_builder.CreateBr(nextBlock); + + m_builder.SetInsertPoint(nextBlock); + + llvm::PHINode *doubleValuePhi = m_builder.CreatePHI(m_builder.getDoubleTy(), 2); + doubleValuePhi->addIncoming(llvm::ConstantFP::get(m_builder.getDoubleTy(), 0.0), nanBlock); + doubleValuePhi->addIncoming(doubleValue, stringCastBlock); + + llvm::PHINode *okPhi = m_builder.CreatePHI(m_builder.getInt1Ty(), 2); + okPhi->addIncoming(m_builder.getInt1(false), nanBlock); + okPhi->addIncoming(ok, stringCastBlock); + + // If both arguments are valid numbers, compare them as numbers, otherwise as strings + llvm::BasicBlock *numberBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + llvm::BasicBlock *stringBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + m_builder.CreateCondBr(okPhi, numberBlock, stringBlock); + + // Number comparison + m_builder.SetInsertPoint(numberBlock); + llvm::Value *numberCmp; + + switch (type) { + case Comparison::EQ: + numberCmp = m_builder.CreateFCmpOEQ(value1, doubleValuePhi); + break; + + case Comparison::GT: + numberCmp = m_builder.CreateFCmpOGT(value1, doubleValuePhi); + break; + + case Comparison::LT: + numberCmp = m_builder.CreateFCmpOLT(value1, doubleValuePhi); + break; + + default: + assert(false); + return nullptr; + } + + m_builder.CreateBr(nextBlock); + + // String comparison + m_builder.SetInsertPoint(stringBlock); + llvm::Value *stringValue = m_builder.CreateCall(m_functions.resolve_value_doubleToStringPtr(), { value1 }); + llvm::Value *cmp = m_builder.CreateCall(m_functions.resolve_string_compare_case_insensitive(), { stringValue, value2 }); + m_builder.CreateCall(m_functions.resolve_string_pool_free(), { stringValue }); // free the string immediately + + llvm::Value *zero = llvm::ConstantInt::get(m_builder.getInt32Ty(), 0, true); + llvm::Value *stringCmp; + + switch (type) { + case Comparison::EQ: + stringCmp = m_builder.CreateICmpEQ(cmp, zero); + break; + + case Comparison::GT: + stringCmp = m_builder.CreateICmpSGT(cmp, zero); + break; + + case Comparison::LT: + stringCmp = m_builder.CreateICmpSLT(cmp, zero); + break; + + default: + assert(false); + return nullptr; + } + + m_builder.CreateBr(nextBlock); + + // Merge the results + m_builder.SetInsertPoint(nextBlock); + + llvm::PHINode *result = m_builder.CreatePHI(m_builder.getInt1Ty(), 2); + result->addIncoming(numberCmp, numberBlock); + result->addIncoming(stringCmp, stringBlock); + + return result; +} + +llvm::Value *LLVMBuildUtils::createBoolAndStringComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type) +{ + llvm::Value *value1 = castValue(arg1, Compiler::StaticType::Bool); + llvm::Value *value2 = castValue(arg2, Compiler::StaticType::String); + + // NOTE: Bools are always valid numbers + + // Convert the string to double + llvm::Value *okPtr = addAlloca(m_builder.getInt1Ty()); + llvm::Value *doubleValue2 = m_builder.CreateCall(m_functions.resolve_value_stringToDoubleWithCheck(), { value2, okPtr }); + llvm::Value *ok = m_builder.CreateLoad(m_builder.getInt1Ty(), okPtr); + + // If the string is a valid number, compare the arguments as numbers, otherwise as strings + llvm::BasicBlock *numberBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + llvm::BasicBlock *stringBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + m_builder.CreateCondBr(ok, numberBlock, stringBlock); + + // Number comparison + m_builder.SetInsertPoint(numberBlock); + llvm::Value *doubleValue1 = m_builder.CreateUIToFP(value1, m_builder.getDoubleTy()); + llvm::Value *numberCmp; + + switch (type) { + case Comparison::EQ: + numberCmp = m_builder.CreateFCmpOEQ(doubleValue1, doubleValue2); + break; + + case Comparison::GT: + numberCmp = m_builder.CreateFCmpOGT(doubleValue1, doubleValue2); + break; + + case Comparison::LT: + numberCmp = m_builder.CreateFCmpOLT(doubleValue1, doubleValue2); + break; + + default: + assert(false); + return nullptr; + } + + m_builder.CreateBr(nextBlock); + + // String comparison + m_builder.SetInsertPoint(stringBlock); + llvm::Value *stringValue = m_builder.CreateCall(m_functions.resolve_value_boolToStringPtr(), { value1 }); + llvm::Value *cmp = m_builder.CreateCall(m_functions.resolve_string_compare_case_insensitive(), { stringValue, value2 }); + // NOTE: Do not free the string! + + llvm::Value *zero = llvm::ConstantInt::get(m_builder.getInt32Ty(), 0, true); + llvm::Value *stringCmp; + + switch (type) { + case Comparison::EQ: + stringCmp = m_builder.CreateICmpEQ(cmp, zero); + break; + + case Comparison::GT: + stringCmp = m_builder.CreateICmpSGT(cmp, zero); + break; + + case Comparison::LT: + stringCmp = m_builder.CreateICmpSLT(cmp, zero); + break; + + default: + assert(false); + return nullptr; + } + + m_builder.CreateBr(nextBlock); + + // Merge the results + m_builder.SetInsertPoint(nextBlock); + + llvm::PHINode *result = m_builder.CreatePHI(m_builder.getInt1Ty(), 2); + result->addIncoming(numberCmp, numberBlock); + result->addIncoming(stringCmp, stringBlock); + + return result; +} + llvm::Value *LLVMBuildUtils::getVariablePtr(llvm::Value *targetVariables, Variable *variable) { if (!m_target->isStage() && variable->target() == m_target) { diff --git a/src/engine/internal/llvm/llvmbuildutils.h b/src/engine/internal/llvm/llvmbuildutils.h index 371f177c..e6192ec9 100644 --- a/src/engine/internal/llvm/llvmbuildutils.h +++ b/src/engine/internal/llvm/llvmbuildutils.h @@ -5,6 +5,7 @@ #include #include "llvmfunctions.h" +#include "llvmlocalvariableinfo.h" #include "llvmvariableptr.h" #include "llvmlistptr.h" #include "llvmcoroutine.h" @@ -27,9 +28,15 @@ class LLVMBuildUtils LT }; + enum class NumberType + { + Int, + Double + }; + LLVMBuildUtils(LLVMCompilerContext *ctx, llvm::IRBuilder<> &builder, Compiler::CodeType codeType); - void init(llvm::Function *function, BlockPrototype *procedurePrototype, bool warp); + void init(llvm::Function *function, BlockPrototype *procedurePrototype, bool warp, const std::vector> ®s); void end(LLVMInstruction *lastInstruction, LLVMRegister *lastConstant); LLVMCompilerContext *compilerCtx() const; @@ -55,14 +62,16 @@ class LLVMBuildUtils LLVMCoroutine *coroutine() const; + void createLocalVariableInfo(CompilerLocalVariable *variable); void createVariablePtr(Variable *variable); void createListPtr(List *list); + LLVMLocalVariableInfo &localVariableInfo(CompilerLocalVariable *variable); LLVMVariablePtr &variablePtr(Variable *variable); LLVMListPtr &listPtr(List *list); - void syncVariables(llvm::Value *targetVariables); - void reloadVariables(llvm::Value *targetVariables); + void syncVariables(); + void reloadVariables(); void reloadLists(); void pushScopeLevel(); @@ -76,16 +85,28 @@ class LLVMBuildUtils static Compiler::StaticType optimizeRegisterType(const LLVMRegister *reg); static Compiler::StaticType mapType(ValueType type); + static ValueType mapType(Compiler::StaticType type); static bool isSingleType(Compiler::StaticType type); llvm::Value *addAlloca(llvm::Type *type); - llvm::Value *castValue(LLVMRegister *reg, Compiler::StaticType targetType); + llvm::Value *castValue(LLVMRegister *reg, Compiler::StaticType targetType, NumberType targetNumType = NumberType::Double); llvm::Type *getType(Compiler::StaticType type, bool isReturnType); llvm::Value *isNaN(llvm::Value *num); llvm::Value *removeNaN(llvm::Value *num); - void createValueStore(LLVMRegister *reg, llvm::Value *destPtr, Compiler::StaticType destType, Compiler::StaticType targetType); - void createValueStore(LLVMRegister *reg, llvm::Value *destPtr, Compiler::StaticType targetType); + void createValueStore( + llvm::Value *destPtr, + llvm::Value *destTypePtr, + llvm::Value *destIsIntVar, + llvm::Value *destIntVar, + LLVMRegister *reg, + Compiler::StaticType destType, + Compiler::StaticType targetType); + + void createValueStore(llvm::Value *destPtr, llvm::Value *destTypePtr, llvm::Value *destIsIntVar, llvm::Value *destIntVar, LLVMRegister *reg, Compiler::StaticType targetType); + + llvm::Value *getValueTypePtr(llvm::Value *value); + llvm::Value *getValueTypePtr(LLVMRegister *reg); llvm::Value *getListSize(const LLVMListPtr &listPtr); llvm::Value *getListItem(const LLVMListPtr &listPtr, llvm::Value *index); @@ -102,12 +123,24 @@ class LLVMBuildUtils void createVariableMap(); void createListMap(); - llvm::Value *castRawValue(LLVMRegister *reg, Compiler::StaticType targetType); - llvm::Constant *castConstValue(const Value &value, Compiler::StaticType targetType); + llvm::Value *loadRegisterType(LLVMRegister *reg, Compiler::StaticType type); + + llvm::Value *castRawValue(LLVMRegister *reg, Compiler::StaticType targetType, NumberType targetNumType); + llvm::Constant *castConstValue(const Value &value, Compiler::StaticType targetType, NumberType targetNumType); void createValueCopy(llvm::Value *source, llvm::Value *target); void copyStructField(llvm::Value *source, llvm::Value *target, int index, llvm::StructType *structType, llvm::Type *fieldType); + Comparison swapComparisonArgs(Comparison type); + + llvm::Value *createNumberAndNumberComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type); + llvm::Value *createBoolAndBoolComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type); + llvm::Value *createStringAndStringComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type); + + llvm::Value *createNumberAndBoolComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type); + llvm::Value *createNumberAndStringComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type); + llvm::Value *createBoolAndStringComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type); + llvm::Value *getVariablePtr(llvm::Value *targetVariables, Variable *variable); llvm::Value *getListPtr(llvm::Value *targetLists, List *list); llvm::Value *getListDataPtr(const LLVMListPtr &listPtr); @@ -136,6 +169,8 @@ class LLVMBuildUtils std::unique_ptr m_coroutine; + std::unordered_map m_localVariables; + std::unordered_map m_targetVariableMap; std::unordered_map m_variablePtrs; diff --git a/src/engine/internal/llvm/llvmcodeanalyzer.cpp b/src/engine/internal/llvm/llvmcodeanalyzer.cpp index dfc3ccc3..f41b54eb 100644 --- a/src/engine/internal/llvm/llvmcodeanalyzer.cpp +++ b/src/engine/internal/llvm/llvmcodeanalyzer.cpp @@ -7,8 +7,8 @@ using namespace libscratchcpp; -static const std::unordered_set - BEGIN_LOOP_INSTRUCTIONS = { LLVMInstruction::Type::BeginRepeatLoop, LLVMInstruction::Type::BeginWhileLoop, LLVMInstruction::Type::BeginRepeatUntilLoop }; +// NOTE: The loop condition in repeat until and while loops is considered a part of the loop body +static const std::unordered_set BEGIN_LOOP_INSTRUCTIONS = { LLVMInstruction::Type::BeginRepeatLoop, LLVMInstruction::Type::BeginLoopCondition }; static const std::unordered_set LIST_WRITE_INSTRUCTIONS = { LLVMInstruction::Type::AppendToList, LLVMInstruction::Type::InsertToList, LLVMInstruction::Type::ListReplace }; @@ -58,13 +58,13 @@ void LLVMCodeAnalyzer::analyzeScript(const LLVMInstructionList &script) const if (primaryBranch && primaryBranch->elseBranch) { // The previous variable types can be ignored in if/else statements overrideVariableTypes(primaryBranch, previousBranch); - mergeListTypes(primaryBranch, previousBranch); + mergeListTypes(primaryBranch, previousBranch, false); mergeVariableTypes(primaryBranch->elseBranch.get(), previousBranch); - mergeListTypes(primaryBranch->elseBranch.get(), previousBranch); + mergeListTypes(primaryBranch->elseBranch.get(), previousBranch, true); } else { mergeVariableTypes(primaryBranch, previousBranch); - mergeListTypes(primaryBranch, previousBranch); + mergeListTypes(primaryBranch, previousBranch, true); } // Remove the branch @@ -102,6 +102,29 @@ void LLVMCodeAnalyzer::analyzeScript(const LLVMInstructionList &script) const // Store the type in the return register // NOTE: Get list item returns empty string if index is out of range ins->functionReturnReg->setType(ins->targetType | Compiler::StaticType::String); + } else if (isProcedureCall(ins)) { + // Variables/lists may change in procedures + for (auto &[var, type] : currentBranch->variableTypes) { + if (type != Compiler::StaticType::Unknown) { + type = Compiler::StaticType::Unknown; + + if (typeAssignedInstructions.find(ins) == typeAssignedInstructions.cend()) + currentBranch->typeChanges = true; + } + } + + for (auto &[list, type] : currentBranch->listTypes) { + if (type != Compiler::StaticType::Unknown) { + type = Compiler::StaticType::Unknown; + + if (typeAssignedInstructions.find(ins) == typeAssignedInstructions.cend()) { + typeAssignedInstructions.insert(ins); + currentBranch->typeChanges = true; + } + } + } + + typeAssignedInstructions.insert(ins); } ins = ins->next; @@ -173,7 +196,7 @@ void LLVMCodeAnalyzer::mergeVariableTypes(Branch *branch, Branch *previousBranch auto it = previousBranch->variableTypes.find(var); if (it == previousBranch->variableTypes.cend()) - previousBranch->variableTypes[var] = type; + previousBranch->variableTypes[var] = Compiler::StaticType::Unknown; else it->second |= type; } @@ -185,13 +208,13 @@ void LLVMCodeAnalyzer::overrideVariableTypes(Branch *branch, Branch *previousBra previousBranch->variableTypes[var] = type; } -void LLVMCodeAnalyzer::mergeListTypes(Branch *branch, Branch *previousBranch) const +void LLVMCodeAnalyzer::mergeListTypes(Branch *branch, Branch *previousBranch, bool firstUnknown) const { for (const auto &[list, type] : branch->listTypes) { auto it = previousBranch->listTypes.find(list); if (it == previousBranch->listTypes.cend()) - previousBranch->listTypes[list] = type; + previousBranch->listTypes[list] = firstUnknown ? Compiler::StaticType::Unknown : type; else it->second |= type; } @@ -247,6 +270,11 @@ bool LLVMCodeAnalyzer::isListClear(const LLVMInstruction *ins) const return (ins->type == LLVMInstruction::Type::ClearList); } +bool LLVMCodeAnalyzer::isProcedureCall(const LLVMInstruction *ins) const +{ + return (ins->type == LLVMInstruction::Type::CallProcedure); +} + Compiler::StaticType LLVMCodeAnalyzer::writeType(LLVMInstruction *ins) const { assert(ins); diff --git a/src/engine/internal/llvm/llvmcodeanalyzer.h b/src/engine/internal/llvm/llvmcodeanalyzer.h index 65c73c5d..4c9b6a4a 100644 --- a/src/engine/internal/llvm/llvmcodeanalyzer.h +++ b/src/engine/internal/llvm/llvmcodeanalyzer.h @@ -32,7 +32,7 @@ class LLVMCodeAnalyzer void mergeVariableTypes(Branch *branch, Branch *previousBranch) const; void overrideVariableTypes(Branch *branch, Branch *previousBranch) const; - void mergeListTypes(Branch *branch, Branch *previousBranch) const; + void mergeListTypes(Branch *branch, Branch *previousBranch, bool firstUnknown) const; bool isLoopStart(const LLVMInstruction *ins) const; bool isLoopEnd(const LLVMInstruction *ins) const; @@ -47,6 +47,8 @@ class LLVMCodeAnalyzer bool isListWrite(const LLVMInstruction *ins) const; bool isListClear(const LLVMInstruction *ins) const; + bool isProcedureCall(const LLVMInstruction *ins) const; + Compiler::StaticType writeType(LLVMInstruction *ins) const; }; diff --git a/src/engine/internal/llvm/llvmcodebuilder.cpp b/src/engine/internal/llvm/llvmcodebuilder.cpp index 3cba9270..07dcf003 100644 --- a/src/engine/internal/llvm/llvmcodebuilder.cpp +++ b/src/engine/internal/llvm/llvmcodebuilder.cpp @@ -55,7 +55,7 @@ std::shared_ptr LLVMCodeBuilder::build() if (m_warp) { #ifdef ENABLE_CODE_ANALYZER // Analyze the script (type analysis, optimizations, etc.) - // NOTE: Do this only for non-warp scripts + // NOTE: Do this only for warp scripts m_codeAnalyzer.analyzeScript(m_instructions); #endif } @@ -80,7 +80,7 @@ std::shared_ptr LLVMCodeBuilder::build() llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_llvmCtx, "entry", m_function); m_builder.SetInsertPoint(entry); - m_utils.init(m_function, m_procedurePrototype, m_warp); + m_utils.init(m_function, m_procedurePrototype, m_warp, m_regs); // Build recorded instructions LLVMInstruction *ins = m_instructions.first(); @@ -158,7 +158,17 @@ CompilerValue *LLVMCodeBuilder::addLoopIndex() CompilerValue *LLVMCodeBuilder::addLocalVariableValue(CompilerLocalVariable *variable) { - return createOp(LLVMInstruction::Type::ReadLocalVariable, variable->type(), variable->type(), { variable->ptr() }); + auto ins = std::make_shared(LLVMInstruction::Type::ReadLocalVariable, m_loopCondition); + ins->localVarInfo = &m_utils.localVariableInfo(variable); + + ins->args.push_back({ variable->type(), dynamic_cast(variable->ptr()) }); + + auto ret = std::make_shared(variable->type()); + ret->isRawValue = false; + ins->functionReturnReg = ret.get(); + + m_instructions.addInstruction(ins); + return addReg(ret, ins); } CompilerValue *LLVMCodeBuilder::addVariableValue(Variable *variable) @@ -416,12 +426,15 @@ CompilerLocalVariable *LLVMCodeBuilder::createLocalVariable(Compiler::StaticType CompilerValue *ptr = createOp(LLVMInstruction::Type::CreateLocalVariable, type); auto var = std::make_shared(ptr); m_localVars.push_back(var); + m_utils.createLocalVariableInfo(var.get()); + m_instructions.last()->localVarInfo = &m_utils.localVariableInfo(var.get()); return var.get(); } void LLVMCodeBuilder::createLocalVariableWrite(CompilerLocalVariable *variable, CompilerValue *value) { createOp(LLVMInstruction::Type::WriteLocalVariable, Compiler::StaticType::Void, variable->type(), { variable->ptr(), value }); + m_instructions.last()->localVarInfo = &m_utils.localVariableInfo(variable); } void LLVMCodeBuilder::createVariableWrite(Variable *variable, CompilerValue *value) @@ -552,6 +565,12 @@ void LLVMCodeBuilder::createStop() m_instructions.addInstruction(ins); } +void LLVMCodeBuilder::createStopWithoutSync() +{ + auto ins = std::make_shared(LLVMInstruction::Type::StopWithoutSync, m_loopCondition); + m_instructions.addInstruction(ins); +} + void LLVMCodeBuilder::createProcedureCall(BlockPrototype *prototype, const Compiler::Args &args) { assert(prototype); diff --git a/src/engine/internal/llvm/llvmcodebuilder.h b/src/engine/internal/llvm/llvmcodebuilder.h index 37c060c3..7fd9b928 100644 --- a/src/engine/internal/llvm/llvmcodebuilder.h +++ b/src/engine/internal/llvm/llvmcodebuilder.h @@ -112,6 +112,7 @@ class LLVMCodeBuilder : public ICodeBuilder void yield() override; void createStop() override; + void createStopWithoutSync() override; void createProcedureCall(BlockPrototype *prototype, const Compiler::Args &args) override; diff --git a/src/engine/internal/llvm/llvmcompilercontext.cpp b/src/engine/internal/llvm/llvmcompilercontext.cpp index f879f930..5ec8c6b4 100644 --- a/src/engine/internal/llvm/llvmcompilercontext.cpp +++ b/src/engine/internal/llvm/llvmcompilercontext.cpp @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 #include +#include +#include #include #include @@ -70,20 +72,7 @@ void LLVMCompilerContext::initJit() #endif // Optimize - llvm::PassBuilder passBuilder; - llvm::LoopAnalysisManager loopAnalysisManager; - llvm::FunctionAnalysisManager functionAnalysisManager; - llvm::CGSCCAnalysisManager cGSCCAnalysisManager; - llvm::ModuleAnalysisManager moduleAnalysisManager; - - passBuilder.registerModuleAnalyses(moduleAnalysisManager); - passBuilder.registerCGSCCAnalyses(cGSCCAnalysisManager); - passBuilder.registerFunctionAnalyses(functionAnalysisManager); - passBuilder.registerLoopAnalyses(loopAnalysisManager); - passBuilder.crossRegisterProxies(loopAnalysisManager, functionAnalysisManager, cGSCCAnalysisManager, moduleAnalysisManager); - - llvm::ModulePassManager modulePassManager = passBuilder.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O3); - modulePassManager.run(*m_module, moduleAnalysisManager); + optimize(llvm::OptimizationLevel::O3); const auto &functions = m_module->getFunctionList(); std::vector lookupNames; @@ -151,6 +140,71 @@ void LLVMCompilerContext::initTarget() llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); llvm::InitializeNativeTargetAsmParser(); + + createTargetMachine(); + + m_module->setDataLayout(m_targetMachine->createDataLayout()); +} + +void LLVMCompilerContext::createTargetMachine() +{ + std::string error; + std::string targetTriple = llvm::sys::getDefaultTargetTriple(); + m_module->setTargetTriple(targetTriple); + + const llvm::Target *target = llvm::TargetRegistry::lookupTarget(targetTriple, error); + + if (!target) { + llvm::errs() << error; + return; + } + + llvm::TargetOptions opt; + const char *cpu = "generic"; + const char *features = ""; + + m_targetMachine = std::unique_ptr(target->createTargetMachine(targetTriple, cpu, features, opt, llvm::Reloc::PIC_)); +} + +void LLVMCompilerContext::optimize(llvm::OptimizationLevel optLevel) +{ + llvm::PassBuilder passBuilder(m_targetMachine.get()); + llvm::LoopAnalysisManager loopAnalysisManager; + llvm::FunctionAnalysisManager functionAnalysisManager; + llvm::CGSCCAnalysisManager cGSCCAnalysisManager; + llvm::ModuleAnalysisManager moduleAnalysisManager; + + passBuilder.registerModuleAnalyses(moduleAnalysisManager); + passBuilder.registerCGSCCAnalyses(cGSCCAnalysisManager); + passBuilder.registerFunctionAnalyses(functionAnalysisManager); + passBuilder.registerLoopAnalyses(loopAnalysisManager); + passBuilder.crossRegisterProxies(loopAnalysisManager, functionAnalysisManager, cGSCCAnalysisManager, moduleAnalysisManager); + + llvm::ModulePassManager modulePassManager = passBuilder.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O3); + + std::string pipeline; + + if (optLevel == llvm::OptimizationLevel::O0) + pipeline = "default"; + else if (optLevel == llvm::OptimizationLevel::O1) + pipeline = "default"; + else if (optLevel == llvm::OptimizationLevel::O2) + pipeline = "default"; + else if (optLevel == llvm::OptimizationLevel::O3) + pipeline = "default"; + else if (optLevel == llvm::OptimizationLevel::Os) + pipeline = "default"; + else if (optLevel == llvm::OptimizationLevel::Oz) + pipeline = "default"; + else + assert(false); + + if (passBuilder.parsePassPipeline(modulePassManager, pipeline)) { + llvm::errs() << "Failed to parse pipeline\n"; + return; + } + + modulePassManager.run(*m_module, moduleAnalysisManager); } llvm::Function *LLVMCompilerContext::createCoroResumeFunction() diff --git a/src/engine/internal/llvm/llvmcompilercontext.h b/src/engine/internal/llvm/llvmcompilercontext.h index 637e4961..ce606980 100644 --- a/src/engine/internal/llvm/llvmcompilercontext.h +++ b/src/engine/internal/llvm/llvmcompilercontext.h @@ -4,6 +4,7 @@ #include #include +#include #include @@ -52,6 +53,8 @@ class LLVMCompilerContext : public CompilerContext using DestroyCoroFuncType = void (*)(void *); void initTarget(); + void createTargetMachine(); + void optimize(llvm::OptimizationLevel optLevel); llvm::Function *createCoroResumeFunction(); llvm::Function *createCoroDestroyFunction(); @@ -62,6 +65,7 @@ class LLVMCompilerContext : public CompilerContext std::unique_ptr m_module; llvm::LLVMContext *m_llvmCtxPtr = nullptr; llvm::Module *m_modulePtr = nullptr; + std::unique_ptr m_targetMachine; llvm::Expected> m_jit; bool m_jitInitialized = false; diff --git a/src/engine/internal/llvm/llvmcoroutine.cpp b/src/engine/internal/llvm/llvmcoroutine.cpp index f4b268fb..02c4ecb1 100644 --- a/src/engine/internal/llvm/llvmcoroutine.cpp +++ b/src/engine/internal/llvm/llvmcoroutine.cpp @@ -112,7 +112,7 @@ llvm::Value *LLVMCoroutine::createResume(llvm::Module *module, llvm::IRBuilder<> llvm::Value *ret = builder->CreateAlloca(builder->getInt1Ty()); llvm::Value *done = builder->CreateCall(coroDone, { coroHandle }); - done = builder->CreateCall(coroDone, { coroHandle }); + builder->CreateStore(done, ret); llvm::BasicBlock *destroyBranch = llvm::BasicBlock::Create(ctx, "", function); llvm::BasicBlock *resumeBranch = llvm::BasicBlock::Create(ctx, "", function); diff --git a/src/engine/internal/llvm/llvmfunctions.cpp b/src/engine/internal/llvm/llvmfunctions.cpp index 899254aa..36ebcc83 100644 --- a/src/engine/internal/llvm/llvmfunctions.cpp +++ b/src/engine/internal/llvm/llvmfunctions.cpp @@ -21,7 +21,7 @@ extern "C" return value_doubleIsInt(from) && value_doubleIsInt(to) ? ctx->rng()->randint(from, to) : ctx->rng()->randintDouble(from, to); } - double llvm_random_long(ExecutionContext *ctx, long from, long to) + int64_t llvm_random_int64(ExecutionContext *ctx, int64_t from, int64_t to) { return ctx->rng()->randint(from, to); } @@ -132,6 +132,15 @@ llvm::FunctionCallee LLVMFunctions::resolve_value_stringToDouble() return callee; } +llvm::FunctionCallee LLVMFunctions::resolve_value_stringToDoubleWithCheck() +{ + llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(*m_ctx->llvmCtx()), 0); + llvm::FunctionCallee callee = resolveFunction("value_stringToDoubleWithCheck", llvm::FunctionType::get(m_builder->getDoubleTy(), { pointerType, m_builder->getInt1Ty()->getPointerTo() }, false)); + llvm::Function *func = llvm::cast(callee.getCallee()); + func->addFnAttr(llvm::Attribute::ReadOnly); + return callee; +} + llvm::FunctionCallee LLVMFunctions::resolve_value_stringToBool() { llvm::FunctionCallee callee = resolveFunction("value_stringToBool", llvm::FunctionType::get(m_builder->getInt1Ty(), llvm::PointerType::get(llvm::Type::getInt8Ty(*m_ctx->llvmCtx()), 0), false)); @@ -240,10 +249,10 @@ llvm::FunctionCallee LLVMFunctions::resolve_llvm_random_double() return resolveFunction("llvm_random_double", llvm::FunctionType::get(m_builder->getDoubleTy(), { pointerType, m_builder->getDoubleTy(), m_builder->getDoubleTy() }, false)); } -llvm::FunctionCallee LLVMFunctions::resolve_llvm_random_long() +llvm::FunctionCallee LLVMFunctions::resolve_llvm_random_int64() { llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(*m_ctx->llvmCtx()), 0); - return resolveFunction("llvm_random_long", llvm::FunctionType::get(m_builder->getDoubleTy(), { pointerType, m_builder->getInt64Ty(), m_builder->getInt64Ty() }, false)); + return resolveFunction("llvm_random_int64", llvm::FunctionType::get(m_builder->getInt64Ty(), { pointerType, m_builder->getInt64Ty(), m_builder->getInt64Ty() }, false)); } llvm::FunctionCallee LLVMFunctions::resolve_llvm_random_bool() diff --git a/src/engine/internal/llvm/llvmfunctions.h b/src/engine/internal/llvm/llvmfunctions.h index 2a257321..13968175 100644 --- a/src/engine/internal/llvm/llvmfunctions.h +++ b/src/engine/internal/llvm/llvmfunctions.h @@ -30,6 +30,7 @@ class LLVMFunctions llvm::FunctionCallee resolve_value_doubleToStringPtr(); llvm::FunctionCallee resolve_value_boolToStringPtr(); llvm::FunctionCallee resolve_value_stringToDouble(); + llvm::FunctionCallee resolve_value_stringToDoubleWithCheck(); llvm::FunctionCallee resolve_value_stringToBool(); llvm::FunctionCallee resolve_value_equals(); llvm::FunctionCallee resolve_value_greater(); @@ -44,7 +45,7 @@ class LLVMFunctions llvm::FunctionCallee resolve_list_to_string(); llvm::FunctionCallee resolve_llvm_random(); llvm::FunctionCallee resolve_llvm_random_double(); - llvm::FunctionCallee resolve_llvm_random_long(); + llvm::FunctionCallee resolve_llvm_random_int64(); llvm::FunctionCallee resolve_llvm_random_bool(); llvm::FunctionCallee resolve_string_pool_new(); llvm::FunctionCallee resolve_string_pool_free(); diff --git a/src/engine/internal/llvm/llvminstruction.h b/src/engine/internal/llvm/llvminstruction.h index e6a5212e..558258aa 100644 --- a/src/engine/internal/llvm/llvminstruction.h +++ b/src/engine/internal/llvm/llvminstruction.h @@ -9,6 +9,7 @@ namespace libscratchcpp { +class LLVMLocalVariableInfo; class BlockPrototype; struct LLVMInstruction @@ -76,6 +77,7 @@ struct LLVMInstruction BeginLoopCondition, EndLoop, Stop, + StopWithoutSync, CallProcedure, ProcedureArg }; @@ -92,6 +94,7 @@ struct LLVMInstruction LLVMRegister *functionReturnReg = nullptr; bool functionTargetArg = false; // whether to add target ptr to function parameters bool functionCtxArg = false; // whether to add execution context ptr to function parameters + LLVMLocalVariableInfo *localVarInfo = nullptr; // for local variables Variable *targetVariable = nullptr; // for variables List *targetList = nullptr; // for lists Compiler::StaticType targetType = Compiler::StaticType::Unknown; // variable or list type (before read/write operation) diff --git a/src/engine/internal/llvm/llvmlistptr.h b/src/engine/internal/llvm/llvmlistptr.h index 194737ff..12f1ced7 100644 --- a/src/engine/internal/llvm/llvmlistptr.h +++ b/src/engine/internal/llvm/llvmlistptr.h @@ -23,6 +23,10 @@ struct LLVMListPtr llvm::Value *sizePtr = nullptr; llvm::Value *allocatedSizePtr = nullptr; llvm::Value *size = nullptr; + + llvm::Value *hasNumber = nullptr; + llvm::Value *hasBool = nullptr; + llvm::Value *hasString = nullptr; }; } // namespace libscratchcpp diff --git a/src/engine/internal/llvm/llvmlocalvariableinfo.h b/src/engine/internal/llvm/llvmlocalvariableinfo.h new file mode 100644 index 00000000..d0f0d465 --- /dev/null +++ b/src/engine/internal/llvm/llvmlocalvariableinfo.h @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace llvm +{ + +class Value; + +} + +namespace libscratchcpp +{ + +struct LLVMLocalVariableInfo +{ + llvm::Value *isInt = nullptr; + llvm::Value *intValue = nullptr; +}; + +} // namespace libscratchcpp diff --git a/src/engine/internal/llvm/llvmregister.h b/src/engine/internal/llvm/llvmregister.h index 3bd90cc8..d4e2eb78 100644 --- a/src/engine/internal/llvm/llvmregister.h +++ b/src/engine/internal/llvm/llvmregister.h @@ -31,7 +31,13 @@ struct LLVMRegister : public virtual CompilerValue } llvm::Value *value = nullptr; + llvm::Value *typeVar = nullptr; + + llvm::Value *isInt = nullptr; + llvm::Value *intValue = nullptr; + bool isRawValue = false; + std::shared_ptr instruction; }; diff --git a/src/engine/internal/llvm/llvmvariableptr.h b/src/engine/internal/llvm/llvmvariableptr.h index 2e5b3d01..0bf064ca 100644 --- a/src/engine/internal/llvm/llvmvariableptr.h +++ b/src/engine/internal/llvm/llvmvariableptr.h @@ -18,10 +18,12 @@ class LLVMInstruction; struct LLVMVariablePtr { - llvm::Value *stackPtr = nullptr; llvm::Value *heapPtr = nullptr; - bool onStack = false; - bool changed = false; + llvm::Value *stackPtr = nullptr; + llvm::Value *changed = nullptr; + + llvm::Value *isInt = nullptr; + llvm::Value *intValue = nullptr; }; } // namespace libscratchcpp diff --git a/src/scratch/value_functions.cpp b/src/scratch/value_functions.cpp index faaf1276..64a08300 100644 --- a/src/scratch/value_functions.cpp +++ b/src/scratch/value_functions.cpp @@ -501,6 +501,22 @@ extern "C" return value_stringToDoubleImpl(s->data, s->size); } + /*! + * Converts the given string to double. + * \param[out] ok Whether the conversion was successful. + */ + double value_stringToDoubleWithCheck(const StringPtr *s, bool *ok) + { + *ok = true; + + if (string_compare_case_sensitive(s, &INFINITY_STR) == 0) + return std::numeric_limits::infinity(); + else if (string_compare_case_sensitive(s, &NEGATIVE_INFINITY_STR) == 0) + return -std::numeric_limits::infinity(); + + return value_stringToDoubleImpl(s->data, s->size, ok); + } + /*! Converts the given string to boolean. */ bool value_stringToBool(const StringPtr *s) { diff --git a/test/blocks/control_blocks_test.cpp b/test/blocks/control_blocks_test.cpp index e126241f..4364edc7 100644 --- a/test/blocks/control_blocks_test.cpp +++ b/test/blocks/control_blocks_test.cpp @@ -1064,45 +1064,96 @@ TEST_F(ControlBlocksTest, CreateCloneOfStage) } } -TEST_F(ControlBlocksTest, DeleteThisClone) +TEST_F(ControlBlocksTest, DeleteThisClone_Clone) { - Sprite sprite; - sprite.setEngine(&m_engineMock); + auto sprite = std::make_shared(); + sprite->setEngine(&m_engineMock); + + auto var = std::make_shared("", ""); + sprite->addVariable(var); std::shared_ptr clone; EXPECT_CALL(m_engineMock, cloneLimit()).WillRepeatedly(Return(-1)); EXPECT_CALL(m_engineMock, initClone(_)).WillOnce(SaveArg<0>(&clone)); - EXPECT_CALL(m_engineMock, moveDrawableBehindOther(_, &sprite)); + EXPECT_CALL(m_engineMock, moveDrawableBehindOther(_, sprite.get())); EXPECT_CALL(m_engineMock, requestRedraw()); - sprite.clone(); + sprite->clone(); ASSERT_TRUE(clone); - ScriptBuilder builder(m_extension.get(), m_engine, clone); + ScriptBuilder builder(m_extension.get(), m_engine, sprite); builder.addBlock("control_delete_this_clone"); auto block = builder.currentBlock(); - Compiler compiler(&m_engineMock, clone.get()); + builder.addBlock("test_set_var"); + builder.addEntityField("VARIABLE", var); + builder.addValueInput("VALUE", true); + builder.currentBlock(); + + Compiler compiler(&m_engineMock, sprite.get()); auto code = compiler.compile(block); - Script script(clone.get(), block, &m_engineMock); + Script script(sprite.get(), block, &m_engineMock); script.setCode(code); Thread thread(clone.get(), &m_engineMock, &script); EXPECT_CALL(m_engineMock, stopTarget(clone.get(), nullptr)); EXPECT_CALL(m_engineMock, deinitClone(clone)); thread.run(); + + // The script should stop (variable value is false) + ASSERT_FALSE(clone->variableAt(0)->value().toBool()); } -TEST_F(ControlBlocksTest, DeleteThisCloneStage) +TEST_F(ControlBlocksTest, DeleteThisClone_NotCloneSprite) +{ + auto sprite = std::make_shared(); + sprite->setEngine(&m_engineMock); + + auto var = std::make_shared("", ""); + sprite->addVariable(var); + + ScriptBuilder builder(m_extension.get(), m_engine, sprite); + + builder.addBlock("control_delete_this_clone"); + auto block = builder.currentBlock(); + + builder.addBlock("test_set_var"); + builder.addEntityField("VARIABLE", var); + builder.addValueInput("VALUE", true); + builder.currentBlock(); + + Compiler compiler(&m_engineMock, sprite.get()); + auto code = compiler.compile(block); + Script script(sprite.get(), block, &m_engineMock); + script.setCode(code); + Thread thread(sprite.get(), &m_engineMock, &script); + + EXPECT_CALL(m_engineMock, stopTarget).Times(0); + EXPECT_CALL(m_engineMock, deinitClone).Times(0); + thread.run(); + + // The script should NOT stop (variable value is true) + ASSERT_TRUE(var->value().toBool()); +} + +TEST_F(ControlBlocksTest, DeleteThisClone_Stage) { auto target = std::make_shared(); target->setEngine(&m_engineMock); + auto var = std::make_shared("", ""); + target->addVariable(var); + ScriptBuilder builder(m_extension.get(), m_engine, target); builder.addBlock("control_delete_this_clone"); auto block = builder.currentBlock(); + builder.addBlock("test_set_var"); + builder.addEntityField("VARIABLE", var); + builder.addValueInput("VALUE", true); + builder.currentBlock(); + Compiler compiler(&m_engineMock, target.get()); auto code = compiler.compile(block); Script script(target.get(), block, &m_engineMock); @@ -1112,4 +1163,7 @@ TEST_F(ControlBlocksTest, DeleteThisCloneStage) EXPECT_CALL(m_engineMock, stopTarget).Times(0); EXPECT_CALL(m_engineMock, deinitClone).Times(0); thread.run(); + + // The script should NOT stop (variable value is true) + ASSERT_TRUE(var->value().toBool()); } diff --git a/test/compiler/compiler_test.cpp b/test/compiler/compiler_test.cpp index 55d8d161..9cdbf44f 100644 --- a/test/compiler/compiler_test.cpp +++ b/test/compiler/compiler_test.cpp @@ -1673,6 +1673,20 @@ TEST_F(CompilerTest, CreateStop) compile(m_compiler.get(), block.get()); } +TEST_F(CompilerTest, CreateStopWithoutSync) +{ + + auto block = std::make_shared("", ""); + + block->setCompileFunction([](Compiler *compiler) -> CompilerValue * { + EXPECT_CALL(*m_builder, createStopWithoutSync()); + compiler->createStopWithoutSync(); + return nullptr; + }); + + compile(m_compiler.get(), block.get()); +} + TEST_F(CompilerTest, CreateProcedureCall) { diff --git a/test/llvm/code_analyzer/list_type_analysis.cpp b/test/llvm/code_analyzer/list_type_analysis.cpp index 021f0a0e..05e76b13 100644 --- a/test/llvm/code_analyzer/list_type_analysis.cpp +++ b/test/llvm/code_analyzer/list_type_analysis.cpp @@ -182,6 +182,39 @@ TEST(LLVMCodeAnalyzer_ListTypeAnalysis, ClearListOperation) ASSERT_EQ(appendList2->targetType, Compiler::StaticType::Void); } +TEST(LLVMCodeAnalyzer_ListTypeAnalysis, ProcedureCall) +{ + LLVMCodeAnalyzer analyzer; + LLVMInstructionList list; + List targetList("", ""); + + auto clearList = std::make_shared(LLVMInstruction::Type::ClearList, false); + clearList->targetList = &targetList; + list.addInstruction(clearList); + + auto appendList1 = std::make_shared(LLVMInstruction::Type::AppendToList, false); + LLVMConstantRegister value1(Compiler::StaticType::String, "hello"); + appendList1->targetList = &targetList; + appendList1->args.push_back({ Compiler::StaticType::Unknown, &value1 }); + list.addInstruction(appendList1); + + auto procCall = std::make_shared(LLVMInstruction::Type::CallProcedure, false); + list.addInstruction(procCall); + + auto appendList2 = std::make_shared(LLVMInstruction::Type::AppendToList, false); + LLVMConstantRegister value2(Compiler::StaticType::Number, 5.2); + appendList2->targetList = &targetList; + appendList2->args.push_back({ Compiler::StaticType::Unknown, &value2 }); + list.addInstruction(appendList2); + + analyzer.analyzeScript(list); + + ASSERT_EQ(appendList1->targetType, Compiler::StaticType::Void); + + // Type unknown due to procedure call + ASSERT_EQ(appendList2->targetType, Compiler::StaticType::Unknown); +} + TEST(LLVMCodeAnalyzer_ListTypeAnalysis, MixedWriteOperationsSameType_AfterClear) { LLVMCodeAnalyzer analyzer; @@ -318,6 +351,167 @@ TEST(LLVMCodeAnalyzer_ListTypeAnalysis, LoopSingleWrite_AfterClear) ASSERT_EQ(appendList->targetType, Compiler::StaticType::Number); } +TEST(LLVMCodeAnalyzer_ListTypeAnalysis, ClearAndWriteInIfStatement_IfBranch) +{ + LLVMCodeAnalyzer analyzer; + LLVMInstructionList list; + List targetList("", ""); + + auto ifStart = std::make_shared(LLVMInstruction::Type::BeginIf, false); + list.addInstruction(ifStart); + + auto clearList = std::make_shared(LLVMInstruction::Type::ClearList, false); + clearList->targetList = &targetList; + list.addInstruction(clearList); + + auto appendList1 = std::make_shared(LLVMInstruction::Type::AppendToList, false); + LLVMConstantRegister value1(Compiler::StaticType::Number, 1.25); + appendList1->targetList = &targetList; + appendList1->args.push_back({ Compiler::StaticType::Unknown, &value1 }); + list.addInstruction(appendList1); + + auto ifEnd = std::make_shared(LLVMInstruction::Type::EndIf, false); + list.addInstruction(ifEnd); + + auto appendList2 = std::make_shared(LLVMInstruction::Type::AppendToList, false); + LLVMConstantRegister value2(Compiler::StaticType::Bool, true); + appendList2->targetList = &targetList; + appendList2->args.push_back({ Compiler::StaticType::Unknown, &value2 }); + list.addInstruction(appendList2); + + analyzer.analyzeScript(list); + + ASSERT_EQ(appendList1->targetType, Compiler::StaticType::Void); + + // The type is Unknown because the if statement might not run at all + ASSERT_EQ(appendList2->targetType, Compiler::StaticType::Unknown); +} + +TEST(LLVMCodeAnalyzer_ListTypeAnalysis, ClearAndWriteInIfStatement_ElseBranch) +{ + LLVMCodeAnalyzer analyzer; + LLVMInstructionList list; + List targetList("", ""); + + auto ifStart = std::make_shared(LLVMInstruction::Type::BeginIf, false); + list.addInstruction(ifStart); + + auto elseStart = std::make_shared(LLVMInstruction::Type::BeginElse, false); + list.addInstruction(elseStart); + + auto clearList = std::make_shared(LLVMInstruction::Type::ClearList, false); + clearList->targetList = &targetList; + list.addInstruction(clearList); + + auto appendList1 = std::make_shared(LLVMInstruction::Type::AppendToList, false); + LLVMConstantRegister value1(Compiler::StaticType::Number, 1.25); + appendList1->targetList = &targetList; + appendList1->args.push_back({ Compiler::StaticType::Unknown, &value1 }); + list.addInstruction(appendList1); + + auto ifEnd = std::make_shared(LLVMInstruction::Type::EndIf, false); + list.addInstruction(ifEnd); + + auto appendList2 = std::make_shared(LLVMInstruction::Type::AppendToList, false); + LLVMConstantRegister value2(Compiler::StaticType::Bool, true); + appendList2->targetList = &targetList; + appendList2->args.push_back({ Compiler::StaticType::Unknown, &value2 }); + list.addInstruction(appendList2); + + analyzer.analyzeScript(list); + + ASSERT_EQ(appendList1->targetType, Compiler::StaticType::Void); + + // The type is Unknown because the if statement might not run at all + ASSERT_EQ(appendList2->targetType, Compiler::StaticType::Unknown); +} + +TEST(LLVMCodeAnalyzer_ListTypeAnalysis, ClearAndWriteInIfElse) +{ + LLVMCodeAnalyzer analyzer; + LLVMInstructionList list; + List targetList("", ""); + + auto ifStart = std::make_shared(LLVMInstruction::Type::BeginIf, false); + list.addInstruction(ifStart); + + auto clearList1 = std::make_shared(LLVMInstruction::Type::ClearList, false); + clearList1->targetList = &targetList; + list.addInstruction(clearList1); + + auto appendList1 = std::make_shared(LLVMInstruction::Type::AppendToList, false); + LLVMConstantRegister value1(Compiler::StaticType::Number, 1.25); + appendList1->targetList = &targetList; + appendList1->args.push_back({ Compiler::StaticType::Unknown, &value1 }); + list.addInstruction(appendList1); + + auto elseStart = std::make_shared(LLVMInstruction::Type::BeginElse, false); + list.addInstruction(elseStart); + + auto clearList2 = std::make_shared(LLVMInstruction::Type::ClearList, false); + clearList2->targetList = &targetList; + list.addInstruction(clearList2); + + auto appendList2 = std::make_shared(LLVMInstruction::Type::AppendToList, false); + LLVMConstantRegister value2(Compiler::StaticType::String, "hello"); + appendList2->targetList = &targetList; + appendList2->args.push_back({ Compiler::StaticType::Unknown, &value2 }); + list.addInstruction(appendList2); + + auto ifEnd = std::make_shared(LLVMInstruction::Type::EndIf, false); + list.addInstruction(ifEnd); + + auto appendList3 = std::make_shared(LLVMInstruction::Type::AppendToList, false); + LLVMConstantRegister value3(Compiler::StaticType::Bool, true); + appendList3->targetList = &targetList; + appendList3->args.push_back({ Compiler::StaticType::Unknown, &value3 }); + list.addInstruction(appendList3); + + analyzer.analyzeScript(list); + + ASSERT_EQ(appendList1->targetType, Compiler::StaticType::Void); + ASSERT_EQ(appendList2->targetType, Compiler::StaticType::Void); + + // The type is Number | String because any of the branches may run + ASSERT_EQ(appendList3->targetType, Compiler::StaticType::Number | Compiler::StaticType::String); +} + +TEST(LLVMCodeAnalyzer_ListTypeAnalysis, ClearAndWriteInLoop) +{ + LLVMCodeAnalyzer analyzer; + LLVMInstructionList list; + List targetList("", ""); + + auto loopStart = std::make_shared(LLVMInstruction::Type::BeginRepeatLoop, false); + list.addInstruction(loopStart); + + auto clearList = std::make_shared(LLVMInstruction::Type::ClearList, false); + clearList->targetList = &targetList; + list.addInstruction(clearList); + + auto appendList1 = std::make_shared(LLVMInstruction::Type::AppendToList, false); + LLVMConstantRegister value1(Compiler::StaticType::Number, 1.25); + appendList1->targetList = &targetList; + appendList1->args.push_back({ Compiler::StaticType::Unknown, &value1 }); + list.addInstruction(appendList1); + + auto loopEnd = std::make_shared(LLVMInstruction::Type::EndLoop, false); + list.addInstruction(loopEnd); + + auto appendList2 = std::make_shared(LLVMInstruction::Type::AppendToList, false); + LLVMConstantRegister value2(Compiler::StaticType::Bool, true); + appendList2->targetList = &targetList; + appendList2->args.push_back({ Compiler::StaticType::Unknown, &value2 }); + list.addInstruction(appendList2); + + analyzer.analyzeScript(list); + + ASSERT_EQ(appendList1->targetType, Compiler::StaticType::Void); + + // The type is Unknown because the loop might not run at all + ASSERT_EQ(appendList2->targetType, Compiler::StaticType::Unknown); +} + TEST(LLVMCodeAnalyzer_ListTypeAnalysis, ClearAfterWriteInLoop) { LLVMCodeAnalyzer analyzer; @@ -335,7 +529,7 @@ TEST(LLVMCodeAnalyzer_ListTypeAnalysis, ClearAfterWriteInLoop) appendList->args.push_back({ Compiler::StaticType::Unknown, &value }); list.addInstruction(appendList); - auto loopStart = std::make_shared(LLVMInstruction::Type::BeginWhileLoop, false); + auto loopStart = std::make_shared(LLVMInstruction::Type::BeginRepeatLoop, false); list.addInstruction(loopStart); auto appendList1 = std::make_shared(LLVMInstruction::Type::AppendToList, false); @@ -373,25 +567,38 @@ TEST(LLVMCodeAnalyzer_ListTypeAnalysis, WhileLoop) appendList->args.push_back({ Compiler::StaticType::Unknown, &value }); list.addInstruction(appendList); + auto loopCond = std::make_shared(LLVMInstruction::Type::BeginLoopCondition, false); + list.addInstruction(loopCond); + + // Read an item in loop condition + auto getItem = std::make_shared(LLVMInstruction::Type::GetListItem, false); + LLVMConstantRegister index(Compiler::StaticType::Number, 0); + getItem->targetList = &targetList; + getItem->args.push_back({ Compiler::StaticType::Number, &index }); + list.addInstruction(getItem); + + LLVMRegister sourceValue(Compiler::StaticType::Unknown); + sourceValue.isRawValue = false; + sourceValue.instruction = getItem; + getItem->functionReturnReg = &sourceValue; + auto loopStart = std::make_shared(LLVMInstruction::Type::BeginWhileLoop, false); list.addInstruction(loopStart); + // Change the type in the loop auto appendList1 = std::make_shared(LLVMInstruction::Type::AppendToList, false); LLVMConstantRegister value1(Compiler::StaticType::Number, 5); appendList1->targetList = &targetList; appendList1->args.push_back({ Compiler::StaticType::Unknown, &value1 }); list.addInstruction(appendList1); - auto clearList2 = std::make_shared(LLVMInstruction::Type::ClearList, false); - clearList2->targetList = &targetList; - list.addInstruction(clearList2); - auto loopEnd = std::make_shared(LLVMInstruction::Type::EndLoop, false); list.addInstruction(loopEnd); analyzer.analyzeScript(list); - ASSERT_EQ(appendList1->targetType, Compiler::StaticType::Bool); + ASSERT_EQ(getItem->targetType, Compiler::StaticType::Number | Compiler::StaticType::Bool); + ASSERT_EQ(appendList1->targetType, Compiler::StaticType::Number | Compiler::StaticType::Bool); } TEST(LLVMCodeAnalyzer_ListTypeAnalysis, RepeatUntilLoop) @@ -411,25 +618,38 @@ TEST(LLVMCodeAnalyzer_ListTypeAnalysis, RepeatUntilLoop) appendList->args.push_back({ Compiler::StaticType::Unknown, &value }); list.addInstruction(appendList); + auto loopCond = std::make_shared(LLVMInstruction::Type::BeginLoopCondition, false); + list.addInstruction(loopCond); + + // Read an item in loop condition + auto getItem = std::make_shared(LLVMInstruction::Type::GetListItem, false); + LLVMConstantRegister index(Compiler::StaticType::Number, 0); + getItem->targetList = &targetList; + getItem->args.push_back({ Compiler::StaticType::Number, &index }); + list.addInstruction(getItem); + + LLVMRegister sourceValue(Compiler::StaticType::Unknown); + sourceValue.isRawValue = false; + sourceValue.instruction = getItem; + getItem->functionReturnReg = &sourceValue; + auto loopStart = std::make_shared(LLVMInstruction::Type::BeginRepeatUntilLoop, false); list.addInstruction(loopStart); + // Change the type in the loop auto appendList1 = std::make_shared(LLVMInstruction::Type::AppendToList, false); LLVMConstantRegister value1(Compiler::StaticType::Number, 5); appendList1->targetList = &targetList; appendList1->args.push_back({ Compiler::StaticType::Unknown, &value1 }); list.addInstruction(appendList1); - auto clearList2 = std::make_shared(LLVMInstruction::Type::ClearList, false); - clearList2->targetList = &targetList; - list.addInstruction(clearList2); - auto loopEnd = std::make_shared(LLVMInstruction::Type::EndLoop, false); list.addInstruction(loopEnd); analyzer.analyzeScript(list); - ASSERT_EQ(appendList1->targetType, Compiler::StaticType::Bool); + ASSERT_EQ(getItem->targetType, Compiler::StaticType::Number | Compiler::StaticType::Bool); + ASSERT_EQ(appendList1->targetType, Compiler::StaticType::Number | Compiler::StaticType::Bool); } TEST(LLVMCodeAnalyzer_ListTypeAnalysis, LoopMultipleWrites_UnknownType) diff --git a/test/llvm/code_analyzer/variable_type_analysis.cpp b/test/llvm/code_analyzer/variable_type_analysis.cpp index 7e7e1124..a09c4eb0 100644 --- a/test/llvm/code_analyzer/variable_type_analysis.cpp +++ b/test/llvm/code_analyzer/variable_type_analysis.cpp @@ -144,9 +144,23 @@ TEST(LLVMCodeAnalyzer_VariableTypeAnalysis, WhileLoop) setVar->args.push_back({ Compiler::StaticType::Unknown, &value }); list.addInstruction(setVar); + auto loopCond = std::make_shared(LLVMInstruction::Type::BeginLoopCondition, false); + list.addInstruction(loopCond); + + // Read the variable in loop condition + auto readVar = std::make_shared(LLVMInstruction::Type::ReadVariable, true); + readVar->targetVariable = &var; + list.addInstruction(readVar); + + LLVMRegister varValue(Compiler::StaticType::Unknown); + varValue.isRawValue = false; + varValue.instruction = readVar; + readVar->functionReturnReg = &varValue; + auto loopStart = std::make_shared(LLVMInstruction::Type::BeginWhileLoop, false); list.addInstruction(loopStart); + // Change the type in the loop auto setVar1 = std::make_shared(LLVMInstruction::Type::WriteVariable, false); LLVMConstantRegister value1(Compiler::StaticType::Number, 5); setVar1->targetVariable = &var; @@ -158,6 +172,7 @@ TEST(LLVMCodeAnalyzer_VariableTypeAnalysis, WhileLoop) analyzer.analyzeScript(list); + ASSERT_EQ(readVar->targetType, Compiler::StaticType::Number | Compiler::StaticType::Bool); ASSERT_EQ(setVar1->targetType, Compiler::StaticType::Number | Compiler::StaticType::Bool); } @@ -174,9 +189,23 @@ TEST(LLVMCodeAnalyzer_VariableTypeAnalysis, RepeatUntilLoop) setVar->args.push_back({ Compiler::StaticType::Unknown, &value }); list.addInstruction(setVar); + auto loopCond = std::make_shared(LLVMInstruction::Type::BeginLoopCondition, false); + list.addInstruction(loopCond); + + // Read the variable in loop condition + auto readVar = std::make_shared(LLVMInstruction::Type::ReadVariable, true); + readVar->targetVariable = &var; + list.addInstruction(readVar); + + LLVMRegister varValue(Compiler::StaticType::Unknown); + varValue.isRawValue = false; + varValue.instruction = readVar; + readVar->functionReturnReg = &varValue; + auto loopStart = std::make_shared(LLVMInstruction::Type::BeginRepeatUntilLoop, false); list.addInstruction(loopStart); + // Change the type in the loop auto setVar1 = std::make_shared(LLVMInstruction::Type::WriteVariable, false); LLVMConstantRegister value1(Compiler::StaticType::Number, 5); setVar1->targetVariable = &var; @@ -188,9 +217,45 @@ TEST(LLVMCodeAnalyzer_VariableTypeAnalysis, RepeatUntilLoop) analyzer.analyzeScript(list); + ASSERT_EQ(readVar->targetType, Compiler::StaticType::Number | Compiler::StaticType::Bool); ASSERT_EQ(setVar1->targetType, Compiler::StaticType::Number | Compiler::StaticType::Bool); } +TEST(LLVMCodeAnalyzer_VariableTypeAnalysis, ProcedureCallInLoop) +{ + LLVMCodeAnalyzer analyzer; + LLVMInstructionList list; + Variable var("", ""); + + auto setVar1 = std::make_shared(LLVMInstruction::Type::WriteVariable, false); + LLVMConstantRegister value1(Compiler::StaticType::Number, 1.25); + setVar1->targetVariable = &var; + setVar1->args.push_back({ Compiler::StaticType::Unknown, &value1 }); + list.addInstruction(setVar1); + + auto loopStart = std::make_shared(LLVMInstruction::Type::BeginRepeatLoop, false); + list.addInstruction(loopStart); + + auto setVar2 = std::make_shared(LLVMInstruction::Type::WriteVariable, false); + LLVMConstantRegister value2(Compiler::StaticType::Number, 5); + setVar2->targetVariable = &var; + setVar2->args.push_back({ Compiler::StaticType::Unknown, &value2 }); + list.addInstruction(setVar2); + + auto procCall = std::make_shared(LLVMInstruction::Type::CallProcedure, false); + list.addInstruction(procCall); + + auto loopEnd = std::make_shared(LLVMInstruction::Type::EndLoop, false); + list.addInstruction(loopEnd); + + analyzer.analyzeScript(list); + + ASSERT_EQ(setVar1->targetType, Compiler::StaticType::Unknown); + + // Type unknown due to procedure call + ASSERT_EQ(setVar2->targetType, Compiler::StaticType::Unknown); +} + TEST(LLVMCodeAnalyzer_VariableTypeAnalysis, LoopMultipleWrites_UnknownType) { LLVMCodeAnalyzer analyzer; @@ -594,6 +659,147 @@ TEST(LLVMCodeAnalyzer_VariableTypeAnalysis, WriteBeforeIfElse) ASSERT_EQ(setVar3->targetType, Compiler::StaticType::Number | Compiler::StaticType::String); } +TEST(LLVMCodeAnalyzer_VariableTypeAnalysis, WriteInIfStatement_IfBranch) +{ + LLVMCodeAnalyzer analyzer; + LLVMInstructionList list; + Variable var("", ""); + + auto ifStart = std::make_shared(LLVMInstruction::Type::BeginIf, false); + list.addInstruction(ifStart); + + auto setVarInIfStatement = std::make_shared(LLVMInstruction::Type::WriteVariable, false); + LLVMConstantRegister valueInIfStatement(Compiler::StaticType::Number, 42); + setVarInIfStatement->targetVariable = &var; + setVarInIfStatement->args.push_back({ Compiler::StaticType::Unknown, &valueInIfStatement }); + list.addInstruction(setVarInIfStatement); + + auto ifEnd = std::make_shared(LLVMInstruction::Type::EndIf, false); + list.addInstruction(ifEnd); + + auto setVar = std::make_shared(LLVMInstruction::Type::WriteVariable, false); + LLVMConstantRegister value(Compiler::StaticType::Bool, false); + setVar->targetVariable = &var; + setVar->args.push_back({ Compiler::StaticType::Unknown, &value }); + list.addInstruction(setVar); + + analyzer.analyzeScript(list); + + ASSERT_EQ(setVarInIfStatement->targetType, Compiler::StaticType::Unknown); + + // The type is Unknown because the if statement might not run at all + ASSERT_EQ(setVar->targetType, Compiler::StaticType::Unknown); +} + +TEST(LLVMCodeAnalyzer_VariableTypeAnalysis, WriteInIfStatement_ElseBranch) +{ + LLVMCodeAnalyzer analyzer; + LLVMInstructionList list; + Variable var("", ""); + + auto ifStart = std::make_shared(LLVMInstruction::Type::BeginIf, false); + list.addInstruction(ifStart); + + auto elseStart = std::make_shared(LLVMInstruction::Type::BeginElse, false); + list.addInstruction(elseStart); + + auto setVarInIfStatement = std::make_shared(LLVMInstruction::Type::WriteVariable, false); + LLVMConstantRegister valueInIfStatement(Compiler::StaticType::Number, 42); + setVarInIfStatement->targetVariable = &var; + setVarInIfStatement->args.push_back({ Compiler::StaticType::Unknown, &valueInIfStatement }); + list.addInstruction(setVarInIfStatement); + + auto ifEnd = std::make_shared(LLVMInstruction::Type::EndIf, false); + list.addInstruction(ifEnd); + + auto setVar = std::make_shared(LLVMInstruction::Type::WriteVariable, false); + LLVMConstantRegister value(Compiler::StaticType::Bool, false); + setVar->targetVariable = &var; + setVar->args.push_back({ Compiler::StaticType::Unknown, &value }); + list.addInstruction(setVar); + + analyzer.analyzeScript(list); + + ASSERT_EQ(setVarInIfStatement->targetType, Compiler::StaticType::Unknown); + + // The type is Unknown because the if statement might not run at all + ASSERT_EQ(setVar->targetType, Compiler::StaticType::Unknown); +} + +TEST(LLVMCodeAnalyzer_VariableTypeAnalysis, WriteInIfElse) +{ + LLVMCodeAnalyzer analyzer; + LLVMInstructionList list; + Variable var("", ""); + + auto ifStart = std::make_shared(LLVMInstruction::Type::BeginIf, false); + list.addInstruction(ifStart); + + auto setVar1 = std::make_shared(LLVMInstruction::Type::WriteVariable, false); + LLVMConstantRegister value1(Compiler::StaticType::Number, 42); + setVar1->targetVariable = &var; + setVar1->args.push_back({ Compiler::StaticType::Unknown, &value1 }); + list.addInstruction(setVar1); + + auto elseStart = std::make_shared(LLVMInstruction::Type::BeginElse, false); + list.addInstruction(elseStart); + + auto setVar2 = std::make_shared(LLVMInstruction::Type::WriteVariable, false); + LLVMConstantRegister value2(Compiler::StaticType::String, "test"); + setVar2->targetVariable = &var; + setVar2->args.push_back({ Compiler::StaticType::Unknown, &value2 }); + list.addInstruction(setVar2); + + auto ifEnd = std::make_shared(LLVMInstruction::Type::EndIf, false); + list.addInstruction(ifEnd); + + auto setVar3 = std::make_shared(LLVMInstruction::Type::WriteVariable, false); + LLVMConstantRegister value3(Compiler::StaticType::Bool, true); + setVar3->targetVariable = &var; + setVar3->args.push_back({ Compiler::StaticType::Unknown, &value3 }); + list.addInstruction(setVar3); + + analyzer.analyzeScript(list); + + ASSERT_EQ(setVar1->targetType, Compiler::StaticType::Unknown); + ASSERT_EQ(setVar2->targetType, Compiler::StaticType::Unknown); + + // The type is Number | String because any of the branches may run + ASSERT_EQ(setVar3->targetType, Compiler::StaticType::Number | Compiler::StaticType::String); +} + +TEST(LLVMCodeAnalyzer_VariableTypeAnalysis, WriteInLoop) +{ + LLVMCodeAnalyzer analyzer; + LLVMInstructionList list; + Variable var("", ""); + + auto loopStart = std::make_shared(LLVMInstruction::Type::BeginRepeatLoop, false); + list.addInstruction(loopStart); + + auto setVarInLoop = std::make_shared(LLVMInstruction::Type::WriteVariable, false); + LLVMConstantRegister valueInLoop(Compiler::StaticType::Number, 42); + setVarInLoop->targetVariable = &var; + setVarInLoop->args.push_back({ Compiler::StaticType::Unknown, &valueInLoop }); + list.addInstruction(setVarInLoop); + + auto loopEnd = std::make_shared(LLVMInstruction::Type::EndLoop, false); + list.addInstruction(loopEnd); + + auto setVar = std::make_shared(LLVMInstruction::Type::WriteVariable, false); + LLVMConstantRegister value(Compiler::StaticType::Bool, false); + setVar->targetVariable = &var; + setVar->args.push_back({ Compiler::StaticType::Unknown, &value }); + list.addInstruction(setVar); + + analyzer.analyzeScript(list); + + ASSERT_EQ(setVarInLoop->targetType, Compiler::StaticType::Unknown); + + // The type is Unknown because the loop might not run at all + ASSERT_EQ(setVar->targetType, Compiler::StaticType::Unknown); +} + TEST(LLVMCodeAnalyzer_VariableTypeAnalysis, ComplexNestedControlFlow) { LLVMCodeAnalyzer analyzer; diff --git a/test/llvm/operators/math/mod_test.cpp b/test/llvm/operators/math/mod_test.cpp index 93190c86..e8c9e38b 100644 --- a/test/llvm/operators/math/mod_test.cpp +++ b/test/llvm/operators/math/mod_test.cpp @@ -159,6 +159,26 @@ TEST_F(LLVMModTest, NegativeFiveModZero_Const) ASSERT_NUM_OP2(m_utils, LLVMTestUtils::OpType::Mod, true, -5, 0); } +TEST_F(LLVMModTest, PositiveDecimalModZero) +{ + ASSERT_NUM_OP2(m_utils, LLVMTestUtils::OpType::Mod, false, 5.8, 0); +} + +TEST_F(LLVMModTest, PositiveDecimalModZero_Const) +{ + ASSERT_NUM_OP2(m_utils, LLVMTestUtils::OpType::Mod, true, 5.8, 0); +} + +TEST_F(LLVMModTest, NegativeDecimalModZero) +{ + ASSERT_NUM_OP2(m_utils, LLVMTestUtils::OpType::Mod, false, -5.8, 0); +} + +TEST_F(LLVMModTest, NegativeDecimalModZero_Const) +{ + ASSERT_NUM_OP2(m_utils, LLVMTestUtils::OpType::Mod, true, -5.8, 0); +} + TEST_F(LLVMModTest, NegativeDecimalModInfinity) { ASSERT_NUM_OP2(m_utils, LLVMTestUtils::OpType::Mod, false, -2.5, "Infinity"); diff --git a/test/llvm/operators/math/round_test.cpp b/test/llvm/operators/math/round_test.cpp index e5bb7b45..f1e3d4cd 100644 --- a/test/llvm/operators/math/round_test.cpp +++ b/test/llvm/operators/math/round_test.cpp @@ -19,6 +19,16 @@ TEST_F(LLVMRoundTest, FourPointZero_Const) ASSERT_EQ(m_utils.getOpResult(LLVMTestUtils::OpType::Round, true, 4.0).toDouble(), 4.0); } +TEST_F(LLVMRoundTest, NegativeFourPointZero) +{ + ASSERT_EQ(m_utils.getOpResult(LLVMTestUtils::OpType::Round, false, -4.0).toDouble(), -4.0); +} + +TEST_F(LLVMRoundTest, NegativeFourPointZero_Const) +{ + ASSERT_EQ(m_utils.getOpResult(LLVMTestUtils::OpType::Round, true, -4.0).toDouble(), -4.0); +} + TEST_F(LLVMRoundTest, ThreePointTwo) { ASSERT_EQ(m_utils.getOpResult(LLVMTestUtils::OpType::Round, false, 3.2).toDouble(), 3.0); diff --git a/test/mocks/codebuildermock.h b/test/mocks/codebuildermock.h index 6ff635c1..1d9225a2 100644 --- a/test/mocks/codebuildermock.h +++ b/test/mocks/codebuildermock.h @@ -88,6 +88,7 @@ class CodeBuilderMock : public ICodeBuilder MOCK_METHOD(void, yield, (), (override)); MOCK_METHOD(void, createStop, (), (override)); + MOCK_METHOD(void, createStopWithoutSync, (), (override)); MOCK_METHOD(void, createProcedureCall, (BlockPrototype *, const Compiler::Args &), (override)); }; diff --git a/test/scratch_classes/value_test.cpp b/test/scratch_classes/value_test.cpp index 12c550f3..6851fb89 100644 --- a/test/scratch_classes/value_test.cpp +++ b/test/scratch_classes/value_test.cpp @@ -3387,6 +3387,57 @@ TEST(ValueTest, StringToDouble) ASSERT_EQ(value_stringToDouble(string_pool_new_assign("0b-1")), 0.0); } +TEST(ValueTest, StringToDoubleWithCheck) +{ + bool ok; + + ASSERT_EQ(value_stringToDoubleWithCheck(string_pool_new_assign("2147483647"), &ok), 2147483647.0); + ASSERT_TRUE(ok); + + ASSERT_EQ(value_stringToDoubleWithCheck(string_pool_new_assign("-255.625"), &ok), -255.625); + ASSERT_TRUE(ok); + + ASSERT_EQ(value_stringToDoubleWithCheck(string_pool_new_assign("0"), &ok), 0.0); + ASSERT_TRUE(ok); + + ASSERT_EQ(value_stringToDoubleWithCheck(string_pool_new_assign("-0"), &ok), -0.0); + ASSERT_TRUE(ok); + + ASSERT_EQ(value_stringToDoubleWithCheck(string_pool_new_assign("+.15"), &ok), 0.15); + ASSERT_TRUE(ok); + + ASSERT_EQ(value_stringToDoubleWithCheck(string_pool_new_assign("0+5"), &ok), 0.0); + ASSERT_FALSE(ok); + + ASSERT_EQ(value_stringToDoubleWithCheck(string_pool_new_assign("0-5"), &ok), 0.0); + ASSERT_FALSE(ok); + + ASSERT_EQ(value_stringToDoubleWithCheck(string_pool_new_assign("1 2 3"), &ok), 0.0); + ASSERT_FALSE(ok); + + ASSERT_EQ(value_stringToDoubleWithCheck(string_pool_new_assign("false"), &ok), 0.0); + ASSERT_FALSE(ok); + + ASSERT_EQ(value_stringToDoubleWithCheck(string_pool_new_assign("true"), &ok), 0.0); + ASSERT_FALSE(ok); + + double result = value_stringToDoubleWithCheck(string_pool_new_assign("Infinity"), &ok); + ASSERT_GT(result, 0); + ASSERT_TRUE(std::isinf(result)); + ASSERT_TRUE(ok); + + result = value_stringToDoubleWithCheck(string_pool_new_assign("-Infinity"), &ok); + ASSERT_LT(result, 0); + ASSERT_TRUE(std::isinf(result)); + ASSERT_TRUE(ok); + + ASSERT_EQ(value_stringToDoubleWithCheck(string_pool_new_assign("NaN"), &ok), 0.0); + ASSERT_FALSE(ok); + + ASSERT_EQ(value_stringToDoubleWithCheck(string_pool_new_assign("something"), &ok), 0.0); + ASSERT_FALSE(ok); +} + TEST(ValueTest, StringToBool) { ASSERT_TRUE(value_stringToBool(string_pool_new_assign("2147483647")));