Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

您好,对于代码有两个问题,请教您一下,谢谢 #65

Open
CUITCHENSIYU opened this issue Jan 21, 2022 · 4 comments
Open

Comments

@CUITCHENSIYU
Copy link

第一个问题:
在meta.py文件中
image
第82行,相当于每次更新一次Meta的参数后,下一次任务的开始,Meta的参数又会被重新初始化,那么我们在for循环结束后更新Meta参数的意义是什么呢?
第二个问题:
在Meta.py文件中
我们发现有两个“with torch.no_grad():”,那这里面的两个操作的意义是什么呢,感觉并不参与训练,更像是在记录日志

@CUITCHENSIYU
Copy link
Author

对于第一个问题,我们可以看论文中的介绍:
image
meta网络参数在循环外被随机初始化之后,在整个迭代过程中就不在被随机初始化了。

@Isuxiz
Copy link

Isuxiz commented May 31, 2022

尝试回答第一个问题:

太长不看版本:
这份代码中参数的随机初始化是在初始化Meta类实例时初始化self.net成员变量的时候做的:self.net = Learner(config, args.imgc, args.imgsz),因此整个训练过程中确实只有一次随机取初值。此初值会每个epoch被adam优化器原地优化一次,下一个epoch开始后优化后self.net成员变量的内部参数就是上一个epoch更新好的

详细:
82行的地方vars=None的含义是不使用外部传入的参数而是使用self.net内部自带的参数,而不是重新随机初始化一份参数给self.net!看下Learner类的源码就知道了:

if vars is None:
    vars = self.vars

只是单纯的让变量var指向内部的参数列表。
而真正随机初始化参数是在初始化一个Learner类的实例的时候做的,以一个线性变化层为例:

elif name is 'linear':
    # ↓初始化了一个全1的参数并将其kaiming标准化
    w = nn.Parameter(torch.ones(*param))
    torch.nn.init.kaiming_normal_(w)
    # ↓把这个全1参数加入模型的参数列表
    self.vars.append(w)
    # ↓初始化一个全0的参数(即wx+b里的b)并把它加入模型的参数列表
    self.vars.append(nn.Parameter(torch.zeros(param[0])))

上面说过了,在初始化Meta类实例的时候已经随机初始化好了self.net的参数,每个Meta实例只会初始化这一次!所以self.net自带的参数是会随着元学习的epoch不断更新的!具体过程如下:

初始化中指定了MAML的参数优化器:self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr),调用是在每次forward的最后,跑完所有任务之后会根据总loss和(这里用的其实是每个任务的平均loss,差一个常数系数不影响梯度方向)更新self.net的参数:

loss_q = losses_q[-1] / task_num
self.meta_optim.zero_grad()
loss_q.backward()
self.meta_optim.step()

因此每次forward结束后self.net的参数已经in-place更新好了,下一个epoch时82行的vars=None的含义就是使用self.net中的参数而不另外赋给它参数——也就是使用上一个epoch结束前更新好的参数,这对应了原始MAML算法的外层while循环的第8行

@sevenHsu
Copy link

line90 和 line101的两个with torch.no_grad()代码段看来确实没实际意义,后面更新meta时也没用到这两段的loss,单纯记录了下loss和correct数量。然后就是计算train acc 使用了这两个部分的correct 数量,让大家看看从随机参数的acc到迭代1、2 ... K轮后的acc变化,再到后面每个batch初始化上一个batch更新的meta参数初始化的acc变化。

@zhaoguangxu666
Copy link

你好,有一个问题想向您请教一下。
屏幕截图 2024-11-11 221422
图片中“from naive5 import Naive5”这句代码的报错请问您是如何解决的?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants