搭建全连接网络进行分类(糖尿病为例)

拿来练手,大神请绕道。

1.网上的代码大多都写在一个函数里,但是其实很多好论文都是把网络,数据训练等分开写的。

2.分开写就是有一个需要注意的事情,就是要import 要用到的文件中的模型或者变量等。

3.全连接的回归也写了,有空再上传吧。

4.一般都是先写data或者model

import torch
import torch.nn as nn
import torch.nn.functional as F
#nn.func这个里面很多功能其实nn里就有,可以不导入,而且后面新的版本的torch也取消了cc.functional里面的部分函数#定义网络,需要定义两部分,一部分就是初始化,另一部分就是数据流
class FCNet(nn.Module):def __init__(self):super(FCNet,self).__init__()self.fc1 = nn.Linear(8,16)#初始的这个8,要和你的数据的特征数一样才行,后面的数可以随意设置,但是不要太多,容易过拟合# self.fc2 = nn.Linear(50,20)self.fc3 = nn.Linear(16,2)#二分类,输出2,其实1也可以的#最后的就是分类数,因为用的sigmod和交叉熵损失,就不用额外加softmax了,多分类要用softmaxself.sig = nn.Sigmoid()# self.drop = nn.Dropout(0.3)#可以把用到的放在这里,也可以用nn.Sequential()放在一起,这样后面的话就可以直接用这个,不用写那么多了def forward(self,x):x = self.sig(self.fc1(x))# x = self.sig(self.fc2(x))x = self.sig(self.fc3(x))return x#就是x要怎么在网络中走,要写一遍#可以自己输出测试一下看看网络是不是自己想的那样,在真的调用的时候再屏蔽掉
# net= FCNet()
# print(net)

首先看看数据是是啥样,outcome就是有没有糖尿病

其实可以手动把csv分成train和test

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
#导入pands是为了读数据,当然使用numpy也可以读得,sklearn是为了把训练数据分为训练和验证集data = pd.read_csv('./train.csv')
#就是把对应的数据哪出来,x代表的是feature上的data,y代表的是label,因为pd可以读到最上面的标签,所以从第2行(i=1)开始读就行
x = data.iloc[1:,:-1]
y = data.iloc[1:,[-1]]
#可以输出看看数据对不对,x中不应该包含labels
# print(x)
# print(y)
#test_size就是划分的比例,后面的是种子,意思是每次运行这个函数时候,0.8就是那些,0.2也还是每次一样,如果想要不一样,只要每次运行这个函数时候换个值就行
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=0)
#print(x_train,y_test)
# print(x_test,y_test)
#给数据进行归一化,可以用很多方法,我用最简单的归一到-1到1
x_train = x_train.apply(lambda x: (x - x.mean()) / (x.std()))
x_test = x_test.apply(lambda x: (x - x.mean()) / (x.std()))#写dataset可以用两种方法,第一种就是 每一个数据自己单独处理,第二个就是要自己重写dataset类
#1.
# 可以使用分别的处理,把数据(首先转换为tensor,或者把dataframe.valus拿出来才能转换为tensor)转换为tensor并且数据类型转换为float32,如果测试没有真值,需要单独转换
# x_train = torch.tensor(np.array(x_train),dtype=torch.float32)
# y_train = torch.tensor(np.array(y_train),dtype=torch.float32)
# x_test = torch.tensor(np.array(x_test),dtype=torch.float32)
# y_test = torch.tensor(np.array(x_test),dtype=torch.float32)
# train_dataset = torch.utils.data.TensorDataset(x_train,y_train)
# test_dataset = torch.utils.data.TensorDataset(x_test,y_test)#2.也可以直接重写datasetclass dataset(Dataset):def __init__(self, x, y):#把值拿出来或者变为np类型才能转换为tensor# self.data = torch.tensor(x.values,dtype=torch.float32)# self.labels = torch.tensor(y.values,dtype=torch.float32)self.data = torch.tensor(np.array(x),dtype=torch.float32)self.labels = torch.tensor(np.array(y),dtype=torch.float32)def __len__(self):return len(self.data)def __getitem__(self,idx):return self.data[idx],self.labels[idx]#应该返回的是list类型,不是字典也不是setBATCH_SIZE = 64#验证集一般不用shuffle
train_dataset = dataset(x_train,y_train)
test_dataset = dataset(x_test,y_test)
# print(train_dataset)
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_lodaer = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False)
# print(train_loader)

