@@ -167,8 +167,8 @@ void RModel::AddOperator(std::unique_ptr<ROperator> op, int order_execution) {
167167 order_execution = fOperators .size ()-1 ;
168168 }
169169
170- // storing the last usage of tensors which are input to
171- // operators (but are not inputs to the model or they are not initialized )
170+ // storing the last usage of tensors which are input to the operator
171+ // (excluding tensors which are inputs to the model or the initialized (weights) tensors )
172172 // We call this function during parsing so we don't have yet initialized the operators
173173 for (size_t index = 0 ; index<op_input_tensors.size () &&
174174 fInitializedTensors .find (UTILITY::Clean_name (std::string (op_input_tensors[index]))) == fInitializedTensors .end () &&
@@ -208,10 +208,24 @@ void RModel::AddShapeTensor(const std::string & name, const std::vector<Dim> & s
208208 fShapeTensors [tensor_name] = std::make_pair (shape_values, scalar);
209209}
210210
211+ void RModel::AddAliasTensor (const std::string & name, const std::string & origin){
212+ // add an alias tensor to origin
213+ auto tensor_name = UTILITY::Clean_name (name);
214+ auto origin_name = UTILITY::Clean_name (origin);
215+ if (fAliasTensors .count (tensor_name) != 0 ) {
216+ throw std::runtime_error (" TMVA-SOFIE: alias tensor with name " + tensor_name + " already exists \n " );
217+ }
218+ fAliasTensors [tensor_name] = origin_name;
219+ }
220+
211221bool RModel::IsShapeTensor (const std::string & tensor_name) const {
212222 return fShapeTensors .count (tensor_name) != 0 ;
213223}
214224
225+ bool RModel::IsAliasTensor (const std::string & tensor_name) const {
226+ return fAliasTensors .count (tensor_name) != 0 ;
227+ }
228+
215229const std::vector<Dim> & RModel::GetShapeTensorValues (const std::string & tensor_name) const {
216230 // if (!IsShapeTensor(tensor_name) ) return std::vector<Dim>{};
217231 return fShapeTensors .at (tensor_name).first ;
@@ -356,6 +370,11 @@ std::string RModel::AllocateIntermediateMemory(std::span<const std::string_view>
356370 fDynamicTensorInfos .find (name) != fDynamicTensorInfos .end ())
357371 continue ;
358372
373+ // case of alias tensor
374+ if (IsAliasTensor (name)) {
375+ continue ;
376+ }
377+
359378 auto tensor_size = GetTypeSize (GetTensorType (name)) * ConvertShapeToLength (GetTensorShape (name));
360379 // important fill the pair in the ordered output tensors with the string view and not the string
361380 TensorMemoryInfo tmi = {it, tensor_size};
@@ -435,9 +454,14 @@ void RModel::CheckAndFlushIntermediateMemory(std::span<const std::string_view> o
435454 chunk != fIntermediateMemoryInfo .available_stack .end (); chunk++) {
436455 if (fVerbose ) std::cout << " -- free chunk " << chunk->first << " size = " << chunk->second << std::endl;
437456 }
438- for (auto &it : op_input_tensors) {
457+ for (auto &iv : op_input_tensors) {
439458 // last occurrence of the tensor is reached => flush it from memory
440- if (fVerbose ) std::cout << " .. input tensors : " << it;
459+ if (fVerbose ) std::cout << " .. input tensors : " << iv;
460+
461+ // for alias tensors replace name with its alias
462+ std::string it{iv}; // convert view to string
463+ if (IsAliasTensor (it))
464+ it = fAliasTensors [it];
441465 if (fIntermediateTensorFrequencyLookup [it] == op_idx) {
442466 if (fVerbose ) std::cout << " flash condition is met - looping on chunks to find matching one \n " ;
443467 for (auto chunk = fIntermediateMemoryInfo .total_stack .begin ();
@@ -623,6 +647,17 @@ void RModel::Initialize(const std::map<std::string, size_t> & inputParams, bool
623647 fUseWeightFile = false ;
624648 }
625649
650+ // update fIntermediateTensorFrequencyLookup for alias tensors
651+ for (auto & it : fAliasTensors ) {
652+ if (fIntermediateTensorFrequencyLookup .find (it.first ) == fIntermediateTensorFrequencyLookup .end ()) continue ;
653+ if (fIntermediateTensorFrequencyLookup .find (it.second ) == fIntermediateTensorFrequencyLookup .end () )
654+ fIntermediateTensorFrequencyLookup [it.second ] = fIntermediateTensorFrequencyLookup [it.first ];
655+ else {
656+ // take the largest one
657+ fIntermediateTensorFrequencyLookup [it.second ] = std::max (fIntermediateTensorFrequencyLookup [it.second ],fIntermediateTensorFrequencyLookup [it.first ] );
658+ }
659+ }
660+
626661 fIsInitialized = true ;
627662}
628663
@@ -737,7 +772,8 @@ void RModel::GenerateIntermediateTensorInfo() {
737772 if (!fIntermediateTensorInfos .empty ()) {
738773 std::string tensor_declaration_block = " " ;
739774 for (auto &i : fIntermediateTensorInfos ) {
740- if (i.second .type == ETensorType::BOOL) {
775+ bool is_alias = (IsAliasTensor (i.first ));
776+ if (i.second .type == ETensorType::BOOL && !is_alias) {
741777 tensor_declaration_block += " std::vector<std::uint8_t> fTensor_" + i.first + " = std::vector<std::uint8_t>(" + std::to_string (ConvertShapeToLength (i.second .shape )) + " );\n " ;
742778 tensor_declaration_block += " std::uint8_t * tensor_" + i.first + " = fTensor_" + i.first + " .data();\n " ;
743779 continue ;
@@ -748,7 +784,7 @@ void RModel::GenerateIntermediateTensorInfo() {
748784 bool not_in_output_names =
749785 (std::find (fOutputTensorNames .begin (), fOutputTensorNames .end (), i.first ) == fOutputTensorNames .end ());
750786
751- if ((not_in_freq_map && not_in_output_names) || (!not_in_freq_map && !is_extended && not_in_output_names)) {
787+ if ((( not_in_freq_map && not_in_output_names) || (!not_in_freq_map && !is_extended && not_in_output_names) ) && !is_alias ) {
752788 size_t length = ConvertShapeToLength (i.second .shape );
753789
754790 if (i.second .type == ETensorType::FLOAT) {
@@ -767,6 +803,10 @@ void RModel::GenerateIntermediateTensorInfo() {
767803 fOtherTensorSize += 8 * length;
768804 }
769805 }
806+ if (is_alias) {
807+ tensor_declaration_block += ConvertTypeToString (i.second .type ) + " * tensor_" + i.first + " = nullptr;\n " ;
808+ }
809+
770810 }
771811
772812 if (tensor_declaration_block.length ()) {
@@ -777,19 +817,7 @@ void RModel::GenerateIntermediateTensorInfo() {
777817 if (!fDynamicTensorInfos .empty ()) {
778818 fGC += " //--- declare the dynamic tensors\n " ;
779819 for (auto &i : fDynamicTensorInfos ) {
780- if (i.second .type == ETensorType::FLOAT) {
781- // fGC += "std::vector<float> fTensor_" + i.first + ";\n";
782- fGC += " float * tensor_" + i.first + " = nullptr;\n " ;
783- } else if (i.second .type == ETensorType::DOUBLE) {
784- // fGC += "std::vector<double> fTensor_" + i.first + ";\n";
785- fGC += " double * tensor_" + i.first + " = nullptr;\n " ;
786- } else if (i.second .type == ETensorType::INT64) {
787- // fGC += "std::vector<int64_t> fTensor_" + i.first + ";\n";
788- fGC += " int64_t * tensor_" + i.first + " = nullptr;\n " ;
789- } else if (i.second .type == ETensorType::BOOL) {
790- // fGC += "std::vector<uint8_t> fTensor_" + i.first + ";\n";
791- fGC += " uint8_t * tensor_" + i.first + " = nullptr;\n " ;
792- }
820+ fGC += ConvertTypeToString (i.second .type ) + " * tensor_" + i.first + " = nullptr;\n " ;
793821 }
794822 fGC += " //--- dynamic tensors pool\n " ;
795823 fGC += " std::vector<char> fDynamicMemoryPool;\n " ;
@@ -835,9 +863,9 @@ void RModel::GenerateDynamicTensorInfo()
835863 auto op_ptr = op.get ();
836864 std::cout << " Looping on operator " << op_index << " " << typeid (*op_ptr).name () << std::endl;
837865 }
838- // check if is a dynamic tensor
866+ // check if is a dynamic tensor and not an alias tensor
839867 std::string name = std::string (it);
840- if ( fDynamicTensorInfos .find (name) != fDynamicTensorInfos .end () ) {
868+ if ( fDynamicTensorInfos .find (name) != fDynamicTensorInfos .end () && ! IsAliasTensor (name) ) {
841869 auto tensor_size = ConvertDimShapeToLength (GetDimTensorShape (name));
842870 auto type = GetTensorType (name);
843871 size_t type_size = GetTypeSize (type);
@@ -873,6 +901,7 @@ void RModel::GenerateDynamicTensorInfo()
873901 // check that all dynamic tensors are covered
874902 bool missingTensor = false ;
875903 for (auto &i : fDynamicTensorInfos ) {
904+ if (IsAliasTensor (i.first )) continue ;
876905 if (std::find (tensors.begin (), tensors.end (), std::pair<std::string,ETensorType>{i.first , i.second .type }) == tensors.end ()) {
877906 std::cout << " Dynamic tensors " << i.first << " is not in list of operator input/output " << std::endl;
878907 missingTensor = true ;
0 commit comments