Support validation set and FedEM for MF datasets#310
Support validation set and FedEM for MF datasets#310yxdyc wants to merge 3 commits intoalibaba:masterfrom
Conversation
DavdGao
left a comment
There was a problem hiding this comment.
Please see the inline comments
| """ | ||
| Ensemble evaluation for matrix factorization model | ||
| """ | ||
| cur_data = ctx.cur_mode |
There was a problem hiding this comment.
Please ensure that the usage of cur_mode is correct here.
cur_mode: the type of our routine, chosen from"train"/"test"/"val"/"finetune"cur_split: the chosen data split
Besides, do we still need to name the variables withcur_data, since they are all removed at the end of the routine.
There was a problem hiding this comment.
fixed, here we should use cur_split
| # set the eval_metrics | ||
| if ctx.num_samples == 0: | ||
| results = { | ||
| f"{cur_data}_avg_loss": ctx.get( |
There was a problem hiding this comment.
The metric calculator uses cur_split instead, please check if it's correct to use cur_data(actually cur_mode)
There was a problem hiding this comment.
fixed as above replied
| } | ||
| else: | ||
| results = { | ||
| f"{ctx.cur_mode}_avg_loss": ctx.get( |
There was a problem hiding this comment.
it's a little confused to use ctx.cur_mode here, since we use cur_data in line 236.
| else: | ||
| self._split_n_clients_rating_vmf(ratings, num_client, split) | ||
|
|
||
| def _split_n_clients_rating_hmf(self, ratings: csc_matrix, num_client: int, |
There was a problem hiding this comment.
Since the class HMFDataset and VMFDataset also have the function _split_n_clients_rating for HMF and VMF resepectively, maybe we don't need the functions _split_n_clients_rating_hmf and _split_n_clients_rating_vmf here?
There was a problem hiding this comment.
deleted it in the new pr
| } | ||
| self.data = data | ||
|
|
||
| def _split_n_clients_rating_vmf(self, ratings: csc_matrix, num_client: int, |
There was a problem hiding this comment.
deleted it in the new pr
| dtype=torch.float32).to_dense() | ||
|
|
||
| return mask * pred, label, float(np.prod(pred.size())) / len(ratings) | ||
| return mask * pred, label, torch.Tensor( |
There was a problem hiding this comment.
Why do we convert it to a Tensor, and do we need to consider the device of the Tensor?
There was a problem hiding this comment.
Here the conversion is for flop counting. The device is not important since after counting the flop, the tensor will be discarded.
| if ctx.get("num_samples") == 0: | ||
| results = { | ||
| f"{ctx.cur_mode}_avg_loss": ctx.get( | ||
| "loss_batch_total_{}".format(ctx.cur_mode)), |
There was a problem hiding this comment.
It's a little confused that in line 53, we use loss_batch_total_{ctx.cur_mode}, while in line 58 it is ctx.loss_batch_total
There was a problem hiding this comment.
changed into loss_batch_total_{ctx.cur_mode} in line 58
|
|
||
| if self.cfg.federate.method.lower() in ["fedem"]: | ||
| # cache label for evaluation ensemble | ||
| ctx.get("{}_y_true".format(ctx.cur_mode)).append( |
There was a problem hiding this comment.
The attribute y_true is a matrix here and can be very large for MF dataset, I'm not sure it's appropriate to storage all the labels and probs
There was a problem hiding this comment.
The appended one is sparse csr_matrix
| """ | ||
| def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int, | ||
| test_portion: float): | ||
| split: list): |
There was a problem hiding this comment.
How about enabling this change to FedNetflix?
There was a problem hiding this comment.
FedNetflix is inherited from MovieLensData, thus this change should be valid to FedNetflix
yxdyc
left a comment
There was a problem hiding this comment.
modified according to the comments
| """ | ||
| Ensemble evaluation for matrix factorization model | ||
| """ | ||
| cur_data = ctx.cur_mode |
There was a problem hiding this comment.
fixed, here we should use cur_split
| # set the eval_metrics | ||
| if ctx.num_samples == 0: | ||
| results = { | ||
| f"{cur_data}_avg_loss": ctx.get( |
There was a problem hiding this comment.
fixed as above replied
| dtype=torch.float32).to_dense() | ||
|
|
||
| return mask * pred, label, float(np.prod(pred.size())) / len(ratings) | ||
| return mask * pred, label, torch.Tensor( |
There was a problem hiding this comment.
Here the conversion is for flop counting. The device is not important since after counting the flop, the tensor will be discarded.
|
|
||
| if self.cfg.federate.method.lower() in ["fedem"]: | ||
| # cache label for evaluation ensemble | ||
| ctx.get("{}_y_true".format(ctx.cur_mode)).append( |
There was a problem hiding this comment.
The appended one is sparse csr_matrix
| } | ||
| else: | ||
| results = { | ||
| f"{ctx.cur_mode}_avg_loss": ctx.get( |
| if ctx.get("num_samples") == 0: | ||
| results = { | ||
| f"{ctx.cur_mode}_avg_loss": ctx.get( | ||
| "loss_batch_total_{}".format(ctx.cur_mode)), |
There was a problem hiding this comment.
changed into loss_batch_total_{ctx.cur_mode} in line 58
| """ | ||
| def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int, | ||
| test_portion: float): | ||
| split: list): |
There was a problem hiding this comment.
FedNetflix is inherited from MovieLensData, thus this change should be valid to FedNetflix
| else: | ||
| self._split_n_clients_rating_vmf(ratings, num_client, split) | ||
|
|
||
| def _split_n_clients_rating_hmf(self, ratings: csc_matrix, num_client: int, |
There was a problem hiding this comment.
deleted it in the new pr
| } | ||
| self.data = data | ||
|
|
||
| def _split_n_clients_rating_vmf(self, ratings: csc_matrix, num_client: int, |
There was a problem hiding this comment.
deleted it in the new pr
as the title says. Please double check the modifications related to MF. Thanks @rayrayraykk @DavdGao