用百度地图 做gis网站,哈尔滨工业大学包机,深圳市住建局官网公示,非商业组织的网站风格论文阅读笔记#xff1a;Dataset Condensation with Gradient Matching1. 解决了什么问题#xff1f;(Motivation)2. 关键方法与创新点 (Key Method Innovation)2.1 核心思路的演进#xff1a;从参数匹配到梯度匹配2.2 算法实现细节 (Implementation Details)3. 实验结…
论文阅读笔记Dataset Condensation with Gradient Matching1. 解决了什么问题(Motivation)2. 关键方法与创新点 (Key Method Innovation)2.1 核心思路的演进从参数匹配到梯度匹配2.2 算法实现细节 (Implementation Details)3. 实验结果与贡献 (Experiments Contributions)4.个人思考与启发主要代码算法逻辑总结ICLR2021 github
核心思想一句话总结 本文提出了一种创新的数据集压缩方法——数据集凝缩Dataset Condensation,DC,其核心思想是通过梯度匹配Gradient Matching,将一个大型数据集T浓缩成一个极小的、信息量丰富的合成数据集S。在S上从头训练的模型其性能可以逼近在T上训练的模型从而极大地节省了存储和训练成本。 1. 解决了什么问题(Motivation)
问题现代深度学习依赖于大规模数据集导致存储成本、数据传输宽带和模型训练时间急剧增加。目标创建一个微型合成数据集S它能作为原始大型数据集T的高效替代品用于从零开始训练神经网络 2. 关键方法与创新点 (Key Method Innovation)
2.1 核心思路的演进从参数匹配到梯度匹配
参数匹配 (Parameter Matching) - 一个被否定的思路 想法直接让S训练收敛后的模型参数θS\theta_SθS与用T训练收敛后的θT\theta_TθT尽可能接近。缺陷 优化路径复杂深度网络的参数空间非凸直接走向目标θT\theta_TθT极易陷入局部最优。计算成本高需要嵌套的双层优化内循环必须将模型训练至收敛计算上不可行。 梯度匹配 (Gradient Matching) - 本文的核心创新 想法放弃匹配静态的”终点“转而匹配动态的”过程“。即确保在训练每一步模型在合成数据S上产生的梯度∇Ls∇L_s∇Ls在真实数据T上产生的梯度∇LT∇L_T∇LT方向一致。优势计算高效通过一个巧妙的近似极大提高了效率和可扩展性。优化路径清晰每一步都有明确的监督信号梯度差异引导S的优化避免了在复杂空间中盲目搜索。对齐学习动态保证了模型在S上的学习方式与T上一致结果更鲁棒。
2.2 算法实现细节 (Implementation Details)
课程学习 (Curriculum Learning) 为了让合成数据S具有泛化性算法采用了一个”课程学习“的框架。在整个凝缩过程中会周期性地重新随机初始化网络参数θ\thetaθ。这确保了S不会过拟合到某一个特定的网络初始化而是对多种随机起点都有效。 梯度匹配损失函数Gradient Matching Loss 使用**余弦距离1-Cosine Similarity**来衡量两个梯度的差异。这更关注梯度的方向而非大小与梯度下降的本质契合。按输出节点分组计算并非所有层的梯度粗暴地展平而是按输出神经元分组计算余弦距离更好地保留了网络结构信息。 重要的工程技巧Practical Tricks BatchNorm层预热与冻结由于合成数据批次极小为了避免BN层统计量不稳定每次迭代前都先用一个较大的真实数据批次来计算并”冻结“BN层的均值和方差。按类别独立匹配在计算梯度时按类别独立进行即用”猫“的合成数据区匹配”猫“的真实数据梯度。这降低了学习难度和内存消耗。 3. 实验结果与贡献 (Experiments Contributions)
性能优越在CIFAR-10, CIFAR-100, SVHN等数据集上仅用极少量合成样本如IPC1或10就能训练出性能远超当时其他数据压缩方法的模型。开创性贡献 首次提出了梯度匹配这一高效且可扩展的数据集凝缩范式为后续大量的研究如DSA, MTT, FTD等奠定了基础。成功将数据集凝缩技术应用到了大型网络上证明了其可行性。展示了其在持续学习和神经架构搜索 (NAS) 等资源受限场景下的巨大潜力。 4.个人思考与启发
”过程“比”结果”更重要这篇论文最精妙的哲学在于它揭示了在复杂优化问题中对齐“过程”梯度比直接追求“结果”参数更有效、更可行。这一思想在很多其他领域也具有启发性。理论与实践的结合论文不仅提出了一个优雅的理论框架还通过BN层处理等工程技巧解决了实际应用中的痛点。
主要代码 training # 为合成图像image_syn创建一个优化器# 我们只优化image_syn这个张量所有优化器只传入它。# 这里的优化器是SGD意味着我们会用梯度下降法来更新图像的像素值。optimizer_img torch.optim.SGD([image_syn, ], lrargs.lr_img, momentum0.5) # optimizer_img for synthetic data# 清空优化器的梯度缓存optimizer_img.zero_grad()# 定义用于计算分类损失的损失函数这里是标准的交叉熵损失。criterion nn.CrossEntropyLoss().to(args.device)print(%s training begins%get_time())# 主迭代循环开始# 这个循环是整个数据集凝缩过程的核心总共进行Iteration1次。for it in range(args.Iteration1):# 评估合成数据在特定迭代点触发 Evaluate synthetic data if it in eval_it_pool:for model_eval in model_eval_pool:# 遍历model_eval_pool中的每一个模型架构用于评估。# 这运行我们测试合成数据集在不同模型上的泛化能力。print(-------------------------\nEvaluation\nmodel_train %s, model_eval %s, iteration %d%(args.model, model_eval, it))# 设置评估时的数据增强策略if args.dsa:# 如果是DSA方法使用其特定的增强策略。args.epoch_eval_train 1000args.dc_aug_param Noneprint(DSA augmentation strategy: \n, args.dsa_strategy)print(DSA augmentation parameters: \n, args.dsa_param.__dict__)else:# 如果是DC方法调用 get_daparam 获取专为DC设计的增强参数。# 注意这些增强只在评估时使用在生成合成数据时不用。args.dc_aug_param get_daparam(args.dataset, args.model, model_eval, args.ipc) # This augmentation parameter set is only for DC method. It will be muted when args.dsa is True.print(DC augmentation parameters: \n, args.dc_aug_param)# 如果在评估时使用了任何数据增强就需要更多的训练轮数来让模型充分学习。if args.dsa or args.dc_aug_param[strategy] ! none:args.epoch_eval_train 1000 # Training with data augmentation needs more epochs.else:args.epoch_eval_train 300# --- 3.2 执行评估 ---# 创建一个空列表用于存储多次评估的准确率accs []# 为了结果的稳定性我们会用当前的合成数据训练num_eval个独立随机初始化的模型。for it_eval in range(args.num_eval):# 每一次都创建一个全新的、随机初始化的评估网络。net_eval get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model# 深拷贝当前的合成数据和标签以防止在评估函数中被意外修改。# detach()是为了确保我们只复制数据不带计算图。image_syn_eval, label_syn_eval copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification# 调用核心评估函数 evaluate_synset。# 这个函数会# 1. 拿 image_syn_eval 从头开始训练 net_eval。# 2. 在训练结束后用训练好的 net_eval 在真实的测试集 testloader 上进行测试。# 3. 返回在测试集上的准确率 acc_test。_, acc_train, acc_test evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)accs.append(acc_test)# 打印这次评估的平均准确率和标准差。print(Evaluate %d random %s, mean %.4f std %.4f\n-------------------------%(len(accs), model_eval, np.mean(accs), np.std(accs)))# 如果这是最后一次迭代将这次评估的所有准确率结果记录到总的实验结果字典中。if it args.Iteration: # record the final resultsaccs_all_exps[model_eval] accs# 可视化并保存合成图像 visualize and save save_name os.path.join(args.save_path, vis_%s_%s_%s_%dipc_exp%d_iter%d.png%(args.method, args.dataset, args.model, args.ipc, exp, it))# 深拷贝合成图像并移到CPU上进行处理。image_syn_vis copy.deepcopy(image_syn.detach().cpu())# 对图像进行反归一化以便人眼观察# 训练时图像通常是归一化的。# 反归一化公式pixel pixel * std meanfor ch in range(channel):image_syn_vis[:, ch] image_syn_vis[:, ch] * std[ch] mean[ch]# 将像素值裁剪到[0,1]范围内防止因浮点数误差导致显示异常。image_syn_vis[image_syn_vis0] 0.0image_syn_vis[image_syn_vis1] 1.0# 使用torchvision.utils.save_image将合成图像保存为一张网格图。# nrowargs.ipc表示每行显示ipc张图像。save_image(image_syn_vis, save_name, nrowargs.ipc) # Trying normalize True/False may get better visual effects.# --- 初始化课程学习环境 --- Train synthetic data # 每次主迭代(it)开始都创建一个全新的、随机初始的网络。# 这是”课程学习“的关键确保合成数据对不同的网络初始化方法都有效而不是过拟合到某一个。net get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random modelnet.train() # 将网络设置为训练模式# 获取网络的所有可学习参数net_parameters list(net.parameters())# 为这个新网络创建一个优化器用于在内循环中更新网络参数optimizer_net torch.optim.SGD(net.parameters(), lrargs.lr_net) # optimizer_img for synthetic dataoptimizer_net.zero_grad()# 初始化平均损失用于记录和打印loss_avg 0# 在生成合成数据时不使用任何数据增强以与DC论文的设置保持一致args.dc_aug_param None # Mute the DC augmentation when learning synthetic data (in inner-loop epoch function) in oder to be consistent with DC paper.# --- 课程学习外循环Outer Loop ---# 这个循环对应论文算法中的外循环用于实现课程学习。for ol in range(args.outer_loop):# -- BatchNorm层预热与冻结一个非常重要的工程技巧 -- freeze the running mu and sigma for BatchNorm layers # Synthetic data batch, e.g. only 1 image/batch, is too small to obtain stable mu and sigma.# So, we calculate and freeze mu and sigma for BatchNorm layer with real data batch ahead.# This would make the training with BatchNorm layers easier.# 动机合成数据的批次非常小例如ipc1如果让BN层在这么小的批次上计算均值和方差结果会极其不稳定导致训练困难。# 解决方案先用一个包含多个真实样本的”大“批次来预热BN层计算出稳定的统计量然后将其冻结。BN_flag FalseBNSizePC 16 # for batch normalization 每个类别用于BN预热的样本数# 检查网络中是否存在BN层for module in net.modules():if BatchNorm in module._get_name(): #BatchNormBN_flag Trueif BN_flag:# 从每个类别中抽取BNSizePC个真实图像拼接成一个大批次。img_real torch.cat([get_images(c, BNSizePC) for c in range(num_classes)], dim0)# 确保网络在训练模式以便BN层可以更新其 running_mean 和 running_var。net.train() # for updating the mu, sigma of BatchNorm# 进行一次前向传播这个操作会自动更新BN层的统计量。output_real net(img_real) # get running mu, sigma# 将所有BN层切换到评估模式。# 在评估模式下BN层会使用已经计算好的 running_mean 和 running_var而不会再根据新的输入来更新它们。# 这就实现了“冻结”的效果。for module in net.modules():if BatchNorm in module._get_name(): #BatchNormmodule.eval() # fix mu and sigma of every BatchNorm layer# --- 核心通过梯度匹配更新合成数据 --- update synthetic data # 初始化当前外循环的总损失loss torch.tensor(0.0).to(args.device)# 按照类别独立进行梯度匹配这个是论文提出的另外一个技巧。for c in range(num_classes):# 准备真实数据和合成数据img_real get_images(c, args.batch_real)lab_real torch.ones((img_real.shape[0],), deviceargs.device, dtypetorch.long) * cimg_syn image_syn[c*args.ipc:(c1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))lab_syn torch.ones((args.ipc,), deviceargs.device, dtypetorch.long) * c# 如果使用DSA方法对真实和合成图像应用相同的可微数据增强if args.dsa:seed int(time.time() * 1000) % 100000img_real DiffAugment(img_real, args.dsa_strategy, seedseed, paramargs.dsa_param)img_syn DiffAugment(img_syn, args.dsa_strategy, seedseed, paramargs.dsa_param)# --- 计算真实梯度 gw_real ---output_real net(img_real)loss_real criterion(output_real, lab_real)# 计算损失对网络参数的梯度gw_real torch.autograd.grad(loss_real, net_parameters)# clone()和detach()是为了将梯度值复制下来并切断其与计算图的联系# 因为我们只需要它的数值作为匹配目标不希望梯度回流真实数据。gw_real list((_.detach().clone() for _ in gw_real))# -- 计算合成梯度gw_syn --output_syn net(img_syn)loss_syn criterion(output_syn, lab_syn)# 关键所在create_graphTrue# 这个参数告诉pytorch在计算gw_syn时要保留其计算图。# 这意味着gw_syn本身也成为了一个计算图中的节点它依赖于iamge_syn.# 因此后续对gw_syn的损失进行反向传播时梯度可以一直流回image_syn。gw_syn torch.autograd.grad(loss_syn, net_parameters, create_graphTrue)# 计算真实梯度和合成梯度之间的匹配损失余弦相似度loss match_loss(gw_syn, gw_real, args)# 更新合成图像optimizer_img.zero_grad() # 清空image_syn的梯度缓存loss.backward() # 反向传播计算匹配损失对image_syn对image_syn的梯度optimizer_img.step() # 根据梯度更新image_syn的像素值loss_avg loss.item() # 累加损失用于打印# 如果是最后一个外循环就不需要再更新网络了直接跳出。if ol args.outer_loop - 1:break# --- 2.3 内循环用更新后的合成数据训练网络 --- update network # 第二步现在轮到网络来适应更新后的合成数据了。image_syn_train, label_syn_train copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modificationdst_syn_train TensorDataset(image_syn_train, label_syn_train)trainloader torch.utils.data.DataLoader(dst_syn_train, batch_sizeargs.batch_train, shuffleTrue, num_workers0)# 对网络进行inner_loop次的训练更新。for il in range(args.inner_loop):epoch(train, trainloader, net, optimizer_net, criterion, args, aug True if args.dsa else False)# 记录和保存# 计算并打印平均损失loss_avg / (num_classes*args.outer_loop)if it%10 0:print(%s iter %04d, loss %.4f % (get_time(), it, loss_avg))# 如果是最后一次主迭代保存所有结果if it args.Iteration: # only record the final resultsdata_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])torch.save({data: data_save, accs_all_exps: accs_all_exps, }, os.path.join(args.save_path, res_%s_%s_%s_%dipc.pt%(args.method, args.dataset, args.model, args.ipc))) 算法逻辑总结
“你追我赶”的双重优化过程
课程学习 (Outer Loop): 每一次外循环都像是新学期开学我们找来一个“新生”一个随机初始化的 net。这个“新生”的存在是为了确保我们的“教材”合成数据 image_syn是普适的对任何基础的学生都有效。 教材编写 (Update Synthetic Data): 这是核心步骤。我们让“新生” net 分别看“官方教材”真实数据 img_real和我们正在编写的“浓缩笔记”合成数据 img_syn。我们记录下“新生”看完两种材料后的“学习心得”梯度 gw_real 和 gw_syn。我们的目标是修改“浓缩笔记” img_syn使得“新生”看完它之后产生的“学习心得” gw_syn 和看完“官方教材”产生的 gw_real 一模一样。create_graphTrue 是实现这一点的技术关键它允许我们对“学习心得”本身求导从而知道该如何修改“浓y缩笔记”的每一个字像素。 学生自习 (Update Network): “浓缩笔记” image_syn 更新完毕后我们让“新生” net 对着这本新版的笔记自习几遍inner_loop次。这会让“新生”对当前的“浓缩笔记”有更深的理解为下一轮的“教材编写”做好准备。
这个“编写教材 - 学生自习 - 换个新生再来一遍”的过程不断重复最终使得“浓缩笔记” image_syn 变得越来越精华能够高效地替代“官方教材” T。