神经网络案例实战

🔎我们通过一个案例详细使用PyTorch实战 ,案例背景:你创办了一家手机公司,不知道如何估算手机产品的价格。为了解决这个问题,收集了多家公司的手机销售数据:这些数据维度可以包括RAM、存储容量、屏幕尺寸、摄像头像素等。

在这个问题中,我们不需要预测实际价格,而是一个价格范围,它的范围使用 0、1、2、3 来表示,所以该问题也是一个分类问题。

🔎思路:

  1. 数据预处理:对收集到的数据进行清洗和预处理,确保数据的质量和一致性。这包括处理缺失值、异常值和重复数据等。

  2. 特征工程:从原始数据中提取有用的特征,以便用于建模。这可以包括对连续型特征进行归一化或标准化,对分类特征进行编码等。

  3. 模型选择:选择一个适合的机器学习算法来建立模型,这里我们使用神经网络模型。

  4. 模型训练:将收集到的数据划分为训练集和测试集。使用训练集来训练模型,通过调整模型的参数来最小化预测误差。

  5. 模型评估:使用测试集来评估模型的性能。常用的评估指标包括均方误差(MSE)、均方根误差(RMSE)和平均绝对误差(MAE)等。

  6. 价格预测:使用训练好的模型来预测新手机的价格。输入新手机的功能特征,模型将输出预测的价格。

构建数据集 

💬数据共有 2000 条, 其中 1600 条数据作为训练集, 400 条数据用作测试集。 我们使用 sklearn 的数据集划分工作来完成。并使用 PyTorch 的 TensorDataset 来将数据集构建为 Dataset 对象,方便后期构造数据集加载对象。

def create_dataset():data = pd.read_csv('predict.csv')# 特征值和目标值x, y = data.iloc[:, :-1], data.iloc[:, -1]x = x.astype(np.float32)y = y.astype(np.int64)# 数据集划分x_train, x_valid, y_train, y_valid = train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)# 构建数据集train_dataset = TensorDataset(torch.tensor(x_train.values), torch.tensor(y_train.values))valid_dataset = TensorDataset(torch.tensor(x_valid.values), torch.tensor(y_valid.values))return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))train_dataset, valid_dataset, input_dim, class_num = create_dataset()

💬其中 train_test_split 方法中的stratify=y参数的作用是在划分训练集和验证集时,保持类别的比例相同。这样可以确保在训练集和验证集中各类别的比例与原始数据集中的比例相同,有助于提高模型的泛化能力,防止出现一份中某个类别只有几个。

构建分类网络模型

# 构建网络模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 128) # 输入层到隐藏层的线性变换self.linear2 = nn.Linear(128, 256)  # 隐藏层到输出层的线性变换self.linear3 = nn.Linear(256, output_dim)  # 输出层到最终输出的线性变换def _activation(self, x):return torch.sigmoid(x)def forward(self, x):x = self._activation(self.linear1(x))# 通过第一层线性变换和激活函数x = self._activation(self.linear2(x))# 通过第二层线性变换和激活函数output = self.linear3(x)# 通过第三层线性变换得到最终输出return output
  • self.linear1self.linear2之间的线性变换将输入维度从input_dim映射到128个神经元,然后再将128个神经元映射到256个神经元。
  • 我们通过self._activation方法对输入数据进行激活函数处理。具体来说,它使用Sigmoid激活函数对输入数据进行非线性变换,将线性变换后的输出映射到0和1之间。
  • 第三层: 输入为维度为 256, 输出维度为: 4

编写训练函数 

def train():# 固定随机数种子torch.manual_seed(66)model = PhonePriceModel(input_dim, class_num)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.SGD(model.parameters(), lr=1e-3)num_epoch = 50for epoch_idx in range(num_epoch):# 初始化数据加载器dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)# 训练时间start = time.time()# 计算损失total_loss = 0.0total_num = 1# 准确率correct = 0for x, y in dataloader:output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_num += len(y)total_loss += loss.item() * len(y)print('epoch: %4s loss: %.2f, time: %.2fs' %(epoch_idx + 1, total_loss / total_num, time.time() - start))# 模型保存torch.save(model.state_dict(), 'price-model.bin')
  • 💬要在PyTorch中查看随机数种子,可以使用torch.random.initial_seed()函数。这个函数会返回当前设置的随机数种子。 

