【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning

【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning

在这里插入图片描述

1 算法原理

论文:Graves, L., Nagisetty, V., & Ganesh, V. (2021). Amnesiac machine learning. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, 11516–11524.

Amnesiac Unlearning(遗忘性遗忘) 是一种高效且精确的算法,旨在从已经训练好的神经网络模型中删除特定数据的学习信息,而不会显著影响模型在其他数据上的性能。该算法的核心思想是通过选择性撤销与敏感数据相关的参数更新来实现数据的“遗忘”。

1. 训练阶段:记录参数更新

在模型训练过程中,记录每个批次的参数更新以及哪些批次包含敏感数据。

  • 步骤
    1. 初始化模型参数:从随机初始化的参数 θ i n i t i a l \theta_{initial} θinitial 开始训练模型。
    2. 训练模型:使用标准训练方法(如随机梯度下降)对模型进行训练,训练过程分为多个 epoch,每个 epoch 包含多个批次(batches)。
    3. 记录参数更新
      • 对于每个批次 b b b,记录该批次的参数更新 Δ θ e , b \Delta_{\theta_{e,b}} Δθe,b,其中 e e e 表示 epoch 编号, b b b 表示批次编号。
      • 同时,记录哪些批次包含敏感数据(即需要删除的数据)。可以将这些批次标记为 S B SB SB(Sensitive Batches)。
    4. 存储信息
      • 存储所有批次的参数更新 Δ θ e , b \Delta_{\theta_{e,b}} Δθe,b
      • 存储敏感数据批次的索引 S B SB SB

2. 数据删除阶段:选择性撤销参数更新

当收到数据删除请求时,撤销与敏感数据相关的参数更新。

  • 步骤

    1. 识别敏感数据批次:从存储的记录中提取包含敏感数据的批次索引 S B SB SB
    2. 撤销参数更新

    计算删除敏感数据后的模型参数 θ M \theta_{M} θM
    θ M ′ = θ M − ∑ s b ∈ S B Δ θ s b \theta_{M'} = \theta_{M} - \sum_{sb \in SB} \Delta_{\theta_{sb}} θM=θMsbSBΔθsb

    其中:

    • θ M \theta_{M} θM 是原始训练后的模型参数。
    • Δ θ s b \Delta_{\theta_{sb}} Δθsb 是敏感数据批次 s b sb sb 的参数更新。
    • 生成保护模型:使用更新后的参数 θ M ′ \theta_{M'} θM 作为新的模型参数。

3. 微调阶段(可选)

如果删除的批次较多,可能会对模型性能产生一定影响。此时可以通过少量微调来恢复模型性能。

  • 步骤
    1. 微调模型:使用删除敏感数据后的数据集对模型进行少量迭代训练。
    2. 恢复性能:通过微调,模型可以恢复在非敏感数据上的性能。

2 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from models.Base import load_MNIST_data, test_model, device, MLP, load_CIFAR100_data, init_model# AmnesiacForget类:封装撤销与敏感数据相关的参数更新
class AmnesiacForget:def __init__(self, model, all_data, epochs, learning_rate):"""初始化 AmnesiacForget 类。:param model: 需要训练的模型。:param all_data: 训练数据集。:param epochs: 训练的总 epoch 数。:param learning_rate: 优化器的学习率。"""self.model = modelself.all_data = all_dataself.epochs = epochsself.learning_rate = learning_rateself.batch_updates = []  # 存储每个批次的参数更新值self.initial_params = {name: param.clone() for name, param in model.named_parameters()}  # 存储初始模型参数self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 设备选择(GPU 或 CPU)def train(self, forgotten_classes):"""训练模型并记录每个批次的参数更新值。:param forgotten_classes: 需要遗忘的类别列表。:return: sensitive_batches: 包含敏感数据的批次索引。"""optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)  # 使用 Adam 优化器self.model.train()  # 将模型设置为训练模式sensitive_batches = {}  # 记录每个 epoch 中包含敏感数据的批次索引# 训练过程for epoch in range(self.epochs):running_loss = 0.0sensitive_batches[epoch] = set()  # 每个 epoch 的敏感批次集for batch_idx, (images, labels) in enumerate(self.all_data):optimizer.zero_grad()  # 清空梯度images, labels = images.to(self.device), labels.to(self.device)  # 将数据移动到设备上# 前向传播和损失计算outputs = self.model(images)loss = nn.CrossEntropyLoss()(outputs, labels)# 反向传播计算梯度loss.backward()running_loss += loss.item()# 记录当前参数值current_params = {name: param.clone() for name, param in self.model.named_parameters()}# 更新参数optimizer.step()# 记录参数更新值(当前参数值 - 更新前的参数值)batch_update = {}for name, param in self.model.named_parameters():if param.requires_grad:batch_update[name] = param.data - current_params[name].data  # 记录参数更新值self.batch_updates.append(batch_update)# 记录包含敏感数据的批次索引if any(label.item() in forgotten_classes for label in labels):sensitive_batches[epoch].add(batch_idx)print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {running_loss/len(self.all_data):.4f}")return sensitive_batchesdef unlearn(self, sensitive_batches):"""撤销与敏感数据相关的批次更新。:param sensitive_batches: 包含敏感数据的批次索引。:return: 更新后的模型。"""# 计算非敏感批次的参数更新总和non_sensitive_updates = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()}for batch_idx, batch_update in enumerate(self.batch_updates):if batch_idx not in {sb for epoch_batches in sensitive_batches.values() for sb in epoch_batches}:for name, update in batch_update.items():non_sensitive_updates[name] += update# 更新模型参数:初始参数 + 非敏感批次的更新for name, param in self.model.named_parameters():param.data = self.initial_params[name].data + non_sensitive_updates[name]return self.model# 全局函数:实现 Amnesiac Forget
def amnesiac_unlearning(model_before, test_loader, forgotten_classes, all_data, epochs=10, learning_rate=0.001):"""执行 Amnesiac Unlearning:训练模型,记录参数更新,并撤销与敏感数据相关的更新。:param model_before: 遗忘前的模型。:param test_loader: 测试数据加载器。:param forgotten_classes: 需要遗忘的类别列表。:param all_data: 训练数据集。:param epochs: 训练的总 epoch 数(默认为 10)。:param learning_rate: 优化器的学习率(默认为 0.001)。:return: 遗忘后的模型。"""# 模拟从头训练的过程,并记录批次更新的过程print("模拟重新训练过程,记录批次更新...")temp_model = MLP().to(device)  # 初始化一个新模型amnesiac_forget = AmnesiacForget(temp_model, all_data, epochs, learning_rate)  # 初始化 AmnesiacForget 类sensitive_batches = amnesiac_forget.train(forgotten_classes)  # 训练模型并记录敏感批次# 测试遗忘前的模型性能overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(amnesiac_forget.model, test_loader)print(f"全部准确率: {overall_acc_before:.2f}%, 保留准确率: {retained_acc_before:.2f}%, 遗忘准确率: {forgotten_acc_before:.2f}%")# 应用遗忘:撤销与敏感数据相关的批次更新model_after = amnesiac_forget.unlearn(sensitive_batches)return model_afterdef main():# 超参数设置batch_size = 256forgotten_classes = [0]  # 需要遗忘的类别ratio = 1model_name = "ResNet18"  # 模型名称# 加载数据if model_name == "MLP":train_loader, test_loader, retain_loader, forget_loader = load_MNIST_data(batch_size, forgotten_classes, ratio)elif model_name == "ResNet18":train_loader, test_loader, retain_loader, forget_loader = load_CIFAR100_data(batch_size, forgotten_classes, ratio)# 初始化模型model_before = init_model(model_name, train_loader)# 在训练之前测试初始模型准确率overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(model_before, test_loader)# 实现遗忘操作print("执行遗忘 Amnesiac...")model_after = amnesiac_unlearning(model_before, test_loader, forgotten_classes, train_loader, epochs=5, learning_rate=0.001)# 测试遗忘后的模型overall_acc_after, forgotten_acc_after, retained_acc_after = test_model(model_after, test_loader)# 输出遗忘前后的准确率变化print(f"Unlearning 前遗忘准确率: {100 * forgotten_acc_before:.2f}%")print(f"Unlearning 后遗忘准确率: {100 * forgotten_acc_after:.2f}%")print(f"Unlearning 前保留准确率: {100 * retained_acc_before:.2f}%")print(f"Unlearning 后保留准确率: {100 * retained_acc_after:.2f}%")if __name__ == "__main__":main()

3 总结

  • 高效性:只需撤销与敏感数据相关的参数更新,避免了从头训练模型的高成本。
  • 精确性:能够精确删除特定数据的学习信息,特别适合删除少量数据。
  • 存储成本:需要存储每个批次的参数更新,存储成本较高,但通常低于从头训练模型的成本。
  • 适用场景:适合删除少量数据(如单个样本或少量样本),而不适合删除大量数据(如整个类别)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/8459.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

51单片机开发:点阵屏显示数字

实验目标:在8x8的点阵屏上显示数字0。 点阵屏的原理图如下图所示,点阵屏的列接在P0端口,行接在74HC595扩展的DP端口上。 扩展口的使用详见:51单片机开发:IO扩展(串转并)实验-CSDN博客 要让点阵屏显示数字&#xff0…

买卖股票的最佳时机 II

hello 大家好!今天开写一个新章节,每一天一道算法题。让我们一起来学习算法思维吧! 问题分析 本题要求计算在可以多次买卖股票(但任何时候最多只能持有一股股票,也可以在同一天买卖)的情况下能获得的最大…

2024年度总结——理想的风,吹进现实

2024年悄然过去,留下了太多美好的回忆,不得不感慨一声时间过得真快啊!旧年风雪尽,新岁星河明。写下这篇博客,记录我独一无二的2024年。这一年,理想的风终于吹进现实! 如果用一句话总结这一年&am…

LosslessScaling-学习版[steam价值30元的游戏无损放大/补帧工具]

LosslessScaling 链接:https://pan.xunlei.com/s/VOHc-yZBgwBOoqtdZAv114ZTA1?pwdxiih# 解压后运行"A-绿化-解压后运行我.cmd"

CVE-2020-0796永恒之蓝2.0(漏洞复现)

目录 前言 产生原因 影响范围 漏洞复现 复现环境 复现步骤 防御措施 总结 前言 在网络安全的战场上,漏洞一直是攻防双方关注的焦点。CVE-2020-0796,这个被称为 “永恒之蓝 2.0” 的漏洞,一度引起了广泛的关注与担忧。它究竟是怎样的…

计算机网络 (61)移动IP

前言 移动IP(Mobile IP)是由Internet工程任务小组(Internet Engineering Task Force,IETF)提出的一个协议,旨在解决移动设备在不同网络间切换时的通信问题,确保移动设备可以在离开原有网络或子网…

node 爬虫开发内存处理 zp_stoken 作为案例分析

声明: 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关! 前言 主要说3种我们补环境过后如果用…

基于Python的哔哩哔哩综合热门数据分析系统的设计与实现

【Django】基于大数据的哔哩哔哩综合热门数据分析系统的设计与实现(完整系统源码开发笔记详细部署教程)✅ 目录 一、项目简介二、项目界面展示三、项目视频展示 一、项目简介 该系统涵盖登录、热门数据展示、数据分析及数据管理等功能。通过大数据处理与…

Object类(2)

大家好,今天我们继续来看看Object类中一些成员方法,这些方法在实际中有很大的用处,话不多说,来看。 注:所有类都默认继承Object类的,所以可调用Object类中的方法,如equals,也可以发生…

C++封装红黑树实现mymap和myset和模拟实现详解

文章目录 map和set的封装map和set的底层 map和set的模拟实现insertiterator实现的思路operatoroperator- -operator[ ] map和set的封装 介绍map和set的底层实现 map和set的底层 一份模版实例化出key的rb_tree和pair<k,v>的rb_tree rb_tree的Key和Value不是我们之前传统意…

单片机基础模块学习——PCF8591芯片

一、A/D、D/A模块 A——Analog 模拟信号:连续变化的信号(很多传感器原始输出的信号都为此类信号)D——Digital 数字信号:只有高电平和低电平两种变化(单片机芯片、微控制芯片所能处理的都是数字信号) 下面是模拟信号和连续信号的区别 为什么需要进行模拟信号和数字信号之…

Blazor-Blazor Web App项目结构

让我们还是从创建项目开始&#xff0c;来一起了解下Blazor Web App的项目情况 创建项目 呈现方式 这里我们可以看到需要选择项目的呈现方式&#xff0c;有以上四种呈现方式 ● WebAssembly ● Server ● Auto(Server and WebAssembly) ● None 纯静态界面静态SSR呈现方式 WebAs…

自动驾驶中的多传感器时间同步

目录 前言 1.多传感器时间特点 2.统一时钟源 2.1 时钟源 2.2 PPSGPRMC 2.3 PTP 2.4 全域架构时间同步方案 3.时间戳误差 3.1 硬件同步 3.2 软件同步 3.2.3 其他方式 ① ROS 中的 message_filters 包 ② 双端队列 std::deque 参考&#xff1a; 前言 对多传感器数据…

神经网络|(一)加权平均法,感知机和神经元

【1】引言 从这篇文章开始&#xff0c;将记述对神经网络知识的探索。相关文章都是学习过程中的感悟和理解&#xff0c;如有雷同或者南辕北辙的表述&#xff0c;请大家多多包涵。 【2】加权平均法 在数学课本和数理统计课本中&#xff0c;我们总会遇到求一组数据平均值的做法…

算法题(48):反转链表

审题&#xff1a; 需要我们将链表反转并返回头结点地址 思路&#xff1a; 一般在面试中&#xff0c;涉及链表的题会主要考察链表的指向改变&#xff0c;所以一般不会允许我们改变节点val值。 这里是单向链表&#xff0c;如果要把指向反过来则需要同时知道前中后三个节点&#x…

DroneXtract:一款针对无人机的网络安全数字取证工具

关于DroneXtract DroneXtract是一款使用 Golang 开发的适用于DJI无人机的综合数字取证套件&#xff0c;该工具可用于分析无人机传感器值和遥测数据、可视化无人机飞行地图、审计威胁活动以及提取多种文件格式中的相关数据。 功能介绍 DroneXtract 具有四个用于无人机取证和审…

SpringBoot中Excel表的导入、导出功能的实现

文章目录 一、easyExcel简介二、Excel表的导出2.1 添加 Maven 依赖2.2 创建导出数据的实体类4. 编写导出接口5. 前端代码6. 实现效果 三、excel表的导出1. Excel表导入的整体流程1.1 配置文件存储路径 2. 前端实现2.1 文件上传组件 2.2 文件上传逻辑3. 后端实现3.1 文件上传接口…

C语言,无法正常释放char*的空间

问题描述 #include <stdio.h> #include <stdio.h>const int STRSIZR 10;int main() {char *str (char *)malloc(STRSIZR*sizeof(char));str "string";printf("%s\n", str);free(str); } 乍一看&#xff0c;这块代码没有什么问题。直接书写…

2025蓝桥杯JAVA编程题练习Day1

1.刑侦科推理试题 题目描述 有以下10道单选题&#xff0c;编程求这10道题的答案。 这道题的答案是&#xff1a; A. A B. B C. C D. D 第5题的答案是&#xff1a; A. C B. D C. A D. B 以下选项中哪一题的答案与其他三项不同&#xff1a; A. 第3题 B. 第6题 C. 第2题 D.…

图漾相机-ROS2-SDK-Ubuntu版本编译(新版本)

官网编译文档链接&#xff1a; https://doc.percipio.xyz/cam/latest/getstarted/sdk-ros2-compile.html 国内gitee下载SDK链接&#xff1a; https://gitee.com/percipioxyz 国外github下载SDK链接&#xff1a; https://github.com/percipioxyz 1.Camport ROS2 SDK 介绍 1.1 …