【机器学习实战】kaggle 欺诈检测---使用生成对抗网络(GAN)解决欺诈数据中正负样本极度不平衡问题

【机器学习实战】kaggle 欺诈检测---如何解决欺诈数据中正负样本极度不平衡问题icon-default.png?t=O83Ahttps://blog.csdn.net/2302_79308082/article/details/145177242

本篇文章是基于上次文章中提到的对抗生成网络,通过对抗生成网络生成少数类样本,平衡欺诈数据中正类样本极少的问题。

本人主页:机器学习小小白

机器学习专栏:机器学习实战

PyTorch入门专栏:PyTorch入门

深度学习实战:深度学习

ok,话不多说,我们进入正题吧

1. 引言

生成对抗网络(Generative Adversarial Networks,简称GAN)是由Ian Goodfellow等人于2014年提出的一种深度学习模型。它在计算机视觉、自然语言处理、音频生成等领域得到了广泛应用。GAN的核心思想是通过两个神经网络之间的博弈关系来生成新的、仿真的数据。自从GAN提出以来,它已经成为生成模型领域的突破性进展,深刻改变了生成式模型的研究和应用。

2. GAN的基本原理

生成对抗网络的结构包括两个主要部分:生成器(Generator)和判别器(Discriminator)。这两个网络分别充当“对手”,并在训练过程中互相博弈:

  • 生成器(Generator):该网络的目的是通过学习数据分布来生成尽可能接近真实数据的虚假样本。生成器从一个随机的噪声(通常是高维的向量)出发,逐步生成样本。

  • 判别器(Discriminator):该网络的任务是判断一个样本是真实的(来自训练数据)还是虚假的(来自生成器)。判别器输出一个概率值,表示输入样本为真实数据的概率。

3. GAN的训练过程

GAN的训练过程是一个“博弈”过程,生成器和判别器不断互相对抗,从而提升各自的性能。这个过程可以通过以下的数学公式来表示:

  • 判别器的目标:判别器的目标是最大化其对于真实数据的判断概率(即预测为1的概率),同时最小化对生成数据的错误分类(即预测为0的概率)。可以通过以下的交叉熵损失函数表示:

       L_D = - \mathbb{E}_{x \sim p_{\text{data}}} \left[ \log D(x) \right] - \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right]

其中:

  • $x \sim p_{\text{data}}$​ 是从真实数据分布中采样的数据。

  • $G(z)$是生成器生成的样本,$z$是从潜在空间中采样的噪声。

  • $D(x)$ 是判别器对样本 $x$的判别输出,表示其为真实数据的概率。

  • 生成器的目标:生成器的目标是使判别器无法区分生成数据与真实数据,因此它通过最大化判别器对生成数据为真实的概率来进行训练:

  • L_G = - \mathbb{E}_{z \sim p_z} \left[ \log D(G(z)) \right]

  • 其中:G(z) 是生成器生成的虚假样本,D(G(z))是判别器对生成样本的输出,表示其为真实数据的概率。

在训练过程中,生成器和判别器会交替优化这两个损失函数。理想的结果是生成器能够生成与真实数据分布相似的样本,而判别器则无法有效地区分生成数据与真实数据。

4. GAN的应用

GAN具有强大的生成能力,广泛应用于多个领域,以下是一些典型的应用场景:

  • 图像生成:GAN可以用于生成高度逼真的图像,如人脸、风景或艺术作品。典型的例子包括DeepArt和StyleGAN,后者能够生成几乎无法与真实人脸区分的图像。

  • 图像到图像的转换:例如,利用GAN进行图像风格转换(如将照片转化为油画风格)、超分辨率重建(如提高图像的分辨率)、图像修复(如填补丢失部分)等任务。

  • 文本生成:结合自然语言处理技术,GAN也可用于生成文本数据,如诗歌、故事生成等,尤其是文本生成和对话系统中的对抗训练。

  • 音频生成:GAN被广泛应用于音频生成,如音乐生成、语音合成等。

  • 数据增强:GAN可以用于数据增强,特别是在医疗图像领域,生成具有一定变异的图像样本,以增强训练数据集。

  • 模型训练中的对抗样本生成:GAN可以生成对抗样本,即通过对训练数据进行微小扰动,生成能够误导模型的样本,这对提升模型的鲁棒性非常重要。