评估函数

def test():# 加载模型model = PhonePriceModel(input_dim, class_num)model.load_state_dict(torch.load('price-model.bin'))# 构建加载器dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)# 评估测试集correct = 0for x, y in dataloader:output = model(x)y_pred = torch.argmax(output, dim=1)correct += (y_pred == y).sum()print('Acc: %.5f' % (correct.item() / len(valid_dataset)))

网络性能调优

  1. 对输入数据进行标准化
  2. 调整优化方法
  3. 调整学习率
  4. 增加批量归一化层
  5. 增加网络层数、神经元个数
  6. 增加训练轮数

🔎我们可以自行将优化方法由 SGD 调整为 Adam , 学习率由 1e-3 调整为 1e-4 ,对数据数据进行标准化 ,增加网络深度 。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/322267.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在Mars3d实现cesium的ImageryLayer自定义瓦片的层级与原点

需要自定义瓦片层级和原点,所以需要自己写第三方图层,但是之前写的很多方法,图层控制和显隐以及透明度,需要跟之前的交互一直,改动量太大的话不划算,所以直接看Mars3d的layer基类,把重写的image…

字符串函数、内存函数——补充

目录 前言 1、strchr函数 1-1 函数介绍 1-1-1 函数功能 1-1-2 函数原型 1-1-3 函数参数 1-1-4 所属库 1-1-5 函数返回值 1-2 函数简单使用 1-3 函数使用场景 1-4 函数的使用总结 1-4-1 注意事项 2、strrchr函数 2-1 函数介绍 2-1-1 函数功能 2-1-2 函数原型 2…

BACnet到OPC UA的楼宇自动化系统与生产执行系统(MES)整合

在智能制造的浪潮下,一家位于深圳的精密电子制造企业面临着前所未有的挑战:如何高效地将楼宇自动化系统与生产执行系统(MES)整合,实现能源管理与生产流程的精细化控制。这家企业的楼宇控制系统使用的是BACnet协议&…

.OpenNJet应用引擎实践——从 0-1 体验感受

目录 一. 🦁 写在前面二. 🦁 安装使用2.1 安装环境2.2 配置yum源2.3 安装软件包2.4 编译代码2.5 启动 三. 🦁 使用效果3.1 编辑配置文件3.2 编辑 HTML 代码 四. 🦁 使用感受 一. 🦁 写在前面 现在互联网体系越来越往云…

【Docker学习】docker run的端口映射-p和-P选项

docker run的端口映射选项分为-p(小写,全称--publish),-P(大写,全称--publish-all),之前认为只有改变容器发布给宿主机的默认端口号才会进行-p的设置,而不改变默认端口号…

MATLAB 基于规则格网的点云抽稀方法(自定义实现)(65)

MATLAB 基于规则格网的点云抽稀方法(自定义实现)(65) 一、算法介绍二、算法实现1.代码2.结果一、算法介绍 海量点云的处理,需要提前进行抽稀预处理,相比MATLAB预先给出的抽稀方法,这里提供一种基于规则格网的自定义抽稀方法,步骤清晰,便于理解抽稀内涵, 主要涉及到使…

为什么你创业总是失败?2024普通人如何创业?2024创业赛道!2024创业新风口!2024创业方向!2024普通人的机会!

为什么你做项目老是不赚钱,是你不够努力吗?是你运气不好吗? 如果都不是!那一定是你的思维逻辑出了问题! 先想一想你以前做的项目,有没有哪个符合以下条件:对客户有价值、寻找客源成本在可接受…

华为eNSP综合实验-网络地址转换

实验完成之后,在AR1的g0/0/1接口抓包,查看地址转换 实现私网pc访问公网pc 实验命令展示 SW1: vlan batch 12 #创建vlan interface e0/0/1 #进入接口配置vlan端口 port link-type access port default vlan 12 q interface e0/0/2 #进入接口配置vlan端口 port link-type ac…

