RNN实现阿尔茨海默症的诊断识别

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

 一 导入数据

import torch.nn as nn
import torch.nn.functional as F
import torchvision,torch
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset,DataLoader
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
import warnings
warnings.filterwarnings('ignore')# 设置硬件设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")df = pd.read_excel('dia.xls')
df

二 数据处理分析

# 删除第一列和最后一列
df = df.iloc[:,1:-1]
print(df)
 Age  Gender  Ethnicity  EducationLevel        BMI  Smoking   
0      73       0          0               2  22.927749        0  \
1      89       0          0               0  26.827681        0   
2      73       0          3               1  17.795882        0   
3      74       1          0               1  33.800817        1   
4      89       0          0               0  20.716974        0   
...   ...     ...        ...             ...        ...      ...   
2144   61       0          0               1  39.121757        0   
2145   75       0          0               2  17.857903        0   
2146   77       0          0               1  15.476479        0   
2147   78       1          3               1  15.299911        0   
2148   72       0          0               2  33.289738        0   AlcoholConsumption  PhysicalActivity  DietQuality  SleepQuality  ...   
0              13.297218          6.327112     1.347214      9.025679  ...  \
1               4.542524          7.619885     0.518767      7.151293  ...   
2              19.555085          7.844988     1.826335      9.673574  ...   
3              12.209266          8.428001     7.435604      8.392554  ...   
4              18.454356          6.310461     0.795498      5.597238  ...   
...                  ...               ...          ...           ...  ...   
2144            1.561126          4.049964     6.555306      7.535540  ...   
2145           18.767261          1.360667     2.904662      8.555256  ...   
2146            4.594670          9.886002     8.120025      5.769464  ...   
2147            8.674505          6.354282     1.263427      8.322874  ...   
2148            7.890703          6.570993     7.941404      9.878711  ...   FunctionalAssessment  MemoryComplaints  BehavioralProblems       ADL   
0                 6.518877                 0                   0  1.725883  \
1                 7.118696                 0                   0  2.592424   
2                 5.895077                 0                   0  7.119548   
3                 8.965106                 0                   1  6.481226   
4                 6.045039                 0                   0  0.014691   
...                    ...               ...                 ...       ...   
2144              0.238667                 0                   0  4.492838   
2145              8.687480                 0                   1  9.204952   
2146              1.972137                 0                   0  5.036334   
2147              5.173891                 0                   0  3.785399   
2148              6.307543                 0                   1  8.327563   Confusion  Disorientation  PersonalityChanges   
0             0               0                   0  \
1             0               0                   0   
2             0               1                   0   
3             0               0                   0   
4             0               0                   1   
...         ...             ...                 ...   
2144          1               0                   0   
2145          0               0                   0   
2146          0               0                   0   
2147          0               0                   0   
2148          0               1                   0   DifficultyCompletingTasks  Forgetfulness  Diagnosis  
0                             1              0          0  
1                             0              1          0  
2                             1              0          0  
3                             0              0          0  
4                             1              0          0  
...                         ...            ...        ...  
2144                          0              0          1  
2145                          0              0          1  
2146                          0              0          1  
2147                          0              1          1  
2148                          0              1          0  [2149 rows x 33 columns]

三 探索性数据分析 

1.得病分布

res = df.groupby('Diabetes')['Age'].count()
print(res)plt.figure(figsize=(8, 6))
plt.pie(res.values, labels=res.index, autopct='%1.1f%%', startangle=90,colors=['#ff9999','#66b3ff','#99ff99'], explode=(0.1,  0),wedgeprops={'edgecolor': 'black', 'linewidth': 1, 'linestyle': 'solid'})
plt.title('是否得阿尔茨海默症', fontsize=16, fontweight='bold')
plt.show()

 2.BMI分布直方图

# BMI分布直方图
sns.displot(df['BMI'], kde=True, color='skyblue', bins=30, height=6, aspect=1.2)
plt.title('BMI Distribution', fontsize=18, fontweight='bold', color='darkblue')
plt.xlabel('BMI', fontsize=14, color='darkgreen')
plt.ylabel('Frequency', fontsize=14, color='darkgreen')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

 

