diff --git a/nilmtk_contrib/disaggregate/afhmm_sac.py b/nilmtk_contrib/disaggregate/afhmm_sac.py index 237064b..56ea73e 100644 --- a/nilmtk_contrib/disaggregate/afhmm_sac.py +++ b/nilmtk_contrib/disaggregate/afhmm_sac.py @@ -256,7 +256,11 @@ def disaggregate_thread(self, test_mains,index,d): prob = cvx.Problem(expression, constraints) prob.solve(solver=cvx.SCS,verbose=False, warm_start=True) - s_ = [i.value for i in cvx_state_vectors] + s_ = [ + np.zeros((len(test_mains), self.default_num_states)) if i.value is None + else i.value + for i in cvx_state_vectors + ] prediction_dict = {} for appliance_id in range(self.num_appliances): @@ -302,4 +306,4 @@ def disaggregate_chunk(self, test_mains_list): return predictions_lst - \ No newline at end of file +