免费做电子相册的网站,it网站建设,湖北营销型网站建设,网站建设推销员话术在pytorch中停止梯度流的若干办法#xff0c;避免不必要模块的参数更新2020/4/11 FesianXu前言在现在的深度模型软件框架中#xff0c;如TensorFlow和PyTorch等等#xff0c;都是实现了自动求导机制的。在深度学习中#xff0c;有时候我们需要对某些模块的梯度流进行精确地… 在pytorch中停止梯度流的若干办法避免不必要模块的参数更新2020/4/11 FesianXu前言在现在的深度模型软件框架中如TensorFlow和PyTorch等等都是实现了自动求导机制的。在深度学习中有时候我们需要对某些模块的梯度流进行精确地控制包括是否允许某个模块的参数更新更新地幅度多少是否每个模块更新地幅度都是一样的。这些问题非常常见但是在实践中却很容易出错我们在这篇文章中尝试对第一个子问题也就是如果精确控制某些模型是否允许其参数更新进行总结。如有谬误请联系指出转载请注明出处。本文实验平台pytorch 1.4.0, ubuntu 18.04, python 3.6 联系方式e-mail: FesianXugmail.comQQ: 973926198github: https://github.com/FesianXu知乎专栏: 计算机视觉/计算机图形理论与应用微信公众号为什么我们要控制梯度流为什么我们要控制梯度流这个答案有很多个但是都可以归结为避免不需要更新的模型模块被参数更新。我们在深度模型训练过程中很可能存在多个loss比如GAN对抗生成网络存在G_loss和D_loss通常来说我们通过D_loss只希望更新判别器(Discriminator)而生成网络(Generator)并不需要也不能被更新生成网络只在通过G_loss学习的情况下才能被更新。这个时候如果我们不控制梯度流那么我们在训练D_loss的时候我们的前端网络Generator和CNN难免也会被一起训练这个是我们不期望发生的。Fig 1.1 典型的GAN结构由生成器和判别器组成。多个loss的协调只是其中一种情况还有一种情况是我们在进行模型迁移的过程中经常采用某些已经预训练好了的特征提取网络比如VGG, ResNet之类的在适用到具体的业务数据集时候特别是小数据集的时候我们可能会希望这些前端的特征提取器不要更新而只是更新末端的分类器(因为数据集很小的情况下如果贸然更新特征提取器很可能出现不期望的严重过拟合这个时候的合适做法应该是更新分类器优先)这个时候我们也可以考虑停止特征提取器的梯度流。这些情况还有很多我们在实践中发现精确控制某些模块的梯度流是非常重要的。笔者在本文中打算讨论的是对某些模块的梯度流的截断而并没有讨论对某些模块梯度流的比例缩放或者说最细粒度的梯度流控制后者我们将会在后文中讨论。一般来说截断梯度流可以有几种思路停止计算某个模块的梯度在优化过程中这个模块还是会被考虑更新然而因为梯度已经被截断了因此不能被更新。设置tensor.detach()完全截断之前的梯度流设置参数的requires_grad属性单纯不计算当前设置参数的梯度不影响梯度流torch.no_grad()效果类似于设置参数的requires_grad属性在优化器中设置不更新某个模块的参数这个模块的参数在优化过程中就不会得到更新然而这个模块的梯度在反向传播时仍然可能被计算。我们后面分别按照这两大类思路进行讨论。停止计算某个模块的梯度在本大类方法中主要涉及到了tensor.detach()和requires_grad的设置这两种都无非是对某些模块某些节点变量设置了是否需要梯度的选项。tensor.detach()tensor.detach()的作用是tensor.detach()会创建一个与原来张量共享内存空间的一个新的张量不同的是这个新的张量将不会有梯度流流过这个新的张量就像是从原先的计算图中脱离(detach)出来一样对这个新的张量进行的任何操作都不会影响到原先的计算图了。因此对此新的张量进行的梯度流也不会流过原先的计算图从而起到了截断的目的。这样说可能不够清楚我们举个例子。众所周知我们的pytorch是动态计算图网络正是因为计算图的存在才能实现自动求导机制。考虑一个表达式如果用计算图表示则如Fig 2.1所示。Fig 2.1 计算图示例考虑在这个式子的基础上加上一个分支那么计算图就变成了Fig 2.2 添加了新的分支后的计算图如果我们不detach() 中间的变量z分别对pq和w进行反向传播梯度我们会有x torch.tensor(([1.0]),requires_gradTrue)y x**2z 2*yw z**3# This is the subpath# Do not use detach()p zq torch.tensor(([2.0]), requires_gradTrue)pq p*qpq.backward(retain_graphTrue)w.backward()print(x.grad)输出结果为 tensor([56.])。我们发现这个结果是吧pq和w的反向传播结果都进行了考虑的也就是新增加的分支的反向传播影响了原先主要枝干的梯度流。这个时候我们用detach()可以把p给从原先计算图中脱离出来使得其不会干扰原先的计算图的梯度流如Fig 2.3 用了detach之后的计算图那么代码就对应地修改为x torch.tensor(([1.0]),requires_gradTrue)y x**2z 2*yw z**3# detach it, so the gradient w.r.t p does not effect z!p z.detach()q torch.tensor(([2.0]), requires_gradTrue)pq p*qpq.backward(retain_graphTrue)w.backward()print(x.grad)这个时候因为分支的梯度流已经影响不到原先的计算图梯度流了因此输出为tensor([48.])。这只是个计算图的简单例子在实际模块中我们同样可以这样用举个GAN的例子代码如 def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B fake_AB self.fake_B # fake_AB self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) self.pred_fake self.netD.forward(fake_AB.detach()) self.loss_D_fake self.criterionGAN(self.pred_fake, False) # Real real_AB self.real_B # GroundTruth # real_AB torch.cat((self.real_A, self.real_B), 1) self.pred_real self.netD.forward(real_AB) self.loss_D_real self.criterionGAN(self.pred_real, True) # Combined loss self.loss_D (self.loss_D_fake self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): # First, G(A) should fake the discriminator fake_AB self.fake_B pred_fake self.netD.forward(fake_AB) self.loss_G_GAN self.criterionGAN(pred_fake, True) # Second, G(A) B self.loss_G_L1 self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A self.loss_G self.loss_G_GAN self.loss_G_L1 self.loss_G.backward() def forward(self): self.real_A Variable(self.input_A) self.fake_B self.netG.forward(self.real_A) self.real_B Variable(self.input_B) # 先调用 forward, 再 D backward 更新D之后 再G backward 再更新G def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step()我们注意看第六行self.pred_fake self.netD.forward(fake_AB.detach())使得在反向传播D_loss的时候不会更新到self.netG因为fake_AB是由self.netG生成的代码如self.fake_B self.netG.forward(self.real_A)。设置requires_gradtensor.detach()是截断梯度流的一个好办法但是在设置了detach()的张量之前的所有模块梯度流都不能回流了(不包括这个张量本身这个张量已经脱离原先的计算图了)如以下代码所示x torch.randn(2, 2)x.requires_grad Truelin0 nn.Linear(2, 2)lin1 nn.Linear(2, 2)lin2 nn.Linear(2, 2)lin3 nn.Linear(2, 2)x1 lin0(x)x2 lin1(x1)x2 x2.detach() # 此处设置了detach之前的所有梯度流都不会回传了x3 lin2(x2)x4 lin3(x3)x4.sum().backward()print(lin0.weight.grad)print(lin1.weight.grad)print(lin2.weight.grad)print(lin3.weight.grad)输出为:NoneNonetensor([[-0.7784, -0.7018], [-0.4261, -0.3842]])tensor([[ 0.5509, -0.0386], [ 0.5509, -0.0386]])我们发现lin0.weight.grad和lin0.weight.grad都为None了因为通过脱离中间张量原先计算图已经和当前回传的梯度流脱离关系了。这样有时候不够理想因为我们可能存在只需要某些中间模块不计算梯度但是梯度仍然需要回传的情况在这种情况下如下图所示我们可能只需要不计算B_net的梯度但是我们又希望计算A_net和C_net的梯度这个时候怎么办呢当然通过detach()这个方法是不能用了。事实上我们可以通过设置张量的requires_grad属性来设置某个张量是否计算梯度而这个不会影响梯度回传只会影响当前的张量。修改上面的代码我们有x torch.randn(2, 2)x.requires_grad Truelin0 nn.Linear(2, 2)lin1 nn.Linear(2, 2)lin2 nn.Linear(2, 2)lin3 nn.Linear(2, 2)x1 lin0(x)x2 lin1(x1)for p in lin2.parameters(): p.requires_grad Falsex3 lin2(x2)x4 lin3(x3)x4.sum().backward()print(lin0.weight.grad)print(lin1.weight.grad)print(lin2.weight.grad)print(lin3.weight.grad)输出为:tensor([[-0.0117, 0.9976], [-0.0080, 0.6855]])tensor([[-0.0075, -0.0521], [-0.0391, -0.2708]])Nonetensor([[0.0523, 0.5429], [0.0523, 0.5429]])啊哈正是我们想要的结果只有设置了requires_gradFalse的模块没有计算梯度但是梯度流又能够回传。另外设置requires_grad经常用在对输入变量和输入的标签进行新建的时候使用如for mat,label in dataloader: mat Variable(mat, requires_gradFalse) label Variable(mat,requires_gradFalse) ...当然通过把所有前端网络都设置requires_gradFalse我们可以实现类似于detach()的效果也就是把该节点之前的所有梯度流回传截断。以VGG16为例子如果我们只需要训练其分类器而固定住其特征提取器网络的参数我们可以采用将前端网络的所有参数的requires_grad设置为False因为这个时候完全不需要梯度流的回传只需要前向计算即可。代码如model torchvision.models.vgg16(pretrainedTrue)for param in model.features.parameters(): param.requires_grad Falsetorch.no_grad()在对训练好的模型进行评估测试时我们同样不需要训练自然也不需要梯度流信息了。我们可以把所有参数的requires_grad属性设置为False事实上我们常用torch.no_grad()上下文管理器达到这个目的。即便输入的张量属性是requires_gradTrue, torch.no_grad()可以将所有的中间计算结果的该属性临时转变为False。如例子所示x torch.randn(3, requires_gradTrue)x1 (x**2)print(x.requires_grad)print(x1.requires_grad)with torch.no_grad(): x2 (x**2) print(x1.requires_grad) print(x2.requires_grad)输出为TrueTrueTrueFalse注意到只是在torch.no_grad()上下文管理器范围内计算的中间变量的属性requires_grad才会被转变为False在该管理器外面计算的并不会变化。不过和单纯手动设置requires_gradFalse不同的是在设置了torch.no_grad()之前的层是不能回传梯度的延续之前的例子如x torch.randn(2, 2)x.requires_grad Truelin0 nn.Linear(2, 2)lin1 nn.Linear(2, 2)lin2 nn.Linear(2, 2)lin3 nn.Linear(2, 2)x1 lin0(x)with torch.no_grad(): x2 lin1(x1)x3 lin2(x2)x4 lin3(x3)x4.sum().backward()print(lin0.weight.grad)print(lin1.weight.grad)print(lin2.weight.grad)print(lin3.weight.grad)输出为NoneNonetensor([[-0.0926, -0.0945], [-0.2793, -0.2851]])tensor([[-0.5216, 0.8088], [-0.5216, 0.8088]])此处如果我们打印lin1.weight.requires_grad我们会发现其为True但是其中间变量x2.requires_gradFalse。一般来说在实践中我们的torch.no_grad()通常会在测试模型的时候使用而不会选择在选择性训练某些模块时使用[1]例子如model.train()# here train the model, just skip the codesmodel.eval() # here we start to evaluate the modelwith torch.no_grad(): for each in eval_data: data, label each logit model(data) ... # here we just skip the codes注意通过设置属性requires_gradFalse的方法(包括torch.no_grad())很多时候可以避免保存中间计算的buffer从而减少对内存的需求但是这个也是视情况而定的比如如[2]的所示graph LR; input--A_net; A_net--B_net; B_net--C_net;如果我们不需要A_net的梯度我们设置所有A_net的requires_gradFalse因为后续的B_net和C_net的梯度流并不依赖于A_net因此不计算A_net的梯度流意味着不需要保存这个中间计算结果因此减少了内存。但是如果我们不需要的是B_net的梯度而需要A_net和C_net的梯度那么问题就不一样了因为A_net梯度依赖于B_net的梯度就算不计算B_net的梯度也需要保存回传过程中B_net中间计算的结果因此内存并不会被减少。但是通过tensor.detach()的方法并不会减少内存使用这一点需要注意。设置优化器的更新列表这个方法更为直接即便某个模块进行了梯度计算我只需要在优化器中指定不更新该模块的参数那么这个模块就和没有计算梯度有着同样的效果了。如以下代码所示:class model(nn.Module): def __init__(self): super().__init__() self.model_1 nn.linear(10,10) self.model_2 nn.linear(10,20) self.fc nn.linear(20,2) self.relu nn.ReLU() def foward(inputv): h self.model_1(inputv) h self.relu(h) h self.model_2(inputv) h self.relu(h) return self.fc(h)在设置优化器时我们只需要更新fc层和model_2层那么则是:curr_model model()opt_list list(curr_model.fc.parameters())list(curr_model.model_2.parameters())optimizer torch.optim.SGD(opt_list, lr1e-4)当然你也可以通过以下的方法去设置每一个层的学习率来避免不需要更新的层的更新[3]optim.SGD([ {params: model.model_1.parameters()}, {params: model.mode_2.parameters(), lr: 0}, {params: model.fc.parameters(), lr: 0} ], lr1e-2, momentum0.9)这种方法不需要更改模型本身结构也不需要添加模型的额外节点但是需要保存梯度的中间变量并且将会计算不需要计算的模块的梯度(即便最后优化的时候不考虑更新)这样浪费了内存和计算时间。Reference[1]. https://blog.csdn.net/LoseInVain/article/details/82916163[2]. https://discuss.pytorch.org/t/requires-grad-false-does-not-save-memory/21936[3]. https://pytorch.org/docs/stable/optim.html#module-torch.optim