3.年龄分布直方图 

# Age分布直方图
sns.displot(df['Age'], kde=True, color='skyblue', bins=30, height=6, aspect=1.2)
plt.title('Age Distribution', fontsize=18, fontweight='bold', color='darkblue')
plt.xlabel('Age', fontsize=14, color='darkgreen')
plt.ylabel('Frequency', fontsize=14, color='darkgreen')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

 

 

四 构建划分数据集 

X = df.iloc[:,:-1]
y = df.iloc[:,-1]sc = StandardScaler()
X = sc.fit_transform(X)# 划分数据集
X = torch.tensor(np.array(X),dtype=torch.float32)
y = torch.tensor(np.array(y),dtype=torch.int64)X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.1,random_state=1)# 构建数据加载器
train_dl = DataLoader(TensorDataset(X_train,y_train),batch_size=64,shuffle=False)
test_dl = DataLoader(TensorDataset(X_test,y_test),batch_size=64,shuffle=False)

 五 训练模型

1.构建模型

# 构建模型
class model_rnn(nn.Module):def __init__(self):super(model_rnn, self).__init__()self.rnn0 = nn.RNN(input_size=32,hidden_size=200,num_layers=1,batch_first=True)self.fc0 = nn.Linear(200,50)self.fc1 = nn.Linear(50,2)def forward(self,x):out , hidden1 = self.rnn0(x)out = self.fc0(out)out = self.fc1(out)return outmodel = model_rnn().to(device)
print(model)
model_rnn((rnn0): RNN(32, 200, batch_first=True)(fc0): Linear(in_features=200, out_features=50, bias=True)(fc1): Linear(in_features=50, out_features=2, bias=True)
)

2.训练函数 

'''
训练模型
'''
# 训练循环
def train(dataloader,model,loss_fn,optimizer):size = len(dataloader.dataset)   # 训练集的大小num_batches = len(dataloader)      # 批次数目,(size/batchsize,向上取整)train_acc,train_loss = 0,0  # 初始化训练损失和正确率for x,y in dataloader:    # 获取数据X,y = x.to(device),y.to(device)# 计算预测误差pred = model(X)   # 网络输出loss = loss_fn(pred,y)   # 计算误差# 反向传播optimizer.zero_grad()    # grad属性归零loss.backward()   # 反向传播optimizer.step()   # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc,train_loss

3.测试函数 

# 测试循环
def valid(dataloader,model,loss_fn):size = len(dataloader.dataset)  # 训练集的大小num_batches = len(dataloader)  # 批次数目,(size/batchsize,向上取整)test_loss, test_acc = 0, 0  # 初始化训练损失和正确率# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs,target in dataloader:imgs,target = imgs.to(device),target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred,target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc,test_loss

4.正式训练 