然后就可以写train或者test了,其实test和train一样

from Model import FCNet
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import data
#导入要调用的net和data,也可以from data import xxx 这样可以直接用xxx,现在的这个需要用data.xxx#看自己的设备,最好用gpu来跑
if (torch.cuda.is_available()):my_device = torch.device('cuda')
else:my_device = torch.device('cpu')print(my_device)
#实例化一个net,并且放到gpu上,需要放到gpu上的有inputs,labels,net,loss
net = FCNet().to(my_device)
# print(net)
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
#一开始是不需要weight_decay(也就是l2正则化),可以等出现过拟合在用,也可以先用上
optimizer = optim.Adam(net.parameters(),lr=0.001,weight_decay=0.01)epochs = 600
#定义train,因为一边训练一边验证,所有就把两个loader都放进去了,不过写法很多,也可以不放dataloader,放epoches也可以
def train(dataloader,valloader):losses = []acces = []losses_val = []for epoch in range(epochs):loss_batch = 0for i,data in enumerate(dataloader):#需要注意的,这里的inputs和labels和之前定义的dataset相关,需要是list类型才可以inputs,labels = data#print(data)可以打印出来查看一下inputs,labels = inputs.to(my_device),labels.to(my_device)optimizer.zero_grad()#每次要梯度清零outputs = net(inputs)#print(outputs)#model的最后一层是sigmod#labels的格式需要注意,因为现在是[[1],[0],[1],[1]..]这样得格式,无法放到交叉熵了,需要时[0,1,1,1...]这样得格式才行loss = criterion(outputs,labels.squeeze(1).long()).to(my_device)#print(labels.squeeze(1).long())loss.backward()optimizer.step()loss_batch += loss.item()length = i#验证的时候不用反向传播和梯度下降这些net.eval()count = 0right = 0loss_batch_val =0with torch.no_grad():for j,data2 in enumerate(valloader):val_inputs,val_labels = data2val_inputs,val_labels = val_inputs.to(my_device),val_labels.squeeze(1).long().to(my_device)val_outputs = net(val_inputs)loss_val = criterion(val_outputs,val_labels)#因为net的最后一层是2,所以输出的是2维的【0.6,0.4】这种,但是这个可以直接放到交叉熵中#——中放的是概率,pred中放的是预测的类别,算损失还是要用outputs,但是算准确率就是用pred和真实labels相比了_,pred = torch.max(val_outputs,1)#print(pred)right = (pred == val_labels).sum().item()count = len(val_labels)acc = right/countloss_batch_val += loss_val.item()length2 = jif epoch % 10 == 9:print('train_epoch:',epoch+1,'train_loss:',loss_batch/length,'val_loss:',loss_batch_val/length2,'acc:',acc)losses.append(loss_batch/length)acces.append(acc)losses_val.append(loss_batch_val/length2)#可以画一些曲线,输出一些值plt.plot(range(60),losses,color ='blue',label ='train_loss')plt.plot(range(60),acces, color ='red',label ='val_acc')plt.plot(range(60),losses_val,color ='yellow',label ='val_loss')plt.legend()plt.show()torch.save(net.state_dict(),'./weights_epoch1000.pth')#保存参数train(data.train_loader,data.test_lodaer)

最后看一下结果,最后的准确率在85%左右,还可以,毕竟数据不多,也是简单的全连接。

在这个结果之前出现了很多问题,比如波动很大,损失先降后升等问题,找个有问题的图

