【chatgpt】train_split_test的random_state

在使用train_test_split函数划分数据集时,random_state参数用于控制随机数生成器的种子,以确保划分结果的可重复性。这样,无论你运行多少次代码,只要使用相同的random_state值,得到的训练集和测试集划分就会是一样的。

使用 train_test_split 示例

以下是一个示例,展示如何使用train_test_split函数进行数据集划分,并设置random_state参数:
程序输出结果
Training set shape: (80, 10), (80,)
Test set shape: (20, 10), (20,)

import numpy as np
from sklearn.model_selection import train_test_split# 假设我们有一些数据
X = np.random.rand(100, 10)  # 100个样本,每个样本10个特征
y = np.random.randint(0, 2, 100)  # 100个样本的标签(0或1)# 使用train_test_split进行数据集划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 打印划分后的数据集形状
print(f'Training set shape: {X_train.shape}, {y_train.shape}')
print(f'Test set shape: {X_test.shape}, {y_test.shape}')

在这个示例中:

  • X 是特征矩阵,包含100个样本,每个样本有10个特征。
  • y 是标签数组,包含100个样本的标签。
  • test_size=0.2 表示将数据集的20%用作测试集,剩下的80%用作训练集。
  • random_state=42 用于确保划分的可重复性。

为什么使用 random_state

使用 random_state 可以确保在多次运行代码时,得到的训练集和测试集划分是一致的,这在以下情况下特别有用:

  1. 调试和开发: 在开发和调试模型时,使用相同的 random_state 可以确保数据划分的一致性,从而使得调试更加容易。
  2. 实验的可重复性: 在进行实验时,使用相同的 random_state 可以确保实验结果的可重复性,使得其他人可以验证你的结果。
  3. 比较模型性能: 在比较不同模型的性能时,使用相同的 random_state 可以确保每个模型都使用相同的训练集和测试集,从而使比较更加公平。
    在这里插入图片描述

示例:比较大数据集和小数据集的模型性能

假设我们有一个大数据集和一个小数据集,我们想要比较它们在同一模型上的性能。我们可以使用 train_test_split 进行数据集划分,并设置 random_state 以确保划分的可重复性。

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numpy as np
# 创建大数据集和小数据集
X_large = np.random.rand(1000, 10)
y_large = np.random.rand(1000, 1)X_small = np.random.rand(100, 10)
y_small = np.random.rand(100, 1)# 使用train_test_split进行数据集划分
X_train_large, X_test_large, y_train_large, y_test_large = train_test_split(X_large, y_large, test_size=0.2, random_state=42)
X_train_small, X_test_small, y_train_small, y_test_small = train_test_split(X_small, y_small, test_size=0.2, random_state=42)# 转换为张量
X_train_large = torch.tensor(X_train_large, dtype=torch.float32)
y_train_large = torch.tensor(y_train_large, dtype=torch.float32)
X_test_large = torch.tensor(X_test_large, dtype=torch.float32)
y_test_large = torch.tensor(y_test_large, dtype=torch.float32)X_train_small = torch.tensor(X_train_small, dtype=torch.float32)
y_train_small = torch.tensor(y_train_small, dtype=torch.float32)
X_test_small = torch.tensor(X_test_small, dtype=torch.float32)
y_test_small = torch.tensor(y_test_small, dtype=torch.float32)# 创建数据加载器
train_loader_large = DataLoader(TensorDataset(X_train_large, y_train_large), batch_size=32, shuffle=True)
test_loader_large = DataLoader(TensorDataset(X_test_large, y_test_large), batch_size=32, shuffle=False)train_loader_small = DataLoader(TensorDataset(X_train_small, y_train_small), batch_size=32, shuffle=True)
test_loader_small = DataLoader(TensorDataset(X_test_small, y_test_small), batch_size=32, shuffle=False)# 定义简单的线性模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.linear = nn.Linear(10, 1)def forward(self, x):return self.linear(x)# 训练模型的通用函数
def train_model(train_loader, num_epochs=50, learning_rate=0.01):model = SimpleModel()criterion = nn.MSELoss()optimizer = optim.SGD(model.parameters(), lr=learning_rate)train_losses = []for epoch in range(num_epochs):model.train()epoch_train_loss = 0.0for batch_x, batch_y in train_loader:outputs = model(batch_x)loss = criterion(outputs, batch_y)optimizer.zero_grad()loss.backward()optimizer.step()epoch_train_loss += loss.item()epoch_train_loss /= len(train_loader)train_losses.append(epoch_train_loss)print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_train_loss:.4f}')return model, train_losses# 训练大数据集的模型
print("Training on large dataset")
model_large, train_losses_large = train_model(train_loader_large)# 训练小数据集的模型
print("\nTraining on small dataset")
model_small, train_losses_small = train_model(train_loader_small)# 绘制训练损失曲线
plt.figure(figsize=(12, 6))
plt.plot(range(1, len(train_losses_large) + 1), train_losses_large, label='Large Dataset Train Loss')
plt.plot(range(1, len(train_losses_small) + 1), train_losses_small, label='Small Dataset Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Loss Comparison')
plt.savefig("test")# 在测试集上计算最终的评估指标(例如均方误差)
def evaluate_model(model, test_loader):model.eval()test_loss = 0.0criterion = nn.MSELoss()with torch.no_grad():for batch_x, batch_y in test_loader:outputs = model(batch_x)loss = criterion(outputs, batch_y)test_loss += loss.item()test_loss /= len(test_loader)return test_loss# 评估大数据集的模型
final_test_loss_large = evaluate_model(model_large, test_loader_large)# 评估小数据集的模型
final_test_loss_small = evaluate_model(model_small, test_loader_small)print(f'Final Test Loss on Large Dataset: {final_test_loss_large:.4f}')
print(f'Final Test Loss on Small Dataset: {final_test_loss_small:.4f}')

