价格分类(神经网络)

# 1.导入依赖包
import timeimport torch
import torch.nn as nn
import torch.optim as optimfrom torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_splitimport numpy as np
import pandas as pd
import matplotlib.pyplot as pltfrom torchsummary import summary# 2.构建数据集
def create_dataset():# 2.1 读取数据集data = pd.read_csv('dataset/手机价格预测.csv')# 2.2 获取特征值和目标值,类型转化  特征(Float)  标签(Long)x, y = data.iloc[:, :-1], data.iloc[:, -1]x, y = x.astype(np.float32), y.astype(np.int64)# 2.3 数据集划分x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2,random_state=2)# 2.4 数据转Tensortrain_dataset = TensorDataset(torch.from_numpy(x_train.values), torch.tensor(y_train.values))test_dataset = TensorDataset(torch.from_numpy(x_test.values), torch.tensor(y_test.values))return train_dataset, test_dataset, x_train.shape[1], len(np.unique(y))# 3. 构建模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 256)self.linear2 = nn.Linear(256, 1024)self.fc = nn.Linear(1024, output_dim)def forward(self, x):x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))output = self.fc(x)# output = torch.softmax(self.fc(x), dim=-1)return output# 4.模型训练(225)
def train(model, train_dataset, num_epochs, batch_size):# 2 初始化参数  损失函数  优化器loss1 = nn.CrossEntropyLoss()# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.99, 0.99))start = time.time()# 2 2个遍历  epoch  dataloaderfor epoch in range(num_epochs):dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)total_num = 0total_loss = 0.0for x, y in dataloader:# 5 前向传播  损失计算 梯度归零  反向传播 参数更新output = model(x)loss = loss1(output, y)optimizer.zero_grad()loss.backward()optimizer.step()total_num += 1  # 批次total_loss += loss.item()epoch += 1print(f'epoch:{epoch + 1:4d},loss:{total_loss / (total_num * epoch):.4f}, time:{time.time() - start:.2f}s')# 模型持久化torch.save(model.state_dict(), 'model/phone2.pth')# 5.模型预测评估
def test(model, test_dataset, input_dim, output_dim):# 3.导入数据dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)correct = 0# 4.遍历数据for x, y in dataloader:# 4.1 前向传播output = model(x)print(output)# 4.2 获取输出结果(类别)y_pred = torch.argmax(output, dim=1)# print(y_pred)  # 预测错误# 4.3 计算准确率Acccorrect += (y_pred == y).sum()print(correct.item())Acc = correct.item() / len(test_dataset)return Accif __name__ == '__main__':train_dataset, test_dataset, feature_num, label_num = create_dataset()# 1.实例化模型model = PhonePriceModel(feature_num, label_num)# 2.加载模型model.load_state_dict(torch.load('model/phone2.pth'))# 模型训练# train(model, train_dataset, num_epochs=50, batch_size=8)# 模型预测Acc = test(model, test_dataset, feature_num, label_num)print(f'Acc:{Acc:.5f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/478700.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

周志华深度森林deep forest(deep-forest)最新可安装教程,仅需在pycharm中完成,超简单安装教程

1、打开pycharm 没有pycharm的,在站内搜索安装教程即可。 2、点击“文件”“新建项目” 3、创建项目,Python版本中选择Python39。如果没有该版本,选择下面的Python 3.9下载并安装。 4、打开软件包,搜索“deep-forest”软件包&am…

ES 和Kibana-v2 带用户登录验证

1. 前言 ElasticSearch、可视化操作工具Kibana。如果你是Linux centos系统的话,下面的指令可以一路CV完成服务的部署。 2. 服务搭建 2.1. 部署ElasticSearch 拉取docker镜像 docker pull elasticsearch:7.17.21 创建挂载卷目录 mkdir /**/es-data -p mkdir /**/…

分布式kettle调度平台v6.4.0新功能介绍

介绍 Kettle(也称为Pentaho Data Integration)是一款开源的ETL(Extract, Transform, Load)工具,由Pentaho(现为Hitachi Vantara)开发和维护。它提供了一套强大的数据集成和转换功能&#xff0c…

力扣hot100-->排序

排序 1. 56. 合并区间 中等 以数组 intervals 表示若干个区间的集合,其中单个区间为 intervals[i] [starti, endi] 。请你合并所有重叠的区间,并返回 一个不重叠的区间数组,该数组需恰好覆盖输入中的所有区间 。 示例 1: 输…

.net 8使用hangfire实现库存同步任务

C# 使用HangFire 第一章:.net Framework 4.6 WebAPI 使用Hangfire 第二章:net 8使用hangfire实现库存同步任务 文章目录 C# 使用HangFire前言项目源码一、项目架构二、项目服务介绍HangFire服务结构解析HangfireCollectionExtensions 类ModelHangfireSettingsHttpAuthInfoUs…

滑动窗口最大值(java)