下面是一些总结:

1.跳跃很大,波动:增大batch_size,减小lr。

2.降低过拟合:

        a.降低模型的复杂程度,但是修改具体的神经元个数,因为这个网络本身就不大,所有没啥用,模型非常大没准会有用。

        b.batchsize增大,lr减小是有效的。

        c.输入数据进行归一化是有用的,归一化之后lr可以调大一点,收敛变快了。

        d.L2正则化是有用的,很有用。dropout应该也有用,但是模型本来就很小,我试了试没啥差别。而且有正则化之后可以加速收敛,lr可以稍微调大一点,较少的epoches也可以收敛了,而已acc也会更高一点,稳定一点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/146701.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Flink CDC MySQL同步MySQL错误记录

1、启动 Flink SQL [appuserwhtpjfscpt01 flink-1.17.1]$ bin/sql-client.sh2、新建源表 问题1:Encountered “(” 处理方法:去掉int(11),改为int Flink SQL> CREATE TABLE t_user ( > uid int(11) NOT NULL AUTO_INCREMENT COMME…

3D WEB轻量化引擎HOOPS助力3D测量应用蓬勃发展:效率、精度显著提升

在3D开发工具领域,Tech Soft 3D打造的HOOPS SDK已经崭露头角,成为了全球领先的3D领域开发工具提供商。HOOPS SDK包括四种不同的3D软件开发工具,已成为行业的翘楚。 其中,HOOPS Exchange以其CAD数据转换的能力脱颖而出&#xff0c…

最新AI智能问答系统源码/AI绘画系统源码/支持GPT联网提问/Prompt应用+支持国内AI提问模型

一、AI创作系统 SparkAi创作系统是基于国外很火的ChatGPT进行开发的AI智能问答系统和AI绘画系统。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI创作ChatGPT?小编这里写一个详细图…

讲讲项目里的仪表盘编辑器(三)布局组件

布局容器处理 看完前面两章的讲解,我们对仪表盘系统有了一个大概的理解。接着我们讲讲更深入的应用。 上文讲解的编辑器只是局限于平铺的组件集。而在编辑器中,还会有一种组件是布局容器。它允许其他组件拖拽进入在里面形成自己的一套布局。典型的有分页…

【Linux】线程概念

🔥🔥 欢迎来到小林的博客!!       🛰️博客主页:✈️林 子       🛰️博客专栏:✈️ Linux       🛰️社区 :✈️ 进步学堂       &#x1f6f0…

3.物联网射频识别,(高频)RFID应用ISO14443-2协议,(校园卡)Mifare S50卡

一。ISO14443-2协议简介 1.ISO14443协议组成及部分缩略语 (1)14443协议组成(下面的协议简介会详细介绍) 14443-1 物理特性 14443-2 射频功率和信号接口 14443-3 初始化和防冲突 (分为Type A、Type B两种接口&…

【嵌入式】使用MultiButton开源库驱动按键并控制多级界面切换

目录 一 背景说明 二 参考资料 三 MultiButton开源库移植 四 设计实现--驱动按键 五 设计实现--界面处理 一 背景说明 需要做一个通过不同按键控制多级界面切换以及界面动作的程序。 查阅相关资料,发现网上大多数的应用都比较繁琐,且对于多级界面的…

深眸科技基于AI机器视觉实现应用部署,构建铝箔异物检测解决方案

异物的定义指的是影响到产品的外观质量或使用性能的外来或产品内部的物质,其产生的原因有很多种,包括在产品生产使用过程中的污染、腐蚀、氧化,以及由于生产工业控制不规范或人为疏忽等。而异物的产生,是导致产品的不良率增加的根…

ChatGPT必应联网功能正式上线

今日凌晨发现,ChatGPT又支持必应联网了!虽然有人使用过newbing这个阉割版的联网GPT4,但官方版本确实更加便捷好用啊! 尽管 ChatGPT 此前已经展现出了其他人工智能模型无可比拟的智能,但由于其训练数据的限制&#xff…

【python学习第12节 pandas】

文章目录 一,pandas1.1 pd.Series1.2 pd.date_range1.3 pd_DataFrame1.4浏览数据1.5布尔索引1.6设置值1.7操作1.8合并1.8.1concat()函数1.8.2 merge()函数 一,pandas 1.1 pd.Series pd.Series 是 Pandas 库中的一个数据结构&…

海信电视U8KL使用体验:参数卷,画质技术也独有!

每个家庭成员对电视都有不同需求,如何能做到兼顾?看似需求众口难调,其实一台海信电视就能满足所有啦。 海信电视的参数不仅是最卷的,同时画质技术还是国内独有的,能把这样一台优秀的电视搬回家,无论电影、…

云原生Kubernetes:对外服务之 Ingress

目录 一、理论 1.Ingress 2.部署 nginx-ingress-controller(第一种方式) 3.部署 nginx-ingress-controller(第二种方式) 二、实验 1.部署 nginx-ingress-controller(第一种方式) 2.部署 nginx-ingress-controller(第二种方式) 三、问题 1.启动 nginx-ingress-controll…

【MySQL入门到精通-黑马程序员】MySQL基础篇-DML

文章目录 前言一、DML-介绍二、DML-添加数据三、DML-修改数据四、DML-删除数据总结 前言 本专栏文章为观看黑马程序员《MySQL入门到精通》所做笔记,课程地址在这。如有侵权,立即删除。 一、DML-介绍 DML(Data Manipulation Language&#xf…

Cinema 4D 2024更新, 比旧版速度更快!

Cinema 4D 2024 for Mac更新至v2024.0.2版本,其中Cinema 4D核心得到了全面优化,增强了可调的Pyro模拟、增强了真实镜头耀斑和色彩校正工作流程。 Cinema 4D 2024变得更加强大,在交互式播放方面有了巨大的性能改进,对刚性体模拟进行…

手撸RPC【gw-rpc】

文章目录 基于 Netty 的简易版 RPC需求分析简易RPC框架的整体实现协议模块 📖自定义协议 🆕序列化方式 🔢 服务工厂 🏭服务调用方 ❓前置知识——动态代理🕳️Proxy类InvocationHandler 接口 RPC服务代理类内嵌Netty客…

【iptables 实战】01 iptables概念

一、iptables和netfilter iptables其实不是真正的防火墙,我们可以把它理解成一个客户端代理,用户通过iptables这个代理,将用户的安全设定执行到对应的”安全框架”中,这个”安全框架”才是真正的防火墙,这个框架的名字…

TVP专家谈腾讯云 Cloud Studio:开启云端开发新篇章

导语 | 近日,由腾讯云 TVP 团队倾力打造的 TVP 吐槽大会第六期「腾讯云 Cloud Studio」专场圆满落幕,6 位资深的 TVP 专家深度体验腾讯云 Cloud Studio 产品,提出了直击痛点的意见与建议,同时也充分肯定了腾讯云 Cloud Studio 的实…

c与c++中的字符串

在c中,string本质上是一个类; string与char *有些区别: char*是一个指针;string是一个类,类内封装了char*,管理这一个字符串,是一个char*的容器 在使用string类型时,要加上其头文…

【Cesium创造属于你的地球】相机系统

相机系统里面有setView,flyTo,lookAt,viewBoundingsphere这几种方法,以下是相关的使用方法,学起来!!! setView 该方法可以直接切换相机视口,从而不需要通过一个飞入的效…

uniapp - 微信小程序实现腾讯地图位置标点展示,将指定地点进行标记选点并以一个图片图标展示出来(详细示例源码,一键复制开箱即用)

效果图 在uniapp微信小程序平台端开发,简单快速的实现在地图上进行位置标点功能,使用腾讯地图并进行标点创建和设置(可以自定义标记点的图片)。 你只需要复制代码,改个标记图标和位置即可。