【Python实现机器遗忘算法】复现2020年顶会CVPR算法Selective Forgetting
1 算法原理
-
Golatkar, A., Achille, A., & Soatto, S. (2020). Eternal sunshine of the spotless net: Selective forgetting in deep networks. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 9304–9312.
-
本文提出的算法简称为 SF(Selective Forgetting)算法,即选择性遗忘算法。这个名称来源于算法的核心目标:从深度神经网络的权重中选择性地移除特定数据子集的信息,而不影响其他数据的性能。该算法基于二次损失假设和梯度下降的稳定性,具体步骤如下:
-
1 计算Hessian矩阵
-
- 计算整个数据集 D \mathcal{D} D 的Hessian矩阵 A = ∇ 2 L D ( w ) A = \nabla^2 L_{\mathcal{D}}(w) A=∇2LD(w)。
- 计算保留数据集 D r \mathcal{D}_r Dr 的Hessian矩阵 B = ∇ 2 L D r ( w ) B = \nabla^2 L_{\mathcal{D}_r}(w) B=∇2LDr(w)。
-
2 计算梯度方向
-
- 计算整个数据集 D \mathcal{D} D 的梯度方向 d = A − 1 ∇ w L D ( w ) d = A^{-1} \nabla_w L_{\mathcal{D}}(w) d=A−1∇wLD(w)。
- 计算保留数据集 D r \mathcal{D}_r Dr 的梯度方向 d r = B − 1 ∇ w L D r ( w ) d_r = B^{-1} \nabla_w L_{\mathcal{D}_r}(w) dr=B−1∇wLDr(w)。
-
3 构造遗忘函数
-
- 根据遗忘函数的定义,构造 h ( w ) h(w) h(w):
h ( w ) = w + e − B t e A t d + e − B t ( d − d r ) − d r . h(w) = w + e^{-Bt} e^{At} d + e^{-Bt} (d - d_r) - d_r. h(w)=w+e−BteAtd+e−Bt(d−dr)−dr.
- 根据遗忘函数的定义,构造 h ( w ) h(w) h(w):
-
4 添加噪声
-
- 生成高斯噪声 n ∼ N ( 0 , Σ ) n \sim \mathcal{N}(0, \Sigma) n∼N(0,Σ),其中 Σ = λ σ h 2 B − 1 / 2 \Sigma = \sqrt{\lambda \sigma_h^2} B^{-1/2} Σ=λσh2B−1/2。
- 将噪声添加到遗忘函数中,得到最终的权重更新:
S ( w ) = h ( w ) + n . S(w) = h(w) + n. S(w)=h(w)+n.
-
5 更新权重
-
- 使用 S ( w ) S(w) S(w) 更新网络权重,确保网络表现得像是从未见过 D f \mathcal{D}_f Df。
2 代码实现
工具函数
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset,TensorDataset
from torch.amp import autocast, GradScaler
import numpy as np
import matplotlib.pyplot as plt
import os
import warnings
import random
from copy import deepcopy
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = Falsewarnings.filterwarnings("ignore")
MODEL_NAMES = "MLP"
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义三层全连接网络
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(28 * 28, 256)self.fc2 = nn.Linear(256, 128)self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 加载MNIST数据集
def load_MNIST_data(batch_size,forgotten_classes,ratio):transform = transforms.Compose([transforms.ToTensor()])train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)forgotten_train_data,_ = generate_subset_by_ratio(train_data, forgotten_classes,ratio)retain_train_data,_ = generate_subset_by_ratio(train_data, [i for i in range(10) if i not in forgotten_classes])forgotten_train_loader= DataLoader(forgotten_train_data, batch_size=batch_size, shuffle=True)retain_train_loader= DataLoader(retain_train_data, batch_size=batch_size, shuffle=True)return train_loader, test_loader, retain_train_loader, forgotten_train_loader# worker_init_fn 用于初始化每个 worker 的随机种子
def worker_init_fn(worker_id):random.seed(2024 + worker_id)np.random.seed(2024 + worker_id)
def get_transforms():train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 标准化为[-1, 1]])test_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 标准化为[-1, 1]])return train_transform, test_transform
# 模型训练函数
def train_model(model, train_loader, criterion, optimizer, scheduler=None,use_fp16 = False):use_fp16 = True# 使用新的初始化方式:torch.amp.GradScaler("cuda")scaler = GradScaler("cuda") # 用于混合精度训练model.train()running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)# 前向传播with autocast(enabled=use_fp16, device_type="cuda"): # 更新为使用 "cuda"outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()if use_fp16:scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()else:loss.backward()optimizer.step()running_loss += loss.item()if scheduler is not None:# 更新学习率scheduler.step()print(f"Loss: {running_loss/len(train_loader):.4f}")
# 模型评估(计算保留和遗忘类别的准确率)
def test_model(model, test_loader, forgotten_classes=[0]):"""测试模型的性能,计算总准确率、遗忘类别准确率和保留类别准确率。:param model: 要测试的模型:param test_loader: 测试数据加载器:param forgotten_classes: 需要遗忘的类别列表:return: overall_accuracy, forgotten_accuracy, retained_accuracy"""model.eval()correct = 0total = 0forgotten_correct = 0forgotten_total = 0retained_correct = 0retained_total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)# 计算总的准确率total += labels.size(0)correct += (predicted == labels).sum().item()# 计算遗忘类别的准确率mask_forgotten = torch.isin(labels, torch.tensor(forgotten_classes, device=device))forgotten_total += mask_forgotten.sum().item()forgotten_correct += (predicted[mask_forgotten] == labels[mask_forgotten]).sum().item()# 计算保留类别的准确率(除遗忘类别的其他类别)mask_retained = ~mask_forgottenretained_total += mask_retained.sum().item()retained_correct += (predicted[mask_retained] == labels[mask_retained]).sum().item()overall_accuracy = correct / totalforgotten_accuracy = forgotten_correct / forgotten_total if forgotten_total > 0 else 0retained_accuracy = retained_correct / retained_total if retained_total > 0 else 0# return overall_accuracy, forgotten_accuracy, retained_accuracyreturn round(overall_accuracy, 4), round(forgotten_accuracy, 4), round(retained_accuracy, 4)
主函数
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from models.Base import test_model, load_MNIST_data, load_CIFAR100_data, init_model# 定义机器遗忘类
class OptimalQuadraticForgetter:def __init__(self, model, lambda_param, sigma_h, forget_threshold):self.model = modelself.lambda_param = lambda_paramself.sigma_h = sigma_hself.forget_threshold = forget_threshold # 设置遗忘阈值self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.hessian_D = None # 整个数据集的Hessian矩阵self.hessian_Dr = None # 保留数据集的Hessian矩阵self.grad_D = None # 整个数据集的梯度self.grad_Dr = None # 保留数据集的梯度# 计算Hessian矩阵和梯度def compute_hessian_and_grad(self, dataloader_D, dataloader_Dr):self.model.eval()self.hessian_D = {}self.hessian_Dr = {}self.grad_D = {}self.grad_Dr = {}# 计算整个数据集的Hessian和梯度for data, target in dataloader_D:data, target = data.to(self.device), target.to(self.device)self.model.zero_grad()output = self.model(data)loss = F.cross_entropy(output, target)loss.backward(create_graph=True) # 需要二阶导数for name, param in self.model.named_parameters():if param.grad is not None:if name not in self.hessian_D:self.hessian_D[name] = torch.zeros_like(param.grad)self.grad_D[name] = torch.zeros_like(param.grad)self.hessian_D[name] += torch.autograd.grad(loss, param, create_graph=True)[0]self.grad_D[name] += param.grad# 计算保留数据集的Hessian和梯度for data, target in dataloader_Dr:data, target = data.to(self.device), target.to(self.device)self.model.zero_grad()output = self.model(data)loss = F.cross_entropy(output, target)loss.backward(create_graph=True) # 需要二阶导数for name, param in self.model.named_parameters():if param.grad is not None:if name not in self.hessian_Dr:self.hessian_Dr[name] = torch.zeros_like(param.grad)self.grad_Dr[name] = torch.zeros_like(param.grad)self.hessian_Dr[name] += torch.autograd.grad(loss, param, create_graph=True)[0]self.grad_Dr[name] += param.grad# 平均化Hessian和梯度for key in self.hessian_D:self.hessian_D[key] /= len(dataloader_D.dataset)self.grad_D[key] /= len(dataloader_D.dataset)for key in self.hessian_Dr:self.hessian_Dr[key] /= len(dataloader_Dr.dataset)self.grad_Dr[key] /= len(dataloader_Dr.dataset)# 执行选择性遗忘操作:根据Hessian和梯度调整参数def scrub_weights(self):self.model.train()for name, param in self.model.named_parameters():if name in self.hessian_D and param.requires_grad:# 计算牛顿更新方向hessian_D = self.hessian_D[name]hessian_Dr = self.hessian_Dr[name]grad_D = self.grad_D[name]grad_Dr = self.grad_Dr[name]# 计算遗忘函数 h(w)h_w = param.data + torch.matmul(torch.inverse(hessian_Dr), grad_Dr - grad_D)# 添加噪声noise_std = (self.lambda_param * self.sigma_h**2)**0.25noise = torch.normal(mean=0, std=noise_std, size=param.data.shape, device=self.device)param.data = h_w + noise # 更新权重return self.model# 全局函数:进行选择性遗忘
def optimal_quadratic_forgetting(model, dataloader_D, dataloader_Dr, lambda_param, sigma_h, forget_threshold):# 创建 Forgetter 对象forgetter = OptimalQuadraticForgetter(model, lambda_param, sigma_h, forget_threshold)# 计算Hessian矩阵和梯度forgetter.compute_hessian_and_grad(dataloader_D, dataloader_Dr)# 执行权重擦除操作model_after = forgetter.scrub_weights()return model_afterdef main():# 超参数设置batch_size = 256forgotten_classes = [0]ratio = 1model_name = "ResNet18"# 加载数据if model_name == "MLP":train_loader, test_loader, retain_loader, forget_loader = load_MNIST_data(batch_size, forgotten_classes, ratio)elif model_name == "ResNet18":train_loader, test_loader, retain_loader, forget_loader = load_CIFAR100_data(batch_size, forgotten_classes, ratio)model_before = init_model(model_name, train_loader)# 在训练之前测试初始模型准确率overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(model_before, test_loader)# 实现遗忘操作print("执行遗忘 optimal_quadratic_forgetting...")model_after = optimal_quadratic_forgetting(model_before, train_loader, retain_loader, lambda_param=0.1, sigma_h=0.1, forget_threshold=1e-05)# 测试遗忘后的模型overall_acc_after, forgotten_acc_after, retained_acc_after = test_model(model_after, test_loader)# 输出遗忘前后的准确率变化print(f"Unlearning前遗忘准确率: {100 * forgotten_acc_before:.2f}%")print(f"Unlearning后遗忘准确率: {100 * forgotten_acc_after:.2f}%")print(f"Unlearning前保留准确率: {100 * retained_acc_before:.2f}%")print(f"Unlearning后保留准确率: {100 * retained_acc_after:.2f}%")if __name__ == "__main__":main()
3 总结
噪声扰动是遗忘的核心机制之一,但噪声的引入可能会对模型的整体性能产生负面影响,尤其是在需要保留的数据上。需要仔细调整。过大的噪声会导致模型性能下降,而过小的噪声则可能无法有效移除目标数据的信息。