【Python实现机器遗忘算法】复现2020年顶会CVPR算法Selective Forgetting

【Python实现机器遗忘算法】复现2020年顶会CVPR算法Selective Forgetting

请添加图片描述

1 算法原理

  • Golatkar, A., Achille, A., & Soatto, S. (2020). Eternal sunshine of the spotless net: Selective forgetting in deep networks. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 9304–9312.

  • 本文提出的算法简称为 SF(Selective Forgetting)算法,即选择性遗忘算法。这个名称来源于算法的核心目标:从深度神经网络的权重中选择性地移除特定数据子集的信息,而不影响其他数据的性能。该算法基于二次损失假设和梯度下降的稳定性,具体步骤如下:

  • 1 计算Hessian矩阵

    • 计算整个数据集 D \mathcal{D} D 的Hessian矩阵 A = ∇ 2 L D ( w ) A = \nabla^2 L_{\mathcal{D}}(w) A=2LD(w)
    • 计算保留数据集 D r \mathcal{D}_r Dr 的Hessian矩阵 B = ∇ 2 L D r ( w ) B = \nabla^2 L_{\mathcal{D}_r}(w) B=2LDr(w)
  • 2 计算梯度方向

    • 计算整个数据集 D \mathcal{D} D 的梯度方向 d = A − 1 ∇ w L D ( w ) d = A^{-1} \nabla_w L_{\mathcal{D}}(w) d=A1wLD(w)
    • 计算保留数据集 D r \mathcal{D}_r Dr 的梯度方向 d r = B − 1 ∇ w L D r ( w ) d_r = B^{-1} \nabla_w L_{\mathcal{D}_r}(w) dr=B1wLDr(w)
  • 3 构造遗忘函数

    • 根据遗忘函数的定义,构造 h ( w ) h(w) h(w)
      h ( w ) = w + e − B t e A t d + e − B t ( d − d r ) − d r . h(w) = w + e^{-Bt} e^{At} d + e^{-Bt} (d - d_r) - d_r. h(w)=w+eBteAtd+eBt(ddr)dr.
  • 4 添加噪声

    • 生成高斯噪声 n ∼ N ( 0 , Σ ) n \sim \mathcal{N}(0, \Sigma) nN(0,Σ),其中 Σ = λ σ h 2 B − 1 / 2 \Sigma = \sqrt{\lambda \sigma_h^2} B^{-1/2} Σ=λσh2 B1/2
    • 将噪声添加到遗忘函数中,得到最终的权重更新:
      S ( w ) = h ( w ) + n . S(w) = h(w) + n. S(w)=h(w)+n.
  • 5 更新权重

    • 使用 S ( w ) S(w) S(w) 更新网络权重,确保网络表现得像是从未见过 D f \mathcal{D}_f Df

2 代码实现