结果分析

通过上述代码,可以得到大数据集和小数据集在训练过程中的损失曲线以及最终的测试损失。根据这些信息,可以比较它们的收敛情况和性能。

  • 损失曲线: 通过观察损失曲线,判断模型在两个数据集上的收敛速度和稳定性。如果两者曲线形状相似,并且在同一水平上趋于平稳,可以认为它们收敛到了相似的程度。

  • 最终测试损失: 最终测试损失值可以用于直接比较两个模型的性能。如果两者最终测试损失值接近,则可以认为它们的模型性能相当。

通过使用相同的 random_state 值,确保数据集划分的一致性,从而使得比较结果更加公平和具有可重复性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/357702.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Git】win本地 git bash:Connect reset by 20.205.243.166 port22报错问题解决

win10 git bash 控制台 reset 22端口拒绝连接问题: Connection reset by 20.205.243.166 port 221、22端口 无法连接 ssh -T gitgithub.com2、尝试用443端口 仍然无法连接 ssh -T -P 443 gitgithub.com3、重写 git clone 地址 url,全局添加 https 前缀…

【jenkins1】gitlab与jenkins集成

文章目录 1.Jenkins-docker配置:运行在8080端口上,机器只要安装docker就能装载image并运行容器2.Jenkins与GitLab配置:docker ps查看正在运行,浏览器访问http://10....:8080/2.1 GitLab与Jenkins的Access Token配置:不…

如何关闭软件开机自启,提升电脑开机速度?

如何关闭软件开机自启,提升电脑开机速度?大家知道,很多软件在安装时默认都会设置为开机自动启动。但是,有很多软件在我们开机之后并不是马上需要用到的,开机启动的软件过多会导致电脑开机变慢。那么,如何关…

Cesium如何高性能的实现上万条道路的流光穿梭效果

大家好,我是日拱一卒的攻城师不浪,专注可视化、数字孪生、前端、nodejs、AI学习、GIS等学习沉淀,这是2024年输出的第20/100篇文章; 前言 在智慧城市的项目中,经常会碰到这样一个需求:领导要求将全市的道路…

SCI一区级 | Matlab实现BO-Transformer-LSTM多变量时间序列预测

SCI一区级 | Matlab实现BO-Transformer-LSTM多变量时间序列预测 目录 SCI一区级 | Matlab实现BO-Transformer-LSTM多变量时间序列预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.【SCI一区级】Matlab实现BO-Transformer-LSTM多变量时间序列预测,贝叶斯…

分布式,容错:10台电脑坏了2台

由10台电脑组成的分布式系统,随机、任意坏了2台,剩下的8台电脑仍然储存着全部信息,可以继续服务。这是怎么做到的? 设N台电脑,坏了H台,要保证上述性质,需要有冗余,总的存储量降低为…

【Flink metric】Flink指标系统的系统性知识:以便我们实现特性化数据的指标监控与分析

文章目录 一. Registering metrics:向flink注册新自己的metrics1. 注册metrics2. Metric types:指标类型2.1. Counter2.2. Gauge2.3. Histogram(ing)4. Meter 二. Scope:指标作用域1. User Scope2. System Scope ing3. User Variables 三. Reporter ing四. System m…

华为数通——单臂路由

单臂路由:指在三层设备路由器的一个接口上通过配置子接口(或“逻辑接口”,并不存在真正物理接口)的方式,实现原来相互隔离的不同VLAN(虚拟局域网)之间的互联互通。但是仅仅允许单播通信。 单臂路…

Web3新视野:Lumoz节点的潜力与收益解读

