基于pytoch卷积神经网络水质图像分类实战

具体怎么学习pytorch,看b站刘二大人的视频。

完整代码:

import numpy as np
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
'''https://zhuanlan.zhihu.com/p/156926543'''
# 定义图片目录
image_dir = 'images'# 初始化图片路径列表
img_list = []# 遍历指定目录及其子目录中的所有文件
for parent, _, filenames in os.walk(image_dir):for filename in filenames:# 拼接文件的完整路径filename_path = os.path.join(parent, filename)img_list.append(filename_path)# 初始化图像张量列表和标签列表
image_tensors = []
y_list = []for image_path in img_list:# 提取标签 (假设标签是文件名的第一个字符)label = int(os.path.basename(image_path)[0])y_list.append(label)# 打开图像img = Image.open(image_path)# 获取图像尺寸width, height = img.size# 定义裁剪的区域(假设要保留图像中心的 100x100 区域)left = (width - 100) / 2top = (height - 100) / 2right = (width + 100) / 2bottom = (height + 100) / 2# 裁剪图像img = img.crop((left, top, right, bottom))# 将图像转换为 NumPy 数组img_array = np.asarray(img)# 将 NumPy 数组转换为 PyTorch 张量img_tensor = torch.from_numpy(img_array).float()# 如果图像是 RGB,将其转换为 (C, H, W) 格式if img_tensor.ndimension() == 3 and img_tensor.shape[2] == 3:img_tensor = img_tensor.permute(2, 0, 1)  # 从 (H, W, C) 变为 (C, H, W)# 增加 batch 维度img_tensor = img_tensor.unsqueeze(0)  # 从 (C, H, W) 变为 (1, C, H, W)# 规范化到0-1之间img_tensor = img_tensor / 255.0# 添加到图像张量列表image_tensors.append(img_tensor)# 打印图像张量的形状print(f"当前图像形状: {img_tensor.shape}")# 将图像张量列表转换为四维张量
x_data = torch.cat(image_tensors, dim=0)
# 遍历 y_list 中的每个元素,并将每个数减去 1
for i in range(len(y_list)):y_list[i] -= 1# 将标签列表转换为张量
y_labels = torch.tensor(y_list).long()  # 注意这里使用 .long() 方法将标签转换为长整型print(x_data.shape,y_labels.shape)
print(y_labels)# 定义数据集和数据加载器
class CustomDataset(torch.utils.data.Dataset):def __init__(self, x_data, y_labels):self.x_data = x_dataself.y_labels = y_labelsdef __len__(self):return len(self.x_data)def __getitem__(self, idx):return self.x_data[idx], self.y_labels[idx]# 使用自定义数据集和数据加载器
custom_dataset = CustomDataset(x_data, y_labels)
train_size = int(0.8 * len(custom_dataset))
val_size = len(custom_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(custom_dataset, [train_size, val_size])train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)# 定义卷积神经网络模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32 * 25 * 25, 128)self.fc2 = nn.Linear(128, 5)  # 假设有5个类别def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = torch.flatten(x, 1)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 训练模型
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(50):  # 假设训练50个epochrunning_loss = 0.0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}")# 在每个epoch结束后,计算并打印验证集的准确率model.eval()  # 将模型设置为评估模式correct = 0total = 0with torch.no_grad():for inputs, labels in val_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()val_accuracy = correct / totalprint(f'Validation Accuracy after Epoch {epoch + 1}: {val_accuracy}')

本数据集中x_data的维度是四维张量(203,3,100,100),y_labels的维度是一维张量 

代码中需要注意的点,卷积模型接受的是四维张量,因此要转变为四维张量。

全连接层中输入的特征数,需要自己计算,通过前面卷积层和池化层后,计算总的维度数。一般是最后的通道数*高度*宽度

定义模型中的forward函数中,在经过全连接层计算前,需要将四维的x转为2维

