12.2深度学习_项目实战

十、项目实战

鲍勃开了自己的手机公司。他想与苹果、三星等大公司展开硬仗。
他不知道如何估算自己公司生产的手机的价格。在这个竞争激烈的手机市场,你不能简单地假设事情。为了解决这个问题,他收集了各个公司的手机销售数据。
鲍勃想找出手机的特性(例如:RAM、内存等)和售价之间的关系。但他不太擅长机器学习。所以他需要你帮他解决这个问题。
在这个问题中,你不需要预测实际价格,而是要预测一个价格区间,表明价格多高。

需要注意的是: 在这个问题中,我们不需要预测实际价格,而是一个价格范围,它的范围使用 0、1、2、3 来表示,所以该问题也是一个分类问题。

数据说明:https://tianchi.aliyun.com/dataset/157241 Mobile Price Classification

推荐专业的数据集平台:https://www.kaggle.com/datasets/iabhishekofficial/mobile-price-classification

字段说明
battery_power电池容量(mAh)
blue是否支持蓝牙
clock_speed微处理器执行指令的速度
dual_sim是否支持双卡
fc前置摄像头分辨率(百万像素)
four_g是否支持4G
int_memory存储内存(GB)
m_dep手机厚度(厘米)
mobile_wt重量
n_cores核心数
pc主摄像头分辨率(百万像素)
px_height屏幕分辨率高度(像素)
px_width屏幕分辨率宽度(像素)
ram运行内存(MB)
sc_h屏幕长度(厘米)
sc_w屏幕宽度(厘米)
talk_time单次充电最长通话时间
three_g是否支持3G
touch_screen是否是触摸屏
wifi是否支持WIFI
price_range价格区间

1. 构建数据集

数据共有 2000 条, 其中 1600 条数据作为训练集, 400 条数据用作测试集。 我们使用 sklearn 的数据集划分工作来完成。并使用 PyTorch 的 TensorDataset 来将数据集构建为 Dataset 对象,方便构造数据集加载对象。

# 构建数据集
def create_dataset():data = pd.read_csv('data/手机价格预测.csv')# 特征值和目标值x, y = data.iloc[:, :-1], data.iloc[:, -1]x = x.astype(np.float32)y = y.astype(np.int64)# 数据集划分x_train, x_valid, y_train, y_valid=train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)# 构建数据集train_dataset = TensorDataset(torch.from_numpy(x_train.values), torch.tensor(y_train.values))valid_dataset = TensorDataset(torch.from_numpy(x_valid.values), torch.tensor(y_valid.values))return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))train_dataset, valid_dataset, input_dim, class_num = create_dataset()

2. 构建分类网络模型

我们构建的用于手机价格分类的模型叫做全连接神经网络。它主要由三个线性层来构建,在每个线性层后,我们使用的时 sigmoid 激活函数。

# 构建网络模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 128)self.linear2 = nn.Linear(128, 256)self.linear3 = nn.Linear(256, output_dim)def _activation(self, x):return torch.sigmoid(x)def forward(self, x):x = self._activation(self.linear1(x))x = self._activation(self.linear2(x))output = self.linear3(x)return output

3. 编写训练函数

网络编写完成之后,我们需要编写训练函数。所谓的训练函数,指的是输入数据读取、送入网络、计算损失、更新参数的流程,该流程较为固定。我们使用的是多分类交叉生损失函数、使用 SGD 优化方法。最终,将训练好的模型持久化到磁盘中。

def train():# 固定随机数种子torch.manual_seed(0)# 初始化模型model = PhonePriceModel(input_dim, class_num)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.SGD(model.parameters(), lr=1e-3)# 训练轮数num_epoch = 50for epoch_idx in range(num_epoch):# 初始化数据加载器dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)# 训练时间start = time.time()# 计算损失total_loss = 0.0total_num = 1# 准确率correct = 0for x, y in dataloader:output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_num += len(y)total_loss += loss.item() * len(y)print('epoch: %4s loss: %.2f, time: %.2fs' %(epoch_idx + 1, total_loss / total_num, time.time() - start))# 模型保存torch.save(model.state_dict(), 'model/phone-price-model.bin')

4. 编写评估函数

评估函数、也叫预测函数、推理函数,主要使用训练好的模型,对未知的样本的进行预测的过程。我们这里使用前面单独划分出来的测试集来进行评估。

def test():# 加载模型model = PhonePriceModel(input_dim, class_num)model.load_state_dict(torch.load('model/phone-price-model.bin'))# 构建加载器dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)# 评估测试集correct = 0for x, y in dataloader:output = model(x)y_pred = torch.argmax(output, dim=1)correct += (y_pred == y).sum()print('Acc: %.5f' % (correct.item() / len(valid_dataset)))

5. 网络性能调优

我们前面的网络模型在测试集的准确率为: 0.54750, 我们可以通过以下方面进行调优:

  1. 对输入数据进行标准化
  2. 调整优化方法
  3. 调整学习率
  4. 增加批量归一化层
  5. 增加网络层数、神经元个数
  6. 增加训练轮数
  7. 等等…

