@@ -199,3 +199,58 @@ TEST_F(TrainingModuleTest, DataExternalConstantsTest) {
199199 ASSERT_EQ (attributes.find (" b" )->second .sizes ()[0 ], 2 );
200200 ASSERT_EQ (attributes.find (" b" )->second .dim (), 2 );
201201}
202+
203+ TEST_F (TrainingModuleTest, UnloadMethodTest) {
204+ const char * ptd_path = std::getenv (" ET_MODULE_TRAIN_DATA_PATH" );
205+ Result<FileDataLoader> data_map_loader_res = FileDataLoader::from (ptd_path);
206+ ASSERT_EQ (data_map_loader_res.error (), Error::Ok);
207+
208+ auto data_map_loader =
209+ std::make_unique<torch::executor::util::FileDataLoader>(
210+ std::move (data_map_loader_res.get ()));
211+
212+ const char * pte_path = std::getenv (" ET_MODULE_TRAIN_PROGRAM_PATH" );
213+ Result<FileDataLoader> pte_loader_res = FileDataLoader::from (pte_path);
214+ ASSERT_EQ (pte_loader_res.error (), Error::Ok);
215+
216+ auto pte_loader = std::make_unique<torch::executor::util::FileDataLoader>(
217+ std::move (pte_loader_res.get ()));
218+
219+ auto mod = executorch::extension::training::TrainingModule (
220+ std::move (pte_loader),
221+ nullptr ,
222+ nullptr ,
223+ nullptr ,
224+ std::move (data_map_loader));
225+
226+ auto parameters_res = mod.named_parameters (" forward" );
227+ ASSERT_EQ (parameters_res.error (), Error::Ok);
228+ auto & parameters = parameters_res.get ();
229+
230+ ASSERT_NEAR (
231+ parameters_res.get ()
232+ .find (" linear.bias" )
233+ ->second .const_data_ptr <float >()[0 ],
234+ 0.1528 ,
235+ 0.0001 );
236+
237+ // mock training
238+ auto linear_bias_ptr =
239+ parameters.find (" linear.bias" )->second .mutable_data_ptr <float >();
240+ linear_bias_ptr[0 ] += 0.5 ;
241+ ASSERT_NEAR (
242+ parameters.find (" linear.bias" )->second .const_data_ptr <float >()[0 ],
243+ 0.6528 ,
244+ 0.0001 );
245+
246+ mod.unload_method (" forward" );
247+
248+ auto new_parameters_res = mod.named_parameters (" forward" );
249+ ASSERT_EQ (new_parameters_res.error (), Error::Ok);
250+ ASSERT_NEAR (
251+ new_parameters_res.get ()
252+ .find (" linear.bias" )
253+ ->second .const_data_ptr <float >()[0 ],
254+ 0.1528 ,
255+ 0.0001 );
256+ }
0 commit comments