工具函数

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset,TensorDataset
from torch.amp import autocast, GradScaler  
import numpy as np
import matplotlib.pyplot as plt
import os
import warnings
import random
from copy import deepcopy
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = Falsewarnings.filterwarnings("ignore")
MODEL_NAMES = "MLP"
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义三层全连接网络
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(28 * 28, 256)self.fc2 = nn.Linear(256, 128)self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 加载MNIST数据集
def load_MNIST_data(batch_size,forgotten_classes,ratio):transform = transforms.Compose([transforms.ToTensor()])train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)forgotten_train_data,_ = generate_subset_by_ratio(train_data, forgotten_classes,ratio)retain_train_data,_ = generate_subset_by_ratio(train_data, [i for i in range(10) if i not in forgotten_classes])forgotten_train_loader= DataLoader(forgotten_train_data, batch_size=batch_size, shuffle=True)retain_train_loader= DataLoader(retain_train_data, batch_size=batch_size, shuffle=True)return train_loader, test_loader, retain_train_loader, forgotten_train_loader# worker_init_fn 用于初始化每个 worker 的随机种子
def worker_init_fn(worker_id):random.seed(2024 + worker_id)np.random.seed(2024 + worker_id)
def get_transforms():train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 标准化为[-1, 1]])test_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 标准化为[-1, 1]])return train_transform, test_transform
# 模型训练函数
def train_model(model, train_loader, criterion, optimizer, scheduler=None,use_fp16 = False):use_fp16 = True# 使用新的初始化方式:torch.amp.GradScaler("cuda")scaler = GradScaler("cuda")  # 用于混合精度训练model.train()running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)# 前向传播with autocast(enabled=use_fp16, device_type="cuda"):  # 更新为使用 "cuda"outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()if use_fp16:scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()else:loss.backward()optimizer.step()running_loss += loss.item()if scheduler is not None:# 更新学习率scheduler.step()print(f"Loss: {running_loss/len(train_loader):.4f}")
# 模型评估(计算保留和遗忘类别的准确率)
def test_model(model, test_loader, forgotten_classes=[0]):"""测试模型的性能,计算总准确率、遗忘类别准确率和保留类别准确率。:param model: 要测试的模型:param test_loader: 测试数据加载器:param forgotten_classes: 需要遗忘的类别列表:return: overall_accuracy, forgotten_accuracy, retained_accuracy"""model.eval()correct = 0total = 0forgotten_correct = 0forgotten_total = 0retained_correct = 0retained_total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)# 计算总的准确率total += labels.size(0)correct += (predicted == labels).sum().item()# 计算遗忘类别的准确率mask_forgotten = torch.isin(labels, torch.tensor(forgotten_classes, device=device))forgotten_total += mask_forgotten.sum().item()forgotten_correct += (predicted[mask_forgotten] == labels[mask_forgotten]).sum().item()# 计算保留类别的准确率(除遗忘类别的其他类别)mask_retained = ~mask_forgottenretained_total += mask_retained.sum().item()retained_correct += (predicted[mask_retained] == labels[mask_retained]).sum().item()overall_accuracy = correct / totalforgotten_accuracy = forgotten_correct / forgotten_total if forgotten_total > 0 else 0retained_accuracy = retained_correct / retained_total if retained_total > 0 else 0# return overall_accuracy, forgotten_accuracy, retained_accuracyreturn  round(overall_accuracy, 4), round(forgotten_accuracy, 4), round(retained_accuracy, 4)

主函数

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from models.Base import test_model, load_MNIST_data, load_CIFAR100_data, init_model# 定义机器遗忘类
class OptimalQuadraticForgetter:def __init__(self, model, lambda_param, sigma_h, forget_threshold):self.model = modelself.lambda_param = lambda_paramself.sigma_h = sigma_hself.forget_threshold = forget_threshold  # 设置遗忘阈值self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.hessian_D = None  # 整个数据集的Hessian矩阵self.hessian_Dr = None  # 保留数据集的Hessian矩阵self.grad_D = None  # 整个数据集的梯度self.grad_Dr = None  # 保留数据集的梯度# 计算Hessian矩阵和梯度def compute_hessian_and_grad(self, dataloader_D, dataloader_Dr):self.model.eval()self.hessian_D = {}self.hessian_Dr = {}self.grad_D = {}self.grad_Dr = {}# 计算整个数据集的Hessian和梯度for data, target in dataloader_D:data, target = data.to(self.device), target.to(self.device)self.model.zero_grad()output = self.model(data)loss = F.cross_entropy(output, target)loss.backward(create_graph=True)  # 需要二阶导数for name, param in self.model.named_parameters():if param.grad is not None:if name not in self.hessian_D:self.hessian_D[name] = torch.zeros_like(param.grad)self.grad_D[name] = torch.zeros_like(param.grad)self.hessian_D[name] += torch.autograd.grad(loss, param, create_graph=True)[0]self.grad_D[name] += param.grad# 计算保留数据集的Hessian和梯度for data, target in dataloader_Dr:data, target = data.to(self.device), target.to(self.device)self.model.zero_grad()output = self.model(data)loss = F.cross_entropy(output, target)loss.backward(create_graph=True)  # 需要二阶导数for name, param in self.model.named_parameters():if param.grad is not None:if name not in self.hessian_Dr:self.hessian_Dr[name] = torch.zeros_like(param.grad)self.grad_Dr[name] = torch.zeros_like(param.grad)self.hessian_Dr[name] += torch.autograd.grad(loss, param, create_graph=True)[0]self.grad_Dr[name] += param.grad# 平均化Hessian和梯度for key in self.hessian_D:self.hessian_D[key] /= len(dataloader_D.dataset)self.grad_D[key] /= len(dataloader_D.dataset)for key in self.hessian_Dr:self.hessian_Dr[key] /= len(dataloader_Dr.dataset)self.grad_Dr[key] /= len(dataloader_Dr.dataset)# 执行选择性遗忘操作:根据Hessian和梯度调整参数def scrub_weights(self):self.model.train()for name, param in self.model.named_parameters():if name in self.hessian_D and param.requires_grad:# 计算牛顿更新方向hessian_D = self.hessian_D[name]hessian_Dr = self.hessian_Dr[name]grad_D = self.grad_D[name]grad_Dr = self.grad_Dr[name]# 计算遗忘函数 h(w)h_w = param.data + torch.matmul(torch.inverse(hessian_Dr), grad_Dr - grad_D)# 添加噪声noise_std = (self.lambda_param * self.sigma_h**2)**0.25noise = torch.normal(mean=0, std=noise_std, size=param.data.shape, device=self.device)param.data = h_w + noise  # 更新权重return self.model# 全局函数:进行选择性遗忘
def optimal_quadratic_forgetting(model, dataloader_D, dataloader_Dr, lambda_param, sigma_h, forget_threshold):# 创建 Forgetter 对象forgetter = OptimalQuadraticForgetter(model, lambda_param, sigma_h, forget_threshold)# 计算Hessian矩阵和梯度forgetter.compute_hessian_and_grad(dataloader_D, dataloader_Dr)# 执行权重擦除操作model_after = forgetter.scrub_weights()return model_afterdef main():# 超参数设置batch_size = 256forgotten_classes = [0]ratio = 1model_name = "ResNet18"# 加载数据if model_name == "MLP":train_loader, test_loader, retain_loader, forget_loader = load_MNIST_data(batch_size, forgotten_classes, ratio)elif model_name == "ResNet18":train_loader, test_loader, retain_loader, forget_loader = load_CIFAR100_data(batch_size, forgotten_classes, ratio)model_before = init_model(model_name, train_loader)# 在训练之前测试初始模型准确率overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(model_before, test_loader)# 实现遗忘操作print("执行遗忘 optimal_quadratic_forgetting...")model_after = optimal_quadratic_forgetting(model_before, train_loader, retain_loader, lambda_param=0.1, sigma_h=0.1, forget_threshold=1e-05)# 测试遗忘后的模型overall_acc_after, forgotten_acc_after, retained_acc_after = test_model(model_after, test_loader)# 输出遗忘前后的准确率变化print(f"Unlearning前遗忘准确率: {100 * forgotten_acc_before:.2f}%")print(f"Unlearning后遗忘准确率: {100 * forgotten_acc_after:.2f}%")print(f"Unlearning前保留准确率: {100 * retained_acc_before:.2f}%")print(f"Unlearning后保留准确率: {100 * retained_acc_after:.2f}%")if __name__ == "__main__":main()