我进行下如下调整:

  1. 优化方法由 SGD 调整为 Adam
  2. 学习率由 1e-3 调整为 1e-4
  3. 对数据数据进行标准化
  4. 增加网络深度, 即: 增加网络参数量

网络模型在测试集的准确率由 0.5475 上升到 0.9625

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np
import time
from sklearn.preprocessing import StandardScaler# 构建数据集
def create_dataset():data = pd.read_csv('data/手机价格预测.csv')# 特征值和目标值x, y = data.iloc[:, :-1], data.iloc[:, -1]x = x.astype(np.float32)y = y.astype(np.int64)# 数据集划分x_train, x_valid, y_train, y_valid = \train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)# 数据标准化transfer = StandardScaler()x_train = transfer.fit_transform(x_train)x_valid = transfer.transform(x_valid)# 构建数据集train_dataset = TensorDataset(torch.from_numpy(x_train), torch.tensor(y_train.values))valid_dataset = TensorDataset(torch.from_numpy(x_valid), torch.tensor(y_valid.values))return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))train_dataset, valid_dataset, input_dim, class_num = create_dataset()# 构建网络模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 128)self.linear2 = nn.Linear(128, 256)self.linear3 = nn.Linear(256, 512)self.linear4 = nn.Linear(512, 128)self.linear5 = nn.Linear(128, output_dim)def _activation(self, x):return torch.sigmoid(x)def forward(self, x):x = self._activation(self.linear1(x))x = self._activation(self.linear2(x))x = self._activation(self.linear3(x))x = self._activation(self.linear4(x))output = self.linear5(x)return output# 编写训练函数
def train():# 固定随机数种子torch.manual_seed(0)# 初始化模型model = PhonePriceModel(input_dim, class_num)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.Adam(model.parameters(), lr=1e-4)# 训练轮数num_epoch = 50for epoch_idx in range(num_epoch):# 初始化数据加载器dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)# 训练时间start = time.time()# 计算损失total_loss = 0.0total_num = 1# 准确率correct = 0for x, y in dataloader:output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_num += len(y)total_loss += loss.item() * len(y)print('epoch: %4s loss: %.2f, time: %.2fs' %(epoch_idx + 1, total_loss / total_num, time.time() - start))# 模型保存torch.save(model.state_dict(), 'model/phone-price-model.bin')def test():# 加载模型model = PhonePriceModel(input_dim, class_num)model.load_state_dict(torch.load('model/phone-price-model.bin'))# 构建加载器dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)# 评估测试集correct = 0for x, y in dataloader:output = model(x)y_pred = torch.argmax(output, dim=1)correct += (y_pred == y).sum()print('Acc: %.5f' % (correct.item() / len(valid_dataset)))if __name__ == '__main__':train()test()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/483461.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习概述,特征工程简述2.1——2.3

机器学习概述: 1.1人工智能概述 达特茅斯会议—人工智能的起点 机器学习是人工智能的一个实现途径 深度学习是机器学习的一个方法发展而来 1.1.2 机器学习和深度学习能做什么 传统预测 图像识别 自然语言处理 1.2什么是机器学习 数据 模型 预测 从历史数…

嵌入式蓝桥杯学习1 点亮LED

cubemx配置 1.新建一个STM32G431RBT6文件 2.在System-Core中点击SYS,找到Debug(设置为Serial Wire) 3.在System-Core中点击RCC,找到High Speed Clock(设置为Crystal/Ceramic Resonator) 4.打开Clock Configuration &#xff0…

【MySql】navicat连接报2013错误

navicat连接mysql报2013错误 报错信息1、检验Mysql数据库是否安装成功2、对Mysql的配置文件进行修改配置2.1、找到配置文件2.2、Linux下修改配置文本 3、连接进入mysql服务4、在mysql下执行授权命令 报错信息 Navicat连接mysql报2013错误 2013-Lost connection to MYSQL serve…

Next.js 路由使用完整指南

Next.js 路由使用指南 目录 基础路由 index 路由页面路由布局路由 嵌套路由 文件夹嵌套共享布局 动态路由 单参数路由多参数路由可选参数 路由组 组织结构共享布局 平行路由 同时渲染条件渲染 拦截路由 模态框照片预览 最佳实践 路由组织性能优化类型安全 1. 基础路由 Nex…

Vue2-从零搭建一个项目(项目基本结构介绍)

目录 一、脚手架安装 二、项目结构介绍 1、项目结构介绍 2、组件关系与脚手架入口内置关系 (1)组件关系 (2)脚手架入口内置关系 一、脚手架安装 在默认安装Node.js的前提下,需要进行两两步操作 直接参照下面的…

Redis 之持久化

目录 介绍 RDB RDB生成方式 自动触发 手动触发 AOF(append-only file) Redis 4.0 混合持久化 Redis主从工作原理 总结 介绍 Redis提供了两个持久化数据的能力,RDB Snapshot 和 AOF(Append Only FIle)…

8. Debian系统中显示屏免密码自动登录