题目描述 给你一个整数数组 nums,有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最右侧。你只可以看到在滑动窗口内的 k 个数字。滑动窗口每次只向右移动一位。 返回 滑动窗口中的最大值 。 示例 1: 输入:nums [1,3,-1,-3,5,3,6,7]…

springboot项目使用maven打包,第三方jar问题

springboot项目使用maven package打包为可执行jar后,第三方jar会被打包进去吗? 答案是肯定的。做了实验如下: 第三方jar的项目结构及jar包结构如下:(该第三方jar采用的是maven工程,打包为普通jar&#xf…

常用Rust日志处理工具教程

在本文中,我想讨论Rust中的日志。通过一些背景信息,我将带您了解两个日志库:env_logger和log4rs。最后,我将分享我的建议和github的片段。 Rust log介绍 log包是Rust中日志API的事实标准,共有五个日志级别&#xff1…

嵌入式的C/C++:深入理解 static、const 与 volatile 的用法与特点

目录 一、static 1、static 修饰局部变量 2、 static 修饰全局变量 3、static 修饰函数 4、static 修饰类成员 5、小结 二、const 1、const 修饰普通变量 2、const 修饰指针 3、const 修饰函数参数 4. const 修饰函数返回值 5. const 修饰类成员 6. const 与 #defi…

时间请求参数、响应

(7)时间请求参数 1.默认格式转换 控制器 RequestMapping("/commonDate") ResponseBody public String commonDate(Date date){System.out.println("默认格式时间参数 date > "date);return "{module : commonDate}"; }…

SpringBoot(9)-Dubbo+Zookeeper

目录 一、了解分布式系统 二、RPC 三、Dubbo 四、SpringBootDubboZookeeper 4.1 框架搭建 4.2 实现RPC 一、了解分布式系统 分布式系统:由一组通过网络进行通信,为了完成共同的任务而协调工作的计算机节点组成的系统 二、RPC RPC:远程…

单片机学习笔记 8. 矩阵键盘按键检测

更多单片机学习笔记:单片机学习笔记 1. 点亮一个LED灯单片机学习笔记 2. LED灯闪烁单片机学习笔记 3. LED灯流水灯单片机学习笔记 4. 蜂鸣器滴~滴~滴~单片机学习笔记 5. 数码管静态显示单片机学习笔记 6. 数码管动态显示单片机学习笔记 7. 独立键盘 目录 0、实现的…

道品智能科技移动式水肥一体机:农业灌溉施肥的革新之选

在现代农业的发展进程中,科技的力量正日益凸显。其中,移动式水肥一体机以其独特的可移动性、智能化以及实现水肥一体化的卓越性能,成为了农业领域的一颗璀璨新星。它不仅改变了传统的农业灌溉施肥方式,更为农业生产带来了高效、精…

android 音效可视化--Visualizer

Visualizer 是使应用程序能够检索当前播放音频的一部分以进行可视化。它不是录音接口,仅返回部分低质量的音频内容。但是,为了保护某些音频数据的隐私,使用 Visualizer 需要 android.permission.RECORD_AUDIO权限。传递给构造函数的音频会话 …

计算机网络八股整理(一)

计算机网络八股文整理 一:网络模型 1:网络osi模型和tcp/ip模型分别介绍一下 osi模型是国际标准的网络模型,它由七层组成,从上到下分别是:应用层,表示层,会话层,传输层,…

利用Python爬虫获得1688按关键字搜索商品:技术解析

在电商领域,1688作为中国领先的B2B电商平台,其商品搜索功能对于商家来说具有极高的价值。通过获取搜索结果,商家可以更好地了解市场趋势,优化产品标题,提高搜索排名。本文将介绍如何使用Python编写爬虫,以获…

Spring Boot集成MyBatis-Plus:自定义拦截器实现动态表名切换

Spring Boot集成MyBatis-Plus:自定义拦截器实现动态表名切换 一、引言 介绍动态表名的场景需求,比如多租户系统、分表分库,或者不同业务模块共用一套代码但操作不同表。说明 MyBatis-Plus 默认绑定固定表名的问题。 二、项目配置 1. 集成 M…

(原创)Android Studio新老界面UI切换及老版本下载地址

前言 这两天下载了一个新版的Android Studio,发现整个界面都发生了很大改动: 新的界面的一些设置可参考一些博客: Android Studio新版UI常用设置 但是对于一些急着开发的小伙伴来说,没有时间去适应,那么怎么办呢&am…

数据新时代:如何选择现代数据治理平台(上)

谈现代数据治理系统的十大架构特征 最近一位老友找到我,咨询他的数据治理平台到底该不该换,背景是这样的:若干年前采购了一个市场主流的数据治理平台,功能大概就是数据治理三件套——标准、元数据和质量等经典数据治理的功能。现…

抖音SEO矩阵系统:开发技术分享

市场环境剖析 短视频SEO矩阵系统是一种策略,旨在通过不同平台上的多个账号建立联系,整合同一品牌下的各平台粉丝流量。该系统通过遵循每个平台的规则和内容要求,输出企业和品牌形象,以矩阵形式增强粉丝基础并提升商业价值。抖音作…