摘要:低估值、高回报、无条件退款80%...... Lumoz正通过其 zkVerifier 节点销售活动,引领一场ZK计算革命。 长期以来,加密市场以其独特的波动性和增长潜力,持续吸引着全球投资者的目光。而历史数据表明,市场往往在一年…

示例:WPF中应用DependencyPropertyDescriptor监视依赖属性值的改变

一、目的:开发过程中,经常碰到使用别人的控件时有些属性改变没有对应的事件抛出,从而无法做处理。比如TextBlock当修改了IsEnabled属性我们可以用IsEnabledChanged事件去做对应的逻辑处理,那么如果有类似Background属性改变我想找…

vscode用vue框架2,续写登陆页面逻辑,以及首页框架的搭建

目录 前言: 一、实现登录页信息验证逻辑 1.实现登录数据双向绑定 2.验证用户输入数据是否和默认数据相同 补充知识1: 知识点补充2: 二、首页和登录页之间的逻辑(1) 1. 修改路由,使得程序被访问先访问首页 知识点补充3&am…

一、系统学习微服务遇到的问题集合

1、启动了nacos服务&#xff0c;没有在注册列表 应该是版本问题 Alibaba-nacos版本 nacos-文档 Spring Cloud Alibaba-中文 Spring-Cloud-Alibaba-英文 Spring-Cloud-Gateway 写的很好的一篇文章 在Spring initial上面配置 start.aliyun.com 重新下载 < 2、 No Feign…

【Java】Java基础语法

一、注释详解 1.1 注释的语法&#xff1a; // 单行注释/*多行注释 *//**文档注释 */ 1.2 注释的特点&#xff1a; 注释不影响程序的执行&#xff0c;在Javac命令进行编译后会将注释去掉 1.3 注释的快捷键 二、字面量详解 2.1 字面量的概念&#xff1a; 计算机是用来处理…

西木科技Westwood-Robotics人型机器人Bruce配置和真机配置

西木科技Westwood-Robotics人型机器人Bruce配置和真机配置 本文内容机器人介绍Bruce机器人Gazebo中仿真代码部署Bruce真机代码部署 本文内容 人形机器人Brcue相关介绍docker中安装Gazebo并使用Bruce机器人控制器更换环境配置 机器人介绍 公司&#xff1a;西木科技Westwood-R…

「动态规划」如何求环绕字符串中唯一的子字符串个数?

467. 环绕字符串中唯一的子字符串https://leetcode.cn/problems/unique-substrings-in-wraparound-string/description/ 定义字符串base为一个"abcdefghijklmnopqrstuvwxyz"无限环绕的字符串&#xff0c;所以base看起来是这样的&#xff1a;"...zabcdefghijklm…

服务器数据恢复—raid5热备盘同步失败导致阵列崩溃如何恢复数据?

服务器存储数据恢复环境&故障&#xff1a; 某品牌DS5300存储&#xff0c;包含一个存储机头和多个磁盘柜&#xff0c;组建了多组RAID5磁盘阵列。 某个磁盘柜中的一组RAID5阵列由15块数据盘和1块热备硬盘组建。该磁盘柜中的某块硬盘离线&#xff0c;热备盘自动替换并开始同步…

Linux 7种 进程间通信方式

传统进程间通信 通过文件实现进程间通信 必须人为保证先后顺序 A--->硬盘---> B&#xff08;B不知道A什么时候把内容传到硬盘中&#xff09; 1.无名管道 2.有名管道 3.信号 IPC进程间通信 4.消息队列 5.共享内存 6.信号灯集 7.socket通信 一、无名管道&a…

【c2】编译预处理,gdb,makefile,文件,多线程,动静态库

文章目录 1.编译预处理&#xff1a;C源程序 - 编译预处理【#开头指令和特殊符号进行处理&#xff0c;删除程序中注释和多余空白行】- 编译2.gdb调试&#xff1a;多进/线程中无法用3.makefile文件&#xff1a;make是一个解释makefile中指令的命令工具4.文件&#xff1a;fprint/f…

Flink Sql Redis Connector

经常做开发的小伙伴肯定知道用flink连接redis的时候比较麻烦&#xff0c;更麻烦的是解析redis数据&#xff0c;如果rdis可以普通数据库那样用flink sql连接并且数据可以像表格那样展示出来就会非常方便。 历时多天&#xff0c;我终于把flink sql redis connector写出来了&…

预备资金有5000-6000买什么电脑比较好?大学生电脑选购指南

小新pro14 2024 处理器&#xff1a;采用了英特尔酷睿Ultra5 125H或Ultra9 185H两种处理器可选&#xff0c;这是英特尔最新的高性能低功耗处理器&#xff0c;具有18个线程&#xff0c;最高可达4.5GHz的加速频率&#xff0c;支持PCIe 4.0接口&#xff0c;内置了强大的ARC核芯显卡…