本文介绍如何在Debian系统上,启动后,自动免密登录,不卡在登录界面。 1. 修改lightDM配置文件 嵌入式Debian系统采用lightDM显示管理器,所以,一般需要修改它的配置文件/etc/lightdm/lightdm.conf,找到[Seat…

Linux下,用ufw实现端口关闭、流量控制(二)

本文是 网安小白的端口关闭实践 的续篇。 海量报文,一手掌握,你值得拥有,让我们开始吧~ ufw 与 iptables的关系 理论介绍: ufw(Uncomplicated Firewall)是一个基于iptables的前端工具&#xf…

MySQL常见面试题(二)

MySQL 索引 MySQL 索引相关的问题比较多,对于面试和工作都比较重要,于是,我单独抽了一篇文章专门来总结 MySQL 索引相关的知识点和问题:MySQL 索引详解 。 MySQL 查询缓存 MySQL 查询缓存是查询结果缓存。执行查询语句的时候&a…

红日靶场vulnstark 2靶机的测试报告

目录 一、测试环境 1、系统环境 2、注意事项 3、使用工具/软件 二、测试目的 三、操作过程 1、信息搜集 2、Weblogic漏洞利用 3、Getshell 4、CS上线 5、内网信息收集 利用zerologon漏洞攻击域控服务器(获取密码) 6、横向移动 ①使用PsExec上线域控服务器 ②使用…

用于LiDAR测量的1.58um单芯片MOPA(一)

--翻译自M. Faugeron、M. Krakowski1等人2014年的文章 1.简介 如今,人们对高功率半导体器件的兴趣日益浓厚,这些器件主要用于遥测、激光雷达系统或自由空间通信等应用。与固态激光器相比,半导体器件更紧凑且功耗更低,这在低功率供…

MFC工控项目实例三十五读取数据库数据

点击按钮打开文件夹中的数据文件生成曲线 相关代码 void CSEAL_PRESSUREDlg::OnTesReport() {CFileDialog dlgOpen(TRUE/*TRUE打开,FALSE保存*/,0,0,OFN_NOCHANGEDIR|OFN_FILEMUSTEXIST,"All Files(mdb.*)|*.*||",//文件过滤器NULL);CString mdb_1, m…

Harnessing Large Language Models for Training-free Video Anomaly Detection

标题:利用大型语言模型实现无训练的视频异常检测 原文链接:https://openaccess.thecvf.com/content/CVPR2024/papers/Zanella_Harnessing_Large_Language_Models_for_Training-free_Video_Anomaly_Detection_CVPR_2024_paper.pdf 源码链接:ht…

Linux笔试题(自己整理,已做完,选择题)

详细Linux内容查看:day04【入门】Linux系统操作-CSDN博客 1、部分笔试题 本文的笔试题,主要是为了复习学习的day04【入门】Linux系统操作-CSDN博客的相关知识点。后续还会更新一些面试相关的题目。 欢迎一起学习

BA是什么?

1.BA的定义 BA的中文译为“光束法平差”,也有翻译为“束调整”、“捆绑调整”等,是一种用于计算机视觉和机器人领域的优化技术,主要用于精确优化相机参数(包括内参数和外参数)和三维空间中特征点的位置。BA的目标是通过最小化重投影误差来提高三维重建的精度和一致性。重投影误…

Windows系统搭建Docker

Windows系统搭建Docker 一、系统虚拟化1.1启用虚拟化1.2启用Hyper-v并开启虚拟任务 二、安装WSL2.1 检验安装2.2 命令安装WSL(与2.3选其一)2.3 手动安装WSL(与2.2选其一)2.4 将 WSL 2 设置为默认版本 三、docker安装 一、系统虚拟…

洛谷二刷P4715 【深基16.例1】淘汰赛(c嘎嘎)

题目链接:P4715 【深基16.例1】淘汰赛 - 洛谷 | 计算机科学教育新生态 题目难度:普及 刷题心得:本题是我二刷,之前第一次刷是在洛谷线性表那个题单,当时印象深刻第 一篇题解是用的树来做,当时我不屑一顾&…

基于Matlab BP神经网络的电力负荷预测模型研究与实现

随着电力系统的复杂性和规模的不断增长,准确的电力负荷预测对于电网的稳定性和运行效率至关重要。传统的负荷预测方法依赖于历史数据和简单的统计模型,但这些方法在处理非线性和动态变化的负荷数据时,表现出较大的局限性。近年来,…

非标自动化行业ERP选型与案例展示!

非标自动化行业,那么使用的就是非标设备,什么是非标设备呢?用一句话来说明就是指设计制造方面没有形成国家标准的设备。 在如今追求高效的社会,各行各业都朝着提高效率精益工艺,缩减流程,调整业务,用各种…

十、软件设计架构-微服务-服务调用Dubbo

文章目录 前言一、Dubbo介绍1. 什么是Dubbo 二、实现1. 提供统一业务api2. 提供服务提供者3. 提供服务消费者 前言 服务调用方案--Dubbo‌ 基于 Java 的高性能 RPC分布式服务框架,致力于提供高性能和透明化的RPC远程服务调用方案,以及SOA服务治理方案。…