如果 x 的形状是 (64, 32, 28, 28),表示一个批次大小为64的图像张量,其中每个图像有32个通道,高度和宽度都是28像素。现在,我们希望将这个张量展平为一个二维张量,以便输入到全连接层进行进一步处理。

通过 torch.flatten(x, 1) 操作,我们将在指定维度(这里是第一个维度,也就是通道维度)上对张量进行展平。展平后的张量形状将变为 (64, 32*28*28),其中64是批次大小,而 32*28*28 是展平后的特征数量,即每个图像的特征数量。这与前面定义的全连接层的输入特征数要一致。

Dataloader中batch_size就是设置第一个维度,比如这里的batch_size是32,那么

for inputs, labels in train_loader:

 这里的inputs维度是(32,3,100,100)

新学习pytorch中的分割数据集与测试集方法。

# 使用自定义数据集和数据加载器
custom_dataset = CustomDataset(x_data, y_labels)
train_size = int(0.8 * len(custom_dataset))
val_size = len(custom_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(custom_dataset, [train_size, val_size])train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)

结果展现,可以看见准确率有0.82:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/345879.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

时序数据库是Niche Market吗?

引言 DB-Engines的流行程度排行从其评估标准[4]可以看出完全不能够做为市场规模的评估标准。甚至于在知道市场规模后可以用这个排行作为一个避雷手册。毕竟现存市场小,可预见增长规模小,竞争大,创新不足,那只能卷价格&#xff0c…

【文献阅读】LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS

目录 1. motivation2. overall3. model3.1 low rank parametrized update matrices3.2 applying lora to transformer 4. limitation5. experiment6. 代码7. 补充参考文献 1. motivation 常规的adaptation需要的微调成本过大现有方法的不足: Adapter Layers Introd…

.net core 使用js,.net core 使用javascript,在.net core项目中怎么使用javascript

.net core 使用js,.net core 使用javascript,在.net core项目中怎么使用javascript 我项目里需要用到“文字编码”,为了保证前端和后端的编码解码不处bug, 所以,我在项目中用了这个 下面推荐之前在.net F4.0时的方法 文章一&#…

操作系统真象还原:中断

第7章-中断 这是一个网站有所有小节的代码实现,同时也包含了Bochs等文件 7.2操作系统是中断驱动的 没有中断,操作系统几乎什么都做不了,操作系统是中断驱动。 7.3中断分类 7.3.1外部中断 外部中断是指来自 CPU 外部的中断,而…

Python基础教程(八):迭代器与生成器编程

💝💝💝首先,欢迎各位来到我的博客,很高兴能够在这里和您见面!希望您在这里不仅可以有所收获,同时也能感受到一份轻松欢乐的氛围,祝你生活愉快! 💝&#x1f49…

时隔很久运行苍穹外卖项目,出现很多错误