Linux高级学习(前置 在vmware安装centos7.4)

【小白入门 通俗易懂】2021韩顺平 一周学会Linux 此文章包含第006p-第p007的内容 操作 在安装好的vmware下进行安装 这里使用的是vmware15(win10下),win11可能无法使用15(有几率蓝屏),换成16就行了 用迅雷…

(二刷)代码随想录第1天|704. 二分查找 27. 移除元素

704. 二分查找 704. 二分查找 - 力扣(LeetCode) 代码随想录 (programmercarl.com) 手把手带你撕出正确的二分法 | 二分查找法 | 二分搜索法 | LeetCode:704. 二分查找_哔哩哔哩_bilibili 给定一个 n 个元素有序的(升序&#xff09…

微信小程序个人中心、我的界面(示例四)

微信小程序个人中心、我的界面,九宫格简单布局(示例四) 微信小程序个人中心、我的界面,超简洁的九宫格界面布局,代码粘贴即用。更多微信小程序界面示例,请进入我的主页哦! 1、js代码 Page({…

无意的一次学习,竟让我摆脱了Android控制?

由于鸿蒙的爆火,为了赶上时代先锋。到目前为止也研究过很长一段时间。作为一名Android的研发人员,免不了对其评头论足,指导文档如何写才算专业?页面如何绘制?页面如何跳转?有没有四大组件等等。 而Harmony…

pycharm中导入rospy(ModuleNotFoundError: No module named ‘rospy‘)

1. ubuntu安装对应版本ros ubuntu20.04可参考: https://wiki.ros.org/cn/noetic/Installation/Ubuntuhttps://zhuanlan.zhihu.com/p/515361781 2. 安装python3-roslib sudo apt-get install python3-roslib3.在conda环境中安装rospy pip install rospkg pip in…

pycharm调试bash启动的python项目(远程开发同理)

pycharm调试bash启动的python项目(远程开发同理) 步骤 打开运行/调试配置 选择Python 调试服务器 参考打开的页面,在需要debug的虚拟环境中安装依赖环境:pip install pydevd-pycharm~241.14494.158。端口号可以手动指定&…

探索鸿蒙开发:鸿蒙系统如何引领嵌入式技术革新

嵌入式技术已经成为现代社会不可或缺的一部分。而在这个领域,华为凭借其自主研发的鸿蒙操作系统,正悄然引领着一场技术革新的浪潮。本文将探讨鸿蒙开发的特点、优势以及其对嵌入式技术发展的深远影响。 鸿蒙操作系统的特点 鸿蒙,作为华为推…

pytorch基础: torch.unbind()

1. torch.unbind 作用 说明:移除指定维后,返回一个元组,包含了沿着指定维切片后的各个切片。 参数: tensor(Tensor) – 输入张量dim(int) – 删除的维度 2. 案例 案例1 x torch.rand(1,80,3,360,360)y x.unbind(dim2)print(&…

初学python记录:力扣1652. 拆炸弹

题目: 你有一个炸弹需要拆除,时间紧迫!你的情报员会给你一个长度为 n 的 循环 数组 code 以及一个密钥 k 。 为了获得正确的密码,你需要替换掉每一个数字。所有数字会 同时 被替换。 如果 k > 0 ,将第 i 个数字用…

Java中的异常处理机制

Java中的异常处理机制主要通过try、catch和finally三个关键字来实现。以下是Java异常处理机制的工作原理和正确处理异常的一些基本步骤: ## 异常处理机制的工作原理 1. **try**:包围可能抛出异常的代码块。 2. **catch**:捕获并处理特定类型…

Python | Leetcode Python题解之第70题爬楼梯

题目: 题解: class Solution:def climbStairs(self, n: int) -> int:a, b 1, 1for _ in range(n - 1):a, b b, a breturn b

前端 Android App 上架详细流程 (Android App)

1、准备上架所需要的材料 先在需要上架的官方网站注册账号。提前把手机号,名字,身份证等等材料准备好,完成开发者实名认证;软著是必要的,提前准备好,软著申请时间比较长大概需要1-2周时间才能下来&#xf…