3 总结

噪声扰动是遗忘的核心机制之一,但噪声的引入可能会对模型的整体性能产生负面影响,尤其是在需要保留的数据上。需要仔细调整。过大的噪声会导致模型性能下降,而过小的噪声则可能无法有效移除目标数据的信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/8476.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux(NTP配置)

后面也会持续更新,学到新东西会在其中补充。 建议按顺序食用,欢迎批评或者交流! 缺什么东西欢迎评论!我都会及时修改的! NTP环境搭建 服务端客户端192.168.111.10192.168.111.11Linux MySQL5.7 3.10.0-1160.el7.x86_…

神经网络|(四)概率论基础知识-古典概型

【1】引言 前序学习了线性回归的基础知识,了解到最小二乘法可以做线性回归分析,但为何最小二乘法如此准确,这需要从概率论的角度给出依据。 因此从本文起,需要花一段时间来回顾概率论的基础知识。 【2】古典概型 古典概型是我…

21款炫酷烟花合集

系列专栏 《Python趣味编程》《C/C趣味编程》《HTML趣味编程》《Java趣味编程》 写在前面 Python、C/C、HTML、Java等4种语言实现18款炫酷烟花的代码。 Python Python烟花① 完整代码:Python动漫烟花(完整代码) ​ Python烟花② 完整…

新项目上传gitlab

Git global setup git config --global user.name “FUFANGYU” git config --global user.email “fyfucnic.cn” Create a new repository git clone gitgit.dev.arp.cn:casDs/sawrd.git cd sawrd touch README.md git add README.md git commit -m “add README” git push…

Brightness Controller-源码记录

Brightness Controller 亮度控制 一、概述二、ddcutil 与 xrandr1. ddcutil2. xrandr 三、部分代码解析1. icons2. ui3. utilinit.py 一、概述 项目:https://github.com/SunStorm2018/Brightness.git 原理:Brightness Controlle 是我在 Ubuntu 发现上调…

机器学习-K近邻算法