loss_fn = nn.CrossEntropyLoss()   # 创建损失函数
learn_rate = 1e-4   # 学习率
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs = 30train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,opt)model.eval()epoch_test_acc,epoch_test_loss = valid(test_dl,model,loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},lr:{:.2E}')print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))print("="*20,'Done',"="*20)
Epoch: 1,Train_acc:52.9%,Train_loss:0.688,Test_acc:67.0%,Test_loss:0.658,lr:1.00E-04
Epoch: 2,Train_acc:68.7%,Train_loss:0.612,Test_acc:67.4%,Test_loss:0.600,lr:1.00E-04
Epoch: 3,Train_acc:68.7%,Train_loss:0.566,Test_acc:70.7%,Test_loss:0.567,lr:1.00E-04
Epoch: 4,Train_acc:74.4%,Train_loss:0.526,Test_acc:72.6%,Test_loss:0.533,lr:1.00E-04
Epoch: 5,Train_acc:77.9%,Train_loss:0.487,Test_acc:78.1%,Test_loss:0.501,lr:1.00E-04
Epoch: 6,Train_acc:81.1%,Train_loss:0.451,Test_acc:79.5%,Test_loss:0.473,lr:1.00E-04
Epoch: 7,Train_acc:82.3%,Train_loss:0.421,Test_acc:80.0%,Test_loss:0.451,lr:1.00E-04
Epoch: 8,Train_acc:83.4%,Train_loss:0.397,Test_acc:78.6%,Test_loss:0.434,lr:1.00E-04
Epoch: 9,Train_acc:84.7%,Train_loss:0.378,Test_acc:80.0%,Test_loss:0.422,lr:1.00E-04
Epoch:10,Train_acc:85.2%,Train_loss:0.365,Test_acc:80.0%,Test_loss:0.414,lr:1.00E-04
Epoch:11,Train_acc:85.6%,Train_loss:0.354,Test_acc:80.0%,Test_loss:0.408,lr:1.00E-04
Epoch:12,Train_acc:85.9%,Train_loss:0.347,Test_acc:80.0%,Test_loss:0.405,lr:1.00E-04
Epoch:13,Train_acc:86.3%,Train_loss:0.341,Test_acc:78.6%,Test_loss:0.403,lr:1.00E-04
Epoch:14,Train_acc:87.0%,Train_loss:0.335,Test_acc:78.1%,Test_loss:0.403,lr:1.00E-04
Epoch:15,Train_acc:87.1%,Train_loss:0.331,Test_acc:78.6%,Test_loss:0.404,lr:1.00E-04
Epoch:16,Train_acc:87.1%,Train_loss:0.327,Test_acc:78.1%,Test_loss:0.405,lr:1.00E-04
Epoch:17,Train_acc:87.1%,Train_loss:0.324,Test_acc:78.6%,Test_loss:0.407,lr:1.00E-04
Epoch:18,Train_acc:87.3%,Train_loss:0.321,Test_acc:78.6%,Test_loss:0.409,lr:1.00E-04
Epoch:19,Train_acc:87.4%,Train_loss:0.318,Test_acc:77.7%,Test_loss:0.412,lr:1.00E-04
Epoch:20,Train_acc:87.7%,Train_loss:0.315,Test_acc:78.1%,Test_loss:0.415,lr:1.00E-04
Epoch:21,Train_acc:87.8%,Train_loss:0.312,Test_acc:77.7%,Test_loss:0.418,lr:1.00E-04
Epoch:22,Train_acc:88.1%,Train_loss:0.309,Test_acc:78.1%,Test_loss:0.422,lr:1.00E-04
Epoch:23,Train_acc:88.6%,Train_loss:0.306,Test_acc:78.1%,Test_loss:0.425,lr:1.00E-04
Epoch:24,Train_acc:88.6%,Train_loss:0.303,Test_acc:79.1%,Test_loss:0.429,lr:1.00E-04
Epoch:25,Train_acc:88.6%,Train_loss:0.301,Test_acc:79.5%,Test_loss:0.433,lr:1.00E-04
Epoch:26,Train_acc:88.6%,Train_loss:0.298,Test_acc:79.5%,Test_loss:0.437,lr:1.00E-04
Epoch:27,Train_acc:88.8%,Train_loss:0.295,Test_acc:80.0%,Test_loss:0.440,lr:1.00E-04
Epoch:28,Train_acc:89.1%,Train_loss:0.292,Test_acc:79.5%,Test_loss:0.444,lr:1.00E-04
Epoch:29,Train_acc:89.1%,Train_loss:0.290,Test_acc:79.1%,Test_loss:0.449,lr:1.00E-04
Epoch:30,Train_acc:89.2%,Train_loss:0.287,Test_acc:79.1%,Test_loss:0.453,lr:1.00E-04
==================== Done ====================

六 结果可视化 

1.Loss和Acc图

epochs_range = range(30)
plt.figure(figsize=(14,4))
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='training accuracy')
plt.plot(epochs_range,test_acc,label='validation accuracy')
plt.legend(loc='lower right')
plt.title('training and validation accuracy')plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='training loss')
plt.plot(epochs_range,test_loss,label='validation loss')
plt.legend(loc='upper right')
plt.title('training and validation loss')
plt.show()

 2.调用模型预测

test_X = X_test[0].reshape(1,-1)
pred = model(test_X.to(device)).argmax(1).item()
print('模型预测结果:',pred)
print('=='*20)
print('0:未患病')
print('1:已患病')
模型预测结果: 0
========================================
0:未患病
1:已患病