5. GAN的变种

GAN作为一种框架,已经发展出了多种变种,以满足不同应用的需求。以下是几种常见的GAN变种:

  • CGAN(Conditional GAN):在生成器和判别器中都加入了条件变量,使得生成的样本可以根据某些条件(如标签信息)进行控制。

  • WGAN(Wasserstein GAN):解决了传统GAN在训练过程中可能出现的梯度消失和模式崩溃问题。WGAN使用了Wasserstein距离作为生成器和判别器的损失函数。

  • DCGAN(Deep Convolutional GAN):使用卷积神经网络(CNN)来构建生成器和判别器,增强了GAN在图像生成任务中的表现。

  • CycleGAN:用于无监督学习场景,特别是在图像到图像的转换中,例如将一张照片转换成另一种风格(如马到斑马转换)。

6. 使用生成对抗网络(GAN)生成欺诈数据中少数类数据

1. 数据预处理与特征提取

import pandas as pd
import numpy as nptrain_df = pd.read_csv('/kaggle/input/credit-card-fraud-prediction/train.csv')
test_df = pd.read_csv('/kaggle/input/credit-card-fraud-prediction/test.csv')def time_feature(df):df['Time'] = pd.to_datetime(df['Time'], unit='s')  # 将时间戳转为 datetime 格式# 提取时间特征df['hour'] = df['Time'].dt.hourdf['minute'] = df['Time'].dt.minute return df train_df = time_feature(train_df)
test_df = time_feature(test_df)

在欺诈检测任务中,时间特征(如交易发生的小时和分钟)通常是重要的,因为欺诈交易往往具有不同的时间模式。例如,欺诈交易可能集中在某些特定的时间段。

  • 这里我们通过pd.to_datetime()Time列从Unix时间戳格式转换为日期时间格式。然后,我们提取了小时和分钟作为新的特征,用于训练模型。
train_feature = train_df.drop(columns=['id','IsFraud','Time'])
test_feature = test_df.drop(columns=['id','Time'])label = train_df['IsFraud']

train_feature 是用于训练的特征数据,删除了 id, IsFraudTime 列。IsFraud 是标签列,表示交易是否为欺诈交易;而 idTime 列不包含有用的特征信息,因此可以去掉。

2. 标准化数据

from sklearn.preprocessing import StandardScaler# 标准化特征数据
scaler = StandardScaler()
train_feature_scaled = scaler.fit_transform(train_feature)
  • 标准化(Standardization)是机器学习中常用的预处理步骤。它通过减去均值并除以标准差,使特征数据具有零均值和单位方差。标准化能够加速模型的收敛过程,尤其是在使用像神经网络这样的梯度优化模型时。

  • 这里使用 StandardScaler 来对训练数据进行标准化,以确保所有特征在同一个量级。

3. 生成器与判别器的构建

生成器(Generator)

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Inputdef build_generator(latent_dim, input_dim):model = Sequential()model.add(Input(shape=(latent_dim,)))  # 使用 Input 层来指定输入维度model.add(Dense(256))model.add(LeakyReLU(0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(input_dim, activation='tanh'))  # 输出层与原数据同维度return model

生成器(Generator)是GAN的核心部分,它通过接收随机噪声向量(潜在空间中的点),然后经过一系列的全连接层和激活函数,生成与原始数据分布相似的虚假数据。

  • 在此,我们使用了 LeakyReLU 激活函数,它允许梯度通过负半轴流动,解决了传统ReLU可能出现的“死神经元”问题。BatchNormalization 用于加速网络的训练,并帮助改善模型的稳定性。

判别器(Discriminator)

def build_discriminator(input_dim):model = Sequential()model.add(Input(shape=(input_dim,)))  # 使用 Input 层来指定输入维度model.add(Dense(1024))model.add(LeakyReLU(0.2))model.add(Dense(512))model.add(LeakyReLU(0.2))model.add(Dense(256))model.add(LeakyReLU(0.2))model.add(Dense(1, activation='sigmoid'))  # 输出真假判定return model

判别器(Discriminator)的任务是判断输入数据是真实的还是由生成器生成的。它是一个二分类模型,输出是一个概率值,表示输入数据为真实的概率。

  • 这里使用 sigmoid 激活函数,输出一个概率值。判别器学习将真实数据和生成数据区分开来。

4. GAN模型的组合与训练

def build_gan(generator, discriminator):discriminator.trainable = False  # 在训练GAN时冻结判别器model = Sequential()model.add(generator)model.add(discriminator)return model# 定义优化器
optimizer = Adam()# 定义输入维度和潜在维度
latent_dim = 100  # 随机噪声的维度
input_dim = 31  # 输入数据的维度,例如欺诈检测数据的特征数# 创建并编译模型
generator = build_generator(latent_dim, input_dim)
discriminator = build_discriminator(input_dim)
gan = build_gan(generator, discriminator)# 编译判别器和GAN模型
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
gan.compile(loss='binary_crossentropy', optimizer=optimizer)
  • 生成对抗训练(Adversarial Training)是GAN的关键。生成器和判别器在一个博弈过程中互相优化。在训练过程中,生成器通过“欺骗”判别器来优化其生成数据的能力,而判别器则不断学习区分真实和生成数据。

  • 在训练过程中,我们冻结判别器的参数,只训练生成器,这样可以避免在训练生成器时更新判别器的权重。

5. GAN训练函数

def train_gan(generator, discriminator, gan, fraud_data_scaled, epochs=10000, batch_size=64):valid = np.ones((batch_size, 1))  # 真数据标签fake = np.zeros((batch_size, 1))  # 假数据标签for epoch in range(epochs):# 随机选择真实欺诈数据idx = np.random.randint(0, fraud_data_scaled.shape[0], batch_size)real_data = fraud_data_scaled[idx]# 生成虚拟数据noise = np.random.normal(0, 1, (batch_size, latent_dim))generated_data = generator.predict(noise)# 训练判别器d_loss_real = discriminator.train_on_batch(real_data, valid)d_loss_fake = discriminator.train_on_batch(generated_data, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, latent_dim))g_loss = gan.train_on_batch(noise, valid)# 输出训练过程的损失if epoch % 1000 == 0:print(f'{epoch}/{epochs} [D loss: {d_loss[0]}] [G loss: {g_loss}]')
  • 训练过程:在每个训练周期中,首先更新判别器的权重(通过训练它区分真实数据和生成数据),然后训练生成器(通过训练它欺骗判别器)。

  • 损失函数:我们使用了 binary_crossentropy 损失函数,它用于二分类任务。在判别器的训练中,我们分别计算真实数据和生成数据的损失,然后平均得到判别器的总损失。生成器的损失则是通过GAN模型进行计算的。

6. 生成虚拟数据

def generate_fake_data(generator, num_samples):noise = np.random.normal(0, 1, (num_samples, latent_dim))  # 随机噪声generated_data = generator.predict(noise)  # 生成虚拟数据# 将生成的数据转换回原始空间generated_data_original = scaler.inverse_transform(generated_data)# 获取原始负样本数据的列名(去除 'id', 'IsFraud', 'Time' 列)feature_columns = [col for col in train_df.columns if col not in ['id', 'IsFraud', 'Time']]# 将生成的数据与原始负样本数据(即非欺诈数据)结合,作为新的训练数据augmented_data = np.concatenate([train_df[train_df['IsFraud'] == 0].drop(columns=['id', 'IsFraud', 'Time']),generated_data_original], axis=0)augmented_label = np.concatenate([np.zeros(train_df[train_df['IsFraud'] == 0].shape[0]), np.ones(generated_data_original.shape[0])], axis=0)# 创建包含生成数据和标签的 DataFrameaugmented_df = pd.DataFrame(augmented_data, columns=feature_columns)augmented_df['IsFraud'] = augmented_labelreturn augmented_df

在这个函数中,我们使用训练好的生成器来生成新的虚拟欺诈数据,并将它们与真实的非欺诈数据结合,以增强数据集。然后,我们通过逆标准化将生成的数据转换回原始数据空间。

本次例子为了缩短训练时间,只生成了100条虚拟的正样本数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/3514.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ZNS SSD垃圾回收优化方案解读-2

四、Brick-ZNS 关键设计机制解析 Brick-ZNS 作为一种创新的 ZNS SSD 设计,聚焦于解决传统 ZNS SSDs 在垃圾回收(GC)过程中的数据迁移低效问题,其核心特色为存储内数据迁移与地址重映射功能。在应用场景中,针对如 Rock…

浅谈云计算14 | 云存储技术

云存储技术 一、云计算网络存储技术基础1.1 网络存储的基本概念1.2云存储系统结构模型1.1.1 存储层1.1.2 基础管理层1.1.3 应用接口层1.1.4 访问层 1.2 网络存储技术分类 二、云计算网络存储技术特点2.1 超大规模与高可扩展性2.1.1 存储规模优势2.1.2 动态扩展机制 2.2 高可用性…

C++ 强化记忆

1 预处理指令 # include <> #include <filename> 添加系统头文件。 #include "filename" 添加自定义头文件。 # include <iostream> 对应使用cout cin的情况下需要添加 # include <string> 对应使用字符串情况 # include <fstream> …

web worker 前端多线程一、

前言&#xff1a; JavaScript 语言采用的是单线程模型&#xff0c;也就是说&#xff0c;所有任务只能在一个线程上完成&#xff0c;一次只能做一件事。前面的任务没做完&#xff0c;后面的任务只能等着。随着电脑计算能力的增强&#xff0c;尤其是多核 CPU 的出现&#xff0c;单…

Vue项目搭建教程超详细

目录 一. 环境准备 1. 安装node.js 2. 安装Vue cli 二. 创建 Vue 2 项目 1. 命令行方式 2. vue ui方式 一. 环境准备 1. 安装node.js 可参考node.js卸载与安装超详细教程-CSDN博客 2. 安装Vue cli npm install -g vue/cli检查是否安装成功 vue --version Vue CLI …

IM聊天学习资源

文章目录 参考链接使用前端界面简单效果消息窗口平滑滚动至底部vue使用watch监听vuex中的变量变化 websocket握手认证ChatKeyCheckHandlerNettyChatServerNettyChatInitializer 参考链接 zzhua/netty-chat-web - 包括前后端 vue.js实现带表情评论功能前后端实现&#xff08;仿…

彻底理解JVM类加载机制

文章目录 一、类加载器和双亲委派机制1.1、类加载器1.2、双亲委派机制1.3、自定义类加载器1.4、打破双亲委派机制 二、类的加载 图片来源&#xff1a;图灵学院   由上图可知&#xff0c;创建对象&#xff0c;执行其中的方法&#xff0c;在java层面&#xff0c;最重要的有获取…

使用FRP进行内网穿透

一、基本概念 内网穿透&#xff1a;它是一种网络技术或方法&#xff0c;旨在允许外部网络&#xff08;如互联网&#xff09;访问位于内部网络&#xff08;内网&#xff09;中的设备或服务。由于内部网络通常处于NAT&#xff08;网络地址转换&#xff09;、防火墙或其他安全机制…

Mysql常见问题处理集锦

Mysql常见问题处理集锦 root用户密码忘记&#xff0c;重置的操作(windows上的操作)MySQL报错&#xff1a;ERROR 1118 (42000): Row size too large. 或者 Row size too large (&#xff1e; 8126).场景&#xff1a;报错原因解决办法 详解行大小限制示例&#xff1a;内容来源于网…

《计算机网络》课后探研题书面报告_网际校验和算法

网际校验和算法 摘 要 本文旨在研究和实现网际校验和&#xff08;Internet Checksum&#xff09;算法。通过阅读《RFC 1071》文档理解该算法的工作原理&#xff0c;并使用编程语言实现网际校验和的计算过程。本项目将对不同类型的网络报文&#xff08;包括ICMP、TCP、UDP等&a…

Python毕业设计选题:基于django+vue的智能租房系统的设计与实现

开发语言&#xff1a;Python框架&#xff1a;djangoPython版本&#xff1a;python3.7.7数据库&#xff1a;mysql 5.7数据库工具&#xff1a;Navicat11开发软件&#xff1a;PyCharm 系统展示 租客注册 添加租客界面 租客管理 房屋类型管理 房屋信息管理 系统管理 摘要 本文首…

联发科MTK6762/MT6762安卓核心板_4G智能模块应用

MT6762安卓核心板是一款工业级高性能、可运行 android9.0 操作系统的 4G智能模块。MT6762平台打造具备 AI 体验、先进双摄像头拍摄效果且具备丰富连接功能的智能手机主板。 MT6762安卓核心板 是一款髙性能低功耗的 4G 全网通安卓智能模块。此模块支持 2G/3G/4G 移动&#xff0c…

彩色图像面积计算一般方法及MATLAB实现

一、引言 在数字图像处理中&#xff0c;经常需要获取感兴趣区域的面积属性&#xff0c;下面给出图像处理的一般步骤。 1.读入的彩色图像 2.将彩色图像转化为灰度图像 3.灰度图像转化为二值图像 4.区域标记 5.对每个区域的面积进行计算和显示 二、程序代码 %面积计算 cle…

重拾Python学习,先从把python删除开始。。。

自己折腾就是不行啊&#xff0c;屡战屡败&#xff0c;最近终于找到前辈教我 第一步 删除Python 先把前阵子折腾的WSL和VScode删掉。还是得用spyder&#xff0c;跟matlab最像&#xff0c;也最容易入手。 从VScode上搞python&#xff0c;最后安装到appdata上&#xff0c;安装插…

Redis系列之底层数据结构字典Dict

Redis系列之底层数据结构字典Dict Dict数据结构 Dict是Redis数据结构中使用最为频繁的复合型数据结构&#xff0c;本质上是一个哈希表 查看redis6.0版本的源码&#xff0c;链接&#xff1a;https://github.com/redis/redis/blob/6.0/src/dict.h 哈希表的结构定义&#xff1…

《贪心算法:原理剖析与典型例题精解》

必刷的贪心算法典型例题&#xff01; 算法竞赛&#xff08;蓝桥杯&#xff09;贪心算法1——数塔问题-CSDN博客 算法竞赛&#xff08;蓝桥杯&#xff09;贪心算法2——需要安排几位师傅加工零件-CSDN博客 算法&#xff08;蓝桥杯&#xff09;贪心算法3——二维数组排序与贪心算…

基于 Python 的深度学习的车俩特征分析系统,附源码

博主介绍&#xff1a;✌stormjun、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&…

VSCode 的部署

一、VSCode部署 (1)、简介 vsCode 全称 Visual Studio Code&#xff0c;是微软出的一款轻量级代码编辑器&#xff0c;免费、开源而且功能强大。它支持几乎所有主流的程序语言的语法高亮、智能代码补全、自定义热键、括号匹配、代码片段、代码对比Diff、版本管理GIT等特性&…

【开源免费】基于SpringBoot+Vue.JS欢迪迈手机商城(JAVA毕业设计)

本文项目编号 T 141 &#xff0c;文末自助获取源码 \color{red}{T141&#xff0c;文末自助获取源码} T141&#xff0c;文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…

Transformer创新模型!Transformer+BO-SVR多变量回归预测,添加气泡图、散点密度图(Matlab)

Transformer创新模型&#xff01;TransformerBO-SVR多变量回归预测&#xff0c;添加气泡图、散点密度图&#xff08;Matlab&#xff09; 目录 Transformer创新模型&#xff01;TransformerBO-SVR多变量回归预测&#xff0c;添加气泡图、散点密度图&#xff08;Matlab&#xff0…