Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

使用fashion_mnist数据集进行fedmgda+训练有bug #59

Open
szzzhy opened this issue Mar 7, 2024 · 2 comments
Open

使用fashion_mnist数据集进行fedmgda+训练有bug #59

szzzhy opened this issue Mar 7, 2024 · 2 comments

Comments

@szzzhy
Copy link

szzzhy commented Mar 7, 2024

首先赞美大佬给予好用的联邦学习框架赐福!

我在用fedmgda+算法进行fashion_mnist任务训练时出现了error,具体如下:

捕获

其中fedmgda+算法是根据大佬的教程复制粘贴过去的,没有有什么改动。数据分布是每个client只有一类数据,如图:

Uploading 捕获.PNG…

另外其他参数设置是
option_batch_size_10 = {'learning_rate': 0.01, 'num_steps': 1, 'num_rounds': 500, 'gpu': 1, 'batch_size': 10,
'proportion':0.1, 'seed': 0}

经过之前一系列测试,是经过标准化(gi.normalize())函数后出现了nan值,应该是标准化除以0了。

希望大佬早日修好bug,在做算法实验了所以比较急。

最后再次赞美大佬!

@szzzhy
Copy link
Author

szzzhy commented Mar 7, 2024

数据分布设置如下也出下了同样的error:
task_config_dirichlet = {'benchmark': femnist,
'partitioner': {
'name': 'DirichletPartitioner',
'para': {
'num_clients': 100, 'alpha': 0.1}}
}

大概在300多轮出现的
image

@WwZzz
Copy link
Owner

WwZzz commented Mar 8, 2024

测试了下发现是除以0导致的。为了不影响正常训练,可以把gi.normalize()那里归一化的方式替换成以下形式

        for i in range(len(grads)):
            gi_norm = 0.0
            for p in grads[i].parameters():
                gi_norm += (p**2).sum()
            grads[i] = grads[i]/(torch.sqrt(gi_norm) + 1e-8)

image

修改后我这里在提到的第一个设置下运行500轮无报错。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants