【深度学习实战】kaggle 自动驾驶的假场景分类

本次分享我在kaggle中参与竞赛的历程,这个版本是我的第一版,使用的是vgg。欢迎大家进行建议和交流。

概述

  • 判断自动驾驶场景是真是假,训练神经网络或使用任何算法来分类驾驶场景的图像是真实的还是虚假的。

  • 图像采用 RGB 格式并以 JPEG 格式压缩。

  • 标签显示 (1) 真实和 (0) 虚假

  • 二元分类

数据集描述

文件
train.csv - 训练集标签
Sample_submission.csv - 正确格式的示例提交文件
Train/- 训练图像
Test/ - 测试图像

模型思路

由于是要进行图像的二分类任务,因此考虑使用迁移学习,将vgg16中的卷积层和卷积层的参数完全迁移过来,不包括顶部的全连接层,自己设计适合该任务的头部结构,然后加以训练,绘制图像查看训练结果。

vgg16简介

VGG16 是由牛津大学视觉几何组(VGG)在2014年提出的卷积神经网络(CNN)。它由16个层组成,其中包含13个卷积层和3个全连接层。其特点是使用3x3的小卷积核和2x2的最大池化层,网络深度较深,有效提取图像特征。VGG16在图像分类任务中表现优异,尤其是在ImageNet挑战中取得了良好成绩。尽管计算量大、参数众多,但它因其简单而高效的结构,仍广泛应用于迁移学习和其他计算机视觉任务中。

源码+解析

  1. 第一步,导入所需的库。
import os
import cv2
import numpy as np
import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Dropout, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.applications.vgg16 import preprocess_input
  1. 加载文件
# 路径和文件
data_file = '/kaggle/input/cidaut-ai-fake-scene-classification-2024/train.csv'
image_test = '/kaggle/input/cidaut-ai-fake-scene-classification-2024/Test/'
image_train = '/kaggle/input/cidaut-ai-fake-scene-classification-2024/Train/'# 加载标签数据
df = pd.read_csv(data_file)
df['image_path'] = df['image'].apply(lambda x: os.path.join(image_train, x))n_classes = df['label'].nunique()df.head()  # 显示数据的前几行,检查路径和标签

输出

	image	label	image_path
0	1.jpg	editada	/kaggle/input/cidaut-ai-fake-scene-classificat...
1	2.jpg	real	/kaggle/input/cidaut-ai-fake-scene-classificat...
2	3.jpg	real	/kaggle/input/cidaut-ai-fake-scene-classificat...
3	6.jpg	editada	/kaggle/input/cidaut-ai-fake-scene-classificat...
4	8.jpg	real	/kaggle/input/cidaut-ai-fake-scene-classificat...

原始train.csv文件只有前两列,image 和label 列,为了方便读取图像文件,新添加了一列image_path用来记录图像文件的具体路径。

# 初始化空列表 x 用于存储图像
x = []# 遍历每一行读取图像
for index, row in df.iterrows():image_path = row['image_path']  # 获取图像路径img = cv2.imread(image_path)  # 使用 cv2 读取图像if img is not None:img_resized = cv2.resize(img, (256, 256))  # 调整图像尺寸为 (256, 256)x.append(img_resized)  # 将读取的图像添加到列表 x 中else:print(f"图像 {row['image_path']} 读取失败")  # 打印失败的路径# x 列表现在包含了所有读取的图像
print(f"总共有 {len(x)} 张图像被读取")

输出

总共有 720 张图像被读取

通过输出结果,可以看到图像被正确的读取了。并且将图像的大小调整为vgg所能用的256*256的尺寸,存放在变量x中。

  1. 第三步,进行数据处理
# 将图像转换为 NumPy 数组
x = np.array(x)# 标签映射并进行 one-hot 编码
y = df['label'].map({'real': 1, 'editada': 0})
y = np.array(y)
y = to_categorical(y, num_classes=2)  # 二分类# 分割训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)# 检查转换后的结果
print(f"x_train.shape: {x_train.shape}")
print(f"y_train.shape: {y_train.shape}")
print(f"x_test.shape: {x_test.shape}")
print(f"y_test.shape: {y_test.shape}")

输出

x_train.shape: (576, 256, 256, 3)
y_train.shape: (576, 2)
x_test.shape: (144, 256, 256, 3)
y_test.shape: (144, 2)

这里是为了将原始的图像转换为numpy数组,并且将标签进行独热编码,(对分类的标签一定要进行独热编码,转换为矩阵形式),并且切分数据集。

  1. 第四步,设计模型结构
from tensorflow.keras.regularizers import l2
# 加载预训练的VGG16卷积基(不包括顶部的全连接层)
vgg16_model = VGG16(include_top=False, weights='imagenet', input_shape=(256, 256, 3))# 冻结VGG16的卷积层
for layer in vgg16_model.layers:layer.trainable = False# 创建一个新的模型
model_fine_tuning = Sequential()# 将VGG16的卷积基添加到新模型中
model_fine_tuning.add(vgg16_model)  # 添加VGG16卷积基
model_fine_tuning.add(Flatten())  # 将卷积特征图展平# 添加新的全连接层并进行正则化
model_fine_tuning.add(Dense(512, activation='relu', kernel_regularizer=l2(0.01)))  # L2正则化
model_fine_tuning.add(Dropout(0.3))  # Dropout层,减少过拟合
model_fine_tuning.add(Dense(256, activation='relu', kernel_regularizer=l2(0.01)))  # 较小的全连接层
model_fine_tuning.add(Dropout(0.3) ) # 再次使用Dropout层# 输出层
model_fine_tuning.add(Dense(2, activation='softmax'))  # 对于二分类问题,使用softmax# 查看模型架构
model_fine_tuning.summary()

输出:

Layer (type)Output ShapeParam #
vgg16 (Functional)(None, 8, 8, 512)14,714,688
flatten (Flatten)(None, 32768)0
dense (Dense)(None, 512)16,777,728
dropout (Dropout)(None, 512)0
dense_1 (Dense)(None, 256)131,328
dropout_1 (Dropout)(None, 256)0
dense_2 (Dense)(None, 2)514

这里实现了一个基于预训练VGG16模型的迁移学习框架,用于图像分类任务。首先,加载了预训练的VGG16卷积基(不包括全连接层),并通过设置include_top=False来只使用卷积部分,从而利用其在ImageNet数据集上学到的特征。接着,冻结VGG16的卷积层,即通过将trainable属性设为False,使得这些层在训练过程中不进行更新。接下来,创建了一个新的Sequential模型,并将VGG16的卷积基添加进去,随后使用Flatten层将卷积特征图展平,为全连接层准备输入。为了增加模型的表达能力,添加了两个全连接层,每个层都应用了ReLU激活函数,并使用L2正则化来防止过拟合。为了进一步减少过拟合,模型还在每个全连接层后添加了Dropout层,丢弃30%的神经元。最后,输出层是一个具有两个神经元的全连接层,采用softmax激活函数,用于处理二分类问题。model_fine_tuning.summary()方法输出模型架构,帮助查看各层的结构和参数。通过这种方式,模型能够利用VGG16的预训练卷积基进行特征提取,并通过新添加的全连接层进行分类。

  1. 第五步,编译并训练模型
# 编译模型
model_fine_tuning.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])datagen = ImageDataGenerator(rotation_range=40,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,fill_mode='nearest',preprocessing_function=preprocess_input)  # 使用VGG16的预处理函数# 对原始图像进行增强,并进行训练
history = model_fine_tuning.fit(datagen.flow(x_train, y_train, batch_size=32),epochs=20,validation_data=(x_test, y_test),callbacks=[ModelCheckpoint('best_model.keras', save_best_only=True),EarlyStopping(patience=5)])

这里主要完成了对已经构建的模型(model_fine_tuning)的编译与训练过程。

  • 首先,使用compile()方法对模型进行编译,指定损失函数为binary_crossentropy,适用于二分类问题,同时选择Adam优化器,这是一种自适应学习率的优化算法,能够有效提升训练性能。在编译时,还通过metrics=['accuracy']设置了准确率作为评估指标。
  • 接着,创建了一个ImageDataGenerator对象用于数据增强,它包含多种图像变换方式,如旋转、平移、剪切、缩放、水平翻转等,这些操作可以增加数据多样性,减少过拟合,提升模型的泛化能力。
  • 此外,preprocessing_function=preprocess_input使用了VGG16预训练模型的标准预处理函数,确保输入图像的像素范围符合VGG16的训练要求。
  • 随后,通过fit()方法开始训练模型,训练数据通过datagen.flow()进行增强和批量生成,训练将在20个周期(epochs)内进行。在训练过程中,还设置了两个回调函数:ModelCheckpoint,用于保存最好的模型权重文件(best_model.keras),并且只保存验证集上表现最好的模型;
  • EarlyStopping,用于在验证集准确率不再提升时提前停止训练,patience=5表示如果5个周期内没有改进,则停止训练。这样,通过数据增强和回调函数的配合,能够有效提高训练的效果和模型的稳定性。

到这里,整个部分就基本完成了。

  1. 绘制损失和准确率图像
import matplotlib.pyplot as plt# 获取训练过程中的损失和准确率数据
history_dict = history.history
loss = history_dict['loss']
accuracy = history_dict['accuracy']
val_loss = history_dict['val_loss']
val_accuracy = history_dict['val_accuracy']# 绘制损失图
plt.figure(figsize=(12, 6))# 损失图
plt.subplot(1, 2, 1)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Loss over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()# 准确率图
plt.subplot(1, 2, 2)
plt.plot(accuracy, label='Training Accuracy')
plt.plot(val_accuracy, label='Validation Accuracy')
plt.title('Accuracy over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()# 展示图像
plt.tight_layout()
plt.show()

在这里插入图片描述
数据文件已经上传,感兴趣的小伙伴可以下载后自己尝试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/2387.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

下载文件,浏览器阻止不安全下载

背景: 在项目开发中,遇到需要下载文件的情况,文件类型可能是图片、excell表、pdf、zip等文件类型,但浏览器会阻止不安全的下载链接。 效果展示: 下载文件的两种方式: 一、根据接口的相对url,拼…

【漏洞分析】DDOS攻防分析

0x00 UDP攻击实例 2013年12月30日,网游界发生了一起“追杀”事件。事件的主角是PhantmL0rd(这名字一看就是个玩家)和黑客组织DERP Trolling。 PhantomL0rd,人称“鬼王”,本名James Varga,某专业游戏小组的…

低代码独特架构带来的编译难点及多线程解决方案

前言 在当今软件开发领域,低代码平台以其快速构建应用的能力,吸引了众多开发者与企业的目光。然而,低代码平台独特的架构在带来便捷的同时,也给编译过程带来了一系列棘手的难点。 一,低代码编译的难点 (1…

Android BitmapShader更简易的实现刮刮乐功能,Kotlin

Android BitmapShader更简易的实现刮刮乐功能,Kotlin 比这种方式 Android使用PorterDuffXfermode模式PorterDuff.Mode.SRC_OUT橡皮擦实现“刮刮乐”效果,Kotlin(2)-CSDN博客 更简单实现刮刮乐效果。 import android.content.Cont…

【DB-GPT】开启数据库交互新篇章的技术探索与实践

一、引言:AI原生数据应用开发的挑战与机遇 在数字化转型的浪潮中,企业对于智能化应用的需求日益增长。然而,传统的数据应用开发方式面临着诸多挑战,如技术栈复杂、开发周期长、成本高昂、难以维护等。这些问题限制了智能化应用的…

客户案例:某家居制造企业跨境电商,解决业务端(亚马逊平台)、易仓ERP与财务端(金蝶ERP)系统间的业务财务数据对账互通

一、系统定义 1、系统定位: 数据中台系统是一种战略选择和组织形式,通过有型的产品支撑和实施方法论,解决企业面临的数据孤岛、数据维护混乱、数据价值利用低的问题,依据企业特有的业务和架构,构建一套从数据汇聚、开…

springboot程序快速入门

1.新建springboot项目 一上来输入项目名字语言选javaType选Mavenjdk 1.8java选8packaging选jar 选择对应的springboot版本2.6.13Web里面勾上Spring Web 点击创建即可。 2.手工编辑一个控制器 手动创建一个Controller类: package com.example.springbootgate.con…

【Linux】常见指令(一)

Linux常见指令 01.whoami02.pwd03.ls04.mkdir05.cd 本文LInux环境为,使用XShell远程登陆到Linux。 具体如何环境搭建,大家可以查看其他博客。 01.whoami whoami 指令用来查看当前账户是谁。 如上图所示,使用whoami指令,查看到现在…

鸿蒙UI开发——键盘弹出避让模式设置

1、概 述 我们在鸿蒙开发时,不免会遇到用户输入场景,当用户准备输入时,会涉及到输入法的弹出,我们的界面针对输入法的弹出有两种避让模式:上抬模式、压缩模式。 下面针对输入法的两种避让模式的设置做简单介绍。 2、…

【零基础入门unity游戏开发——unity3D篇】地形Terrain的使用介绍

考虑到每个人基础可能不一样,且并不是所有人都有同时做2D、3D开发的需求,所以我把 【零基础入门unity游戏开发】 分为成了C#篇、unity通用篇、unity3D篇、unity2D篇。 【C#篇】:主要讲解C#的基础语法,包括变量、数据类型、运算符、…

微服务之松耦合

参考:https://microservices.io/post/architecture/2023/03/28/microservice-architecture-essentials-loose-coupling.html There’s actually two different types of coupling: runtime coupling - influences availability design-time coupling - influences…

数据结构之双链表(C语言)

​ 数据结构之双链表(C语言) 1 链表的分类2 双向链表的结构3 双向链表的节点创建与初始化3.1 节点创建函数3.2 初始化函数 4 双向链表插入节点与删除节点的前序分析5 双向链表尾插法与头插法5.1 尾插函数5.2 头插函数 6 双向链表的尾删法与头删法6.1尾删…

Banana Pi BPI-RV2 RISC-V路由开发板采用矽昌通信SF2H8898芯片

Banana Pi BPI-RV2 开源网关是⼀款基于矽昌SF2H8898 SoC的设备,1 2.5 G WAN⽹络接⼝、5 个千兆LAN ⽹络接⼝、板载 512MB DDR3 内存 、128 MiB NAND、16 MiB NOR、M.2接⼝,MINI PCIE和USB 2.0接⼝等。 Banana Pi BPI-RV2 开源网关是矽昌和⾹蕉派开源社…

C语言:数据的存储

本文重点: 1. 数据类型详细介绍 2. 整形在内存中的存储:原码、反码、补码 3. 大小端字节序介绍及判断 4. 浮点型在内存中的存储解析 数据类型结构的介绍: 类型的基本归类: 整型家族 浮点家族 构造类型: 指针类型&…

从代码层面熟悉UniAD,开始学习了解端到端整体架构

0. 简介 最近端到端已经是越来越火了,以UniAD为代表的很多工作不断地在不断刷新端到端的指标,比如最近SparseDrive又重新刷新了所有任务的指标。在端到端火热起来之前,成熟的模块化自动驾驶系统被分解为不同的独立任务,例如感知、…

Go-Zero整合Goose实现MySQL数据库版本管理

推荐阅读 【系列好文】go-zero从入门到精通(看了就会) 教程地址:https://blog.csdn.net/u011019141/article/details/139619172 Go-Zero整合Goose实现MySQL数据库版本管理的教程 在开发中,数据库迁移和版本管理是必不可少的工作。…

day 27 日志文件(枚举,时间函数),目录io,多文件管理

0## 1.获得当前时间 # include <stdio.h> #include <stdlib.h> #include <time.h>int main() {struct tm* ptm;time_t sec time(NULL);ptm localtime(&sec);printf("%d-%d-%d %d:%d:%d\n",ptm->tm_year1900,ptm->tm_mon1,ptm->tm_…

使用Flink-JDBC将数据同步到Doris

在现代数据分析和处理环境中&#xff0c;数据同步是一个至关重要的环节。Apache Flink和Doris是两个强大的工具&#xff0c;分别用于实时数据处理和大规模并行处理&#xff08;MPP&#xff09;SQL数据库。本文将介绍如何使用Flink-JDBC连接器将数据同步到Doris。 一、背景介绍…

【python】OpenCV—Local Translation Warps

文章目录 1、功能描述2、原理分析3、代码实现4、效果展示5、完整代码6、参考 1、功能描述 利用液化效果实现瘦脸美颜 交互式的液化效果原理来自 Gustafsson A. Interactive image warping[D]. , 1993. 2、原理分析 上面描述很清晰了&#xff0c;鼠标初始在 C&#xff0c;也即…

灵活妙想学数学

灵活妙想学数学 题1&#xff1a;海星有几只&#xff1f; 一共有12只海洋生物&#xff0c;分别是5只脚的海星&#xff0c;8只脚的章鱼和10只脚的鱿鱼&#xff0c;这些海洋动物的脚一共有87只&#xff0c;每种生物至少有1只&#xff0c;问海星有几只&#xff1f; 解&#xff1a…