1
+ #include " KaleidoscopeJIT.h"
1
2
#include " llvm/ADT/APFloat.h"
2
3
#include " llvm/ADT/STLExtras.h"
3
4
#include " llvm/IR/BasicBlock.h"
24
25
#include < cstdio>
25
26
#include < cstdlib>
26
27
#include < unordered_map>
27
- #include " ./KaleidoscopeJIT.h"
28
28
29
29
using namespace llvm ;
30
30
31
+ using namespace llvm ::orc;
32
+
31
33
#define IRGEN true
32
34
33
35
// ---Lexer---
@@ -119,6 +121,9 @@ static std::unique_ptr<Module> TheModule; // to hold blocks, definitions? (TODO)
119
121
static std::unique_ptr<IRBuilder<>> Builder; // for creating instructions, constants, etc
120
122
static std::unique_ptr<legacy::FunctionPassManager> TheFPM; // Function pass manager
121
123
static std::unordered_map<std::string, Value *> Symbols; // Maps names inside function context to LLVM "values"
124
+ static std::unique_ptr<KaleidoscopeJIT> TheJIT; // JIT engine for Kaleidoscope
125
+ // Prototypes will be codegened in _each_ module, again and again? TODO: check
126
+ static ExitOnError ExitOnErr;
122
127
123
128
124
129
@@ -658,7 +663,7 @@ static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
658
663
659
664
if (auto E = ParseExpression ()) {
660
665
661
- auto Proto = std::make_unique<PrototypeAST>(" " , std::vector<std::string>());
666
+ auto Proto = std::make_unique<PrototypeAST>(" __anon_expr " , std::vector<std::string>());
662
667
663
668
// fprintf(stderr, "debug: toplevelexpr\n");
664
669
return std::make_unique<FunctionAST>(std::move (Proto), std::move (E));
@@ -667,8 +672,53 @@ static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
667
672
return nullptr ;
668
673
}
669
674
675
+ static std::unordered_map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos; // Function Name -> PrototypeAST Node map
676
+
670
677
// -- Code Generator --
671
678
679
+ static void InitializeModuleAndPassManager () {
680
+ TheContext = std::make_unique<LLVMContext>();
681
+ TheModule = std::make_unique<Module>(" kaleidoscope" , *TheContext);
682
+ TheModule->setDataLayout (TheJIT->getDataLayout ());
683
+
684
+ Builder = std::make_unique<IRBuilder<>>(*TheContext);
685
+
686
+ // Why .get? Ahh- I want to pass a pointer. What about uniqueness?
687
+ TheFPM = std::make_unique<legacy::FunctionPassManager>(TheModule.get ());
688
+
689
+ // Peephole optimizations
690
+ TheFPM->add (createInstructionCombiningPass ());
691
+
692
+ // ?
693
+ TheFPM->add (createReassociatePass ());
694
+
695
+ // Global value numbering-> common subexpression elimination. Global is actually per-function
696
+ TheFPM->add (createGVNPass ());
697
+
698
+ // Dead code elimination pass;
699
+ TheFPM->add (createCFGSimplificationPass ());
700
+
701
+ // Run initalizers for all passes added to pass manager
702
+ TheFPM->doInitialization ();
703
+ }
704
+
705
+ Function *getOrCreateFunction (const std::string& Name) {
706
+ // Check whether declaration is present in current module
707
+ if (auto *F = TheModule->getFunction (Name)) {
708
+ // Hypothesis: When each function is created in a new module, this will never happen
709
+ return F;
710
+ }
711
+
712
+ // Check whether this function has been declared previously
713
+ auto F_itr = FunctionProtos.find (Name);
714
+ if (F_itr != FunctionProtos.end ()) {
715
+ // If yes, codegen declaration to _this module_.
716
+ return F_itr->second ->codegen ();
717
+ }
718
+
719
+ return nullptr ;
720
+ }
721
+
672
722
// Create a new constant of type "double"
673
723
Value* NumExprAST::codegen () {
674
724
return ConstantFP::get (Builder->getDoubleTy (), Val);
@@ -688,7 +738,7 @@ Value* VariableExprAST::codegen() {
688
738
Value* CallExprAST::codegen () {
689
739
690
740
// Obtain function with name `Callee` from Module
691
- Function *func = TheModule-> getFunction (Callee);
741
+ Function *func = getOrCreateFunction (Callee);
692
742
if (!func) {
693
743
return LogErrorV ((std::string (" undefined function: " ) + Callee).c_str ());
694
744
}
@@ -776,11 +826,13 @@ Function* FunctionAST::codegen() {
776
826
777
827
// TODO: why are we doing this? This codegen method will never be called
778
828
// for an extern function, right? Why else do I need to check?
779
- Function *func = TheModule-> getFunction ( Proto->GetName () );
829
+ const std::string& func_name = Proto->GetName ();
780
830
781
- if (!func) {
782
- func = Proto->codegen ();
783
- }
831
+ // Make global FunctionProto map the owner of function prototype node
832
+ // This ensures that declaration can be codegened in different modules
833
+ FunctionProtos[func_name] = std::move (Proto);
834
+
835
+ Function *func = getOrCreateFunction (func_name);
784
836
785
837
if (!func) {
786
838
return nullptr ;
@@ -832,6 +884,13 @@ static void HandleDefinition() {
832
884
#if IRGEN
833
885
if (Function *func = def->codegen ()) {
834
886
func->print (errs ());
887
+
888
+ ExitOnErr (TheJIT->addModule (
889
+ ThreadSafeModule (std::move (TheModule), std::move (TheContext))
890
+ ));
891
+
892
+ InitializeModuleAndPassManager ();
893
+
835
894
fprintf (stderr, " \n " );
836
895
fprintf (stderr, " Read a function definition\n " );
837
896
}
@@ -844,7 +903,7 @@ static void HandleDefinition() {
844
903
845
904
846
905
static void HandleExtern () {
847
- if (const auto extn = ParseExtern ()) {
906
+ if (auto extn = ParseExtern ()) {
848
907
849
908
#if DEBUGPARSE
850
909
LispPrintVisitor lvt;
@@ -856,6 +915,7 @@ static void HandleExtern() {
856
915
func->print (errs ());
857
916
fprintf (stderr, " \n " );
858
917
fprintf (stderr, " Read an extern\n " );
918
+ FunctionProtos[extn->GetName ()] = std::move (extn);
859
919
}
860
920
#endif
861
921
@@ -878,7 +938,30 @@ static void HandleTopLevelExpression() {
878
938
fprintf (stderr, " \n " );
879
939
fprintf (stderr, " Parsed a top level expression\n " );
880
940
881
- func->eraseFromParent ();
941
+ // TODO: how do I know which functions to call? In this case, I have the
942
+ // tutorial for reference. What if I don't know what does what?
943
+ auto RT = TheJIT->getMainJITDylib ().createResourceTracker ();
944
+
945
+ // TODO: wasn't the context supposed to be unique for the
946
+ // program? If this context is now owned by the JIT, then
947
+ // will each top level expression (and even each function)
948
+ // be created in a new context?
949
+ auto TSM = ThreadSafeModule (std::move (TheModule), std::move (TheContext));
950
+
951
+ ExitOnErr (TheJIT->addModule (std::move (TSM), RT));
952
+
953
+ // Now, the next function will be placed in a new Module?
954
+ InitializeModuleAndPassManager ();
955
+
956
+ auto ExprSymbol = ExitOnErr (TheJIT->lookup (" __anon_expr" ));
957
+
958
+ // TODO: why do I need intptr_t
959
+ double (*Fn)() = (double (*)())(intptr_t )ExprSymbol.getAddress ();
960
+
961
+ fprintf (stderr, " Evaluated to %lf\n " , Fn ());
962
+
963
+ ExitOnErr (RT->remove ());
964
+
882
965
}
883
966
#endif
884
967
@@ -887,28 +970,6 @@ static void HandleTopLevelExpression() {
887
970
}
888
971
}
889
972
890
- static void InitializeModuleAndPassManager () {
891
- TheContext = std::make_unique<LLVMContext>();
892
- TheModule = std::make_unique<Module>(" kaleidoscope" , *TheContext);
893
- // Why .get? Ahh- I want to pass a pointer. What about uniqueness?
894
- TheFPM = std::make_unique<legacy::FunctionPassManager>(TheModule.get ());
895
- Builder = std::make_unique<IRBuilder<>>(*TheContext);
896
-
897
- // Peephole optimizations
898
- TheFPM->add (createInstructionCombiningPass ());
899
-
900
- // ?
901
- TheFPM->add (createReassociatePass ());
902
-
903
- // Global value numbering-> common subexpression elimination. Global is actually per-function
904
- TheFPM->add (createGVNPass ());
905
-
906
- // Dead code elimination pass;
907
- TheFPM->add (createCFGSimplificationPass ());
908
-
909
- // Run initalizers for all passes added to pass manager
910
- TheFPM->doInitialization ();
911
- }
912
973
913
974
// top = definition | expression | external | ;
914
975
static void MainLoop () {
@@ -964,31 +1025,37 @@ int oldmain(void) {
964
1025
}
965
1026
966
1027
int main (void ) {
967
- BinopPrecedence[' >' ] = 10 ;
968
- BinopPrecedence[' <' ] = 10 ;
969
- BinopPrecedence[' +' ] = 20 ;
970
- BinopPrecedence[' -' ] = 20 ;
971
- BinopPrecedence[' *' ] = 40 ;
972
- BinopPrecedence[' /' ] = 40 ;
973
1028
974
- fprintf (stderr, " ready>" );
975
- getNextToken ();
1029
+ InitializeNativeTarget ();
1030
+ InitializeNativeTargetAsmParser ();
1031
+ InitializeNativeTargetAsmPrinter ();
1032
+
1033
+ BinopPrecedence[' >' ] = 10 ;
1034
+ BinopPrecedence[' <' ] = 10 ;
1035
+ BinopPrecedence[' +' ] = 20 ;
1036
+ BinopPrecedence[' -' ] = 20 ;
1037
+ BinopPrecedence[' *' ] = 40 ;
1038
+ BinopPrecedence[' /' ] = 40 ;
1039
+
1040
+ fprintf (stderr, " ready>" );
1041
+ getNextToken ();
976
1042
977
1043
#if IRGEN
978
- InitializeModuleAndPassManager ();
1044
+ TheJIT = ExitOnErr (KaleidoscopeJIT::Create ());
1045
+ InitializeModuleAndPassManager ();
979
1046
#endif
980
1047
981
- MainLoop ();
982
- // oldmain();
1048
+ MainLoop ();
1049
+ // oldmain();
983
1050
984
1051
#if IRGEN
985
- verifyModule (*TheModule, &errs ());
1052
+ verifyModule (*TheModule, &errs ());
986
1053
987
- TheModule->print (errs (), nullptr );
988
- fprintf (stderr, " \n " );
1054
+ TheModule->print (errs (), nullptr );
1055
+ fprintf (stderr, " \n " );
989
1056
#endif
990
1057
991
- return 0 ;
1058
+ return 0 ;
992
1059
}
993
1060
994
1061
0 commit comments