中途运行了很多其他项目,maven的配置文件还被我修改了一次。导致再次运行苍穹外卖项目出现很多错误。 发现没有办法,把本地的仓库删了个干干净净。然后点击clean发现报错: Cannot access alimaven (http://mavejavascript:void(0);n.aliyun.…

力扣每日一题 6/10

881.救生艇[中等] 题目: 给定数组 people 。people[i]表示第 i 个人的体重 ,船的数量不限,每艘船可以承载的最大重量为 limit。 每艘船最多可同时载两人,但条件是这些人的重量之和最多为 limit。 返回 承载所有人所需的最小船…

Vue15-watch对比计算属性

一、姓名案例 1-1、watch实现 1-2、计算属性 对比发现: 计算属性比watch属性更简略一些。 1-3、计算属性 VS 侦听属性 1-4、需求变更 计算属性中不能开启异步任务!!!因为计算属性靠return返回值。但是watch靠亲自写代码去改。 1-…

streamlit:如何快速构建一个应用,不会前端也能写出好看的界面

通过本文你可以了解到: 如何安装streamlit,运行起来第一个demo熟悉streamlit的基本语法,常用的一些组件使用streamlit库构建应用 大模型学习参考: 大模型学习资料整理:如何从0到1学习大模型,搭建个人或企业…

设计模式之观察者模式ObserverPattern(十一)

一、概述 观察者模式 (Observer Pattern) 是一种行为型设计模式,又被称为发布-订阅 (Publish/Subscribe) 模式,它定义了对象之间的一种一对多的依赖关系,使得当一个对象的状态发生变化时,所有依赖于它的对象都会自动收到通知并更新…

Duck Bro的第512天创作纪念日

Tips:发布的文章将会展示至 里程碑专区 ,也可以在 专区 内查看其他创作者的纪念日文章 我的创作纪念日第512天 文章目录 我的创作纪念日第512天一、与CSDN平台的相遇1. 为什么在CSDN这个平台进行创作?2. 创作这些文章是为了赚钱吗&#xff1f…

手写kNN算法的实现-用欧几里德空间来度量距离

kNN的算法思路:找K个离预测点最近的点,然后让它们进行投票决定预测点的类型。 step 1: kNN存储样本点的特征数据和标签数据step 2: 计算预测点到所有样本点的距离,关于这个距离,我们用欧几里德距离来度量(其实还有很多…

ui自动化中,selenium进行元素定位,以及CSS,xpath定位总结

几种定位方式 简单代码 from selenium import webdriver import time# 创建浏览器驱动对象 from selenium.webdriver.common.by import Bydriver webdriver.Chrome() # 参数写浏览器驱动文件的路径,若配置到环境变量就不用写了 # 访问网址 driver.get…

WPF-UI布局

WPF布局元素有如下几个: Grid:网格。可以自定义行和列并通过行列的数量、行高和列宽来调整控件的布局。StackPanel:栈式面板。可将包含的元素在竖直或水平方向上排成一条直线,当移除一个元素后,后面的元素会自动向前移…

两台电脑通过网线直连共享数据(超详细)- 我的实践记录

原文链接 按照原文的操作,成功通过直连网线连接了两台windows电脑并共享传输数据。 ping不通可能是防火墙没关闭导致的,但是完全关闭防火墙又不安全。 那么有没有不关闭防火墙,能够上网,又能直连另一台电脑呢? 我们…

Linux启动KKfileview文件在线浏览时报错:启动office组件失败,请检查office组件是否可用

目录 1、导论 2、报错信息 3、问题分析 4、解决方法 4.1、下载 4.2、安装步骤 1、导论 今天进行项目部署时,遇到了一个问题。在启动kkfileview时,出现了报错异常: 2024-06-09 06:36:44.765 ERROR 1 --- [ main] cn.keking.service.Of…

全球溃败,苹果可能要全球大降价了,试图摆脱中国制造的后果

苹果一季度在中国市场的出货量暴跌,导致它不得不在中国市场大降价,从3月份就在中国市场大幅度降价,然而目前它在美国和欧洲两大市场也出现大幅衰退,降价可能将成为苹果在全球的举措。 市调机构Canalys公布的一季度数据显示&#x…

Warning: `ReactDOMTestUtils.act` is deprecated in favor of `React.act`.

问题:在代码中使用jest进行单元测试时,报错如下: 解决思路: 根据报错提示出来的 react-dom/test-utils 进行全局搜索,发现没有该引用,故进入该代码块中分析。发现代码中引入testing-library/react &#…

Git使用总结(git使用,git实操,git命令和常用指令)

简介:Git是一款代码版本管理工具,可以记录每次提交的代码,防止代码丢失,可实现版本迭代,解决代码冲突,常用的远程Git仓库:Gitee(国内)、GitHub(国外&#xff…

基于websocket与node搭建简易聊天室

一、前言 上一篇文章介绍了websocket的详细用法与工具类的封装,本篇就基于websocket搭建一个简易实时的聊天室。 在本篇开始之前也可以去回顾一下websocket详细用法:WebSocket详解与封装工具类 二、基于node搭建后台websocket服务 首先确认本机电脑中…