3.绘制混淆矩阵 

'''
绘制混淆矩阵
'''
print('=============输入数据shape为==============')
print('X_test.shape:',X_test.shape)
print('y_test.shape:',y_test.shape)pred = model(X_test.to(device)).argmax(1).cpu().numpy()print('\n==========输出数据shape为==============')
print('pred.shape:',pred.shape)from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay# 计算混淆矩阵
cm = confusion_matrix(y_test,pred)plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm,annot=True,fmt='d',cmap='Blues')
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title('Confusion Matrix',fontsize=12)
plt.xlabel('Pred Label',fontsize=10)
plt.ylabel('True Label',fontsize=10)
plt.tight_layout()
plt.show()
=============输入数据shape为==============
X_test.shape: torch.Size([215, 32])
y_test.shape: torch.Size([215])==========输出数据shape为==============
pred.shape: (215,)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/7891.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ui-automator定位官网文档下载及使用

一、ui-automator定位官网文档简介及下载 AndroidUiAutomator:移动端特有的定位方式,uiautomator是java实现的,定位类型必须写成java类型 官方地址:https://developer.android.com/training/testing/ui-automator.html#ui-autom…

算法每日双题精讲 —— 二分查找(寻找旋转排序数组中的最小值,点名)

🌟快来参与讨论💬,点赞👍、收藏⭐、分享📤,共创活力社区。 🌟 别再犹豫了!快来订阅我们的算法每日双题精讲专栏,一起踏上算法学习的精彩之旅吧💪 在算法的…

macOS使用LLVM官方发布的tar.xz来安装Clang编译器

之前笔者写过一篇博文ubuntu使用LLVM官方发布的tar.xz来安装Clang编译器介绍了Ubuntu下使用官方发布的tar.xz包来安装Clang编译。官方发布的版本中也有MacOS版本的tar.xz,那MacOS应该也是可以安装的。 笔者2015款MBP笔记本,CPU是intel的,出厂…

机器学习周报-文献阅读

文章目录 摘要Abstract 1 相关知识1.1 WDN建模1.2 掩码操作(Masking Operation) 2 论文内容2.1 WDN信息的数据处理2.2 使用所收集的数据构造模型2.2.1 Gated graph neural network2.2.2 Masking operation2.2.3 Training loss2.2.4 Evaluation metrics 2…

Doris Schema Change 常见问题分析

1. 什么是 Schema Change Schema Change 是在数据库中修改表结构的一种操作,例如添加列、删除列、更改列类型等。 ⚠️Schema Change 限制⚠️ 一张表在同一时间只能有一个 Schema Change 作业在运行。分区列和分桶列不能修改。如果聚合表中有 REPLACE 方式聚合的…

我的2024年年度总结

序言 在前不久(应该是上周)的博客之星入围赛中铩羽而归了。虽然心中颇为不甘,觉得这一年兢兢业业,每天都在发文章,不应该是这样的结果(连前300名都进不了)。但人不能总抱怨,总要向前…

C++ DLL注入原理以及示例

0、 前言 0.1 什么是DLL注入 DLL(动态链接库)注入是一种技术,通过将外部的 DLL 文件强行加载到目标进程的地址空间中,使得外部代码可以执行。这种技术常用于修改或扩展应用程序的行为,甚至用于恶意攻击。 0.2 DLL注入…

MATLAB绘图:随机彩色圆点图

这段代码在MATLAB中生成并绘制了500个随机位置和颜色的散点图。通过随机生成的x和y坐标以及颜色,用户可以直观地观察到随机点的分布。这种可视化方式在数据分析、统计学和随机过程的演示中具有广泛的应用。 文章目录 运行结果代码代码讲解 运行结果 代码 clc; clea…

关于使用PHP时WordPress排错——“这意味着您在wp-config.php文件中指定的用户名和密码信息不正确”的解决办法

本来是看到一位好友的自己建站,所以突发奇想,在本地装个WordPress玩玩吧,就尝试着装了一下,因为之前电脑上就有MySQL,所以在自己使用PHP建立MySQL时报错了。 最开始是我的php启动mysql时有问题,也就是启动过…