文章目录 一. 数据集介绍Iris plants dataset 二. 代码三. k值的选择 一. 数据集介绍 鸢尾花数据集 鸢尾花Iris Dataset数据集是机器学习领域经典数据集,鸢尾花数据集包含了150条鸢尾花信息,每50条取自三个鸢尾花中之一:Versicolour、Setosa…

Day27-【13003】短文,线性表两种基本实现方式空间效率、时间效率比较?兼顾优点的静态链表是什么?如何融入空闲单元链表来解决问题?

文章目录 本次内容总览第四节,两种基本实现方式概览两种基本实现方式的比较元素个数n大于多少时,使用顺序表存储的空间效率才会更高?时间效率比较?*、访问操作,也就是读运算,读操作1、插入,2、删…

JavaSE第十一天——集合框架Collection

一、List接口 List接口是一个有序的集合,允许元素有重复,它继承了Collection接口,提供了许多额外的功能,比如基于索引的插入、删除和访问元素等。 常见的List接口的实现类有ArrayList、LinkedList和Vector。 List接口的实现类 …

数据结构与算法学习笔记----求组合数

数据结构与算法学习笔记----求组合数 author: 明月清了个风 first publish time: 2025.1.27 ps⭐️一组求组合数的模版题,因为数据范围的不同要用不同的方法进行求解,涉及了很多之前的东西快速幂,逆元,质数,高精度等…

kaggle社区LLM Classification Finetuning

之前有个一样的比赛,没去参加,现在弄了一个无限期的比赛出来 训练代码链接:fine_tune | Kaggle 推理代码链接:https://www.kaggle.com/code/linheshen/inference-llama-3-8b?scriptVersionId219332972 包链接:pack…

【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning

【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning 1 算法原理 论文:Graves, L., Nagisetty, V., & Ganesh, V. (2021). Amnesiac machine learning. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, 115…

51单片机开发:点阵屏显示数字

实验目标:在8x8的点阵屏上显示数字0。 点阵屏的原理图如下图所示,点阵屏的列接在P0端口,行接在74HC595扩展的DP端口上。 扩展口的使用详见:51单片机开发:IO扩展(串转并)实验-CSDN博客 要让点阵屏显示数字&#xff0…

买卖股票的最佳时机 II

hello 大家好!今天开写一个新章节,每一天一道算法题。让我们一起来学习算法思维吧! 问题分析 本题要求计算在可以多次买卖股票(但任何时候最多只能持有一股股票,也可以在同一天买卖)的情况下能获得的最大…

2024年度总结——理想的风,吹进现实

2024年悄然过去,留下了太多美好的回忆,不得不感慨一声时间过得真快啊!旧年风雪尽,新岁星河明。写下这篇博客,记录我独一无二的2024年。这一年,理想的风终于吹进现实! 如果用一句话总结这一年&am…

LosslessScaling-学习版[steam价值30元的游戏无损放大/补帧工具]

LosslessScaling 链接:https://pan.xunlei.com/s/VOHc-yZBgwBOoqtdZAv114ZTA1?pwdxiih# 解压后运行"A-绿化-解压后运行我.cmd"

CVE-2020-0796永恒之蓝2.0(漏洞复现)

目录 前言 产生原因 影响范围 漏洞复现 复现环境 复现步骤 防御措施 总结 前言 在网络安全的战场上,漏洞一直是攻防双方关注的焦点。CVE-2020-0796,这个被称为 “永恒之蓝 2.0” 的漏洞,一度引起了广泛的关注与担忧。它究竟是怎样的…

计算机网络 (61)移动IP

前言 移动IP(Mobile IP)是由Internet工程任务小组(Internet Engineering Task Force,IETF)提出的一个协议,旨在解决移动设备在不同网络间切换时的通信问题,确保移动设备可以在离开原有网络或子网…

node 爬虫开发内存处理 zp_stoken 作为案例分析

声明: 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关! 前言 主要说3种我们补环境过后如果用…

基于Python的哔哩哔哩综合热门数据分析系统的设计与实现

【Django】基于大数据的哔哩哔哩综合热门数据分析系统的设计与实现(完整系统源码开发笔记详细部署教程)✅ 目录 一、项目简介二、项目界面展示三、项目视频展示 一、项目简介 该系统涵盖登录、热门数据展示、数据分析及数据管理等功能。通过大数据处理与…

Object类(2)

大家好,今天我们继续来看看Object类中一些成员方法,这些方法在实际中有很大的用处,话不多说,来看。 注:所有类都默认继承Object类的,所以可调用Object类中的方法,如equals,也可以发生…