RabbitMQ 架构分析

文章目录 前言一、RabbitMQ架构分析1、Broker2、Vhost3、Producer4、Messages5、Connections6、Channel7、Exchange7、Queue8、Consumer 二、消息路由机制1、Direct Exchange2、Topic Exchange3、Fanout Exchange4、Headers Exchange5、notice5.1、备用交换机(Alter…

【Uniapp-Vue3】setTabBar设置TabBar和下拉刷新API

一、setTabBar设置 uni.setTabBarItem({ index:"需要修改第几个", text:"修改后的文字内容" }) 二、tabBar的隐藏和显式 // 隐藏tabBar uni.hideTabBar(); // 显示tabBar uni.showTabBar(); 三、为tabBar右上角添加文本 uni.setTabBarBadge({ index:"…

routeros7 adguardhome添加规则报错certificate expired

mikrokit routeros 7添加adguardhome容器。 /container/add remote-imageadguard/adguardhome:latest interfaceveth1 root-dircontainer/adgurdhome loggingyes结果发现添加不了规则,报证书过期。 Error: control/filtering/add_url | Couldn’t fetch filter fro…

壁纸设计过程中如何增加氛围感

在壁纸设计过程中,增加氛围感是提升整体视觉效果和情感传达的关键。以下是一些具体的方法和技巧,帮助你在设计中营造出强烈的氛围感: 一、色彩运用 选择主题色: 根据你想要传达的情感选择主色调。例如,温暖的色调&…

RabbitMQ模块新增消息转换器

文章目录 1.目录结构2.代码1.pom.xml 排除logging2.RabbitMQConfig.java3.RabbitMQAutoConfiguration.java 1.目录结构 2.代码 1.pom.xml 排除logging <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/PO…

2024年度总结:技术探索与个人成长的交织

文章目录 前言年度创作回顾&#xff1a;技术深耕与分享数据库技术&#xff1a;MySQL 与 MyBatisJava 及相关技术栈计算机网络&#xff1a;构建网络知识体系思维方式的转变&#xff1a;构建技术知识体系的桥梁 项目实践&#xff1a;人工智能与智慧医疗的碰撞生活与博客的融合与平…

python:taichi 模拟一维波场

在 Taichi 中模拟一维波场&#xff0c;通常是利用 Taichi 编程语言的特性来对一维空间中的波动现象进行数值模拟&#xff0c;以下是相关介绍&#xff1a; 原理基础 波动方程&#xff1a;一维波动方程的一般形式为 &#xff0c;其中 u(x,t) 表示在位置x 和时间t 处的波的状态&…

基于回归分析法的光伏发电系统最大功率计算simulink建模与仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 5.完整工程文件 1.课题概述 基于回归分析法的光伏发电系统最大功率计算simulink建模与仿真。选择回归法进行最大功率点的追踪&#xff0c;使用光强和温度作为影响因素&#xff0c;电压作为输出进行建模。…

深入MapReduce——引入

引入 前面我们已经深入了HDFS的设计与实现&#xff0c;对于分布式系统也有了不错的理解。 但HDFS仅仅解决了海量数据存储和读写的问题。要想让数据产生价值&#xff0c;一定是需要从数据中挖掘出价值才行&#xff0c;这就需要我们拥有海量数据的计算处理能力。 下面我们还是…

Vue 引入及简单示例

Vue 渐进式JavaScript 框架 学习笔记 - Vue 引入及简单示例 目录 与jquery区别 Vue引入 两种方式引入 下载到本地 代码结构 简单示例 Style中引入vue.js 对vue语法进行解析 对三目运算符支持 设置变量&#xff08;状态&#xff09; 总结 与jquery区别 不需要手动操…

系统思考—问题分析

很多中小企业都在面对转型的难题&#xff1a;市场变化快&#xff0c;资源有限&#xff0c;团队协作不畅……这些问题似乎总是困扰着我们。就像最近和一位企业主交流时&#xff0c;他提到&#xff1a;“我们团队每天都很忙&#xff0c;但效率始终没见提升&#xff0c;感觉像是在…