前馈神经网络正则化例子

直接看代码:

import torch  
import numpy as np  
import random  
from IPython import display  
from matplotlib import pyplot as plt  
import torchvision  
import torchvision.transforms as transforms   mnist_train = torchvision.datasets.MNIST(root='/MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  batch_size = 256 train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  num_inputs,num_hiddens,num_outputs =784, 256,10def init_param():W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)  b1 = torch.zeros(1, dtype=torch.float32)  W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)  b2 = torch.zeros(1, dtype=torch.float32)  params =[W1,b1,W2,b2]for param in params:param.requires_grad_(requires_grad=True)  return W1,b1,W2,b2def relu(x):x = torch.max(input=x,other=torch.tensor(0.0))  return xdef net(X):  X = X.view((-1,num_inputs))  H = relu(torch.matmul(X,W1.t())+b1)  #myrelu =((matmal x,w1)+b1),return  matmal(myrelu,w2 )+ b2return relu(torch.matmul(H,W2.t())+b2 )return torch.matmul(H,W2.t())+b2def SGD(paras,lr):  for param in params:  param.data -= lr * param.grad  def l2_penalty(w):return (w**2).sum()/2def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None,mylambda=0):  train_ls, test_ls = [], []for epoch in range(num_epochs):ls, count = 0, 0for X,y in train_iter :X = X.reshape(-1,num_inputs)l=loss(net(X),y)+ mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)optimizer.zero_grad()l.backward()optimizer.step()ls += l.item()count += y.shape[0]train_ls.append(ls)ls, count = 0, 0for X,y in test_iter:X = X.reshape(-1,num_inputs)l=loss(net(X),y) + mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)ls += l.item()count += y.shape[0]test_ls.append(ls)if(epoch)%2==0:print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))return train_ls,test_lslr = 0.01num_epochs = 20Lamda = [0,0.1,0.2,0.3,0.4,0.5]Train_ls, Test_ls = [], []for lamda in Lamda:print("current lambda is %f"%lamda)W1,b1,W2,b2 = init_param()loss = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer,lamda)   Train_ls.append(train_ls)Test_ls.append(test_ls)x = np.linspace(0,len(Train_ls[1]),len(Train_ls[1]))plt.figure(figsize=(10,8))for i in range(0,len(Lamda)):plt.plot(x,Train_ls[i],label= f'L2_Regularization:{Lamda [i]}',linewidth=1.5)plt.xlabel('different epoch')plt.ylabel('loss')plt.legend(loc=2, bbox_to_anchor=(1.1,1.0),borderAxesPad = 0.)plt.title('train loss with L2_penalty')plt.show()

运行结果:

在这里插入图片描述

疑问和心得:

  1. 画图的实现和细节还是有些模糊。
  2. 正则化系数一般是一个可以根据算法有一定变动的常数。
  3. 前馈神经网络中,二分类最后使用logistic函数返回,多分类一般返回softmax值,若是一般的回归任务,一般是直接relu返回。
  4. 前馈神经网络的实现,从物理层上应该是全连接的,但是网上的代码一般都是两层单个神经元,这个容易产生误解。个人感觉,还是要使用nn封装的函数比较正宗。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/96562.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

产品经理必知必会0.2

Q1:产品经理需要具备的能力? A:硬实力:产品设计、需求分析、竞品分析、数据分析、撰写文档 软实力:沟通能力、学习能力、用户思维、主动性、好奇心、同理心、责任心、抗压能力、目标导向.... 扩展能力:商业思维、市场敏感度... Q…

hive高频使用的拼接函数及“避坑”

hive高频使用的拼接函数及“避坑” 说到拼接函数应用场景和使用频次还是非常高,比如一个员工在公司充当多个角色,我们在底层存数的时候往往是多行,但是应用的时候我们通常会只需要一行,角色字段进行拼接,这样join其他…

STM32 F103C8T6学习笔记8:0.96寸单色OLED显示屏显示字符

使用STM32F103 C8T6 驱动0.96寸单色OLED显示屏: OLED显示屏的驱动,在设计开发中OLED显示屏十分常见,因此今日学习一下。一篇文章从程序到显示都讲通。 文章提供源码、原理解释、测试工程下载,测试效果图展示。 目录 OLED驱动原理—IIC通信…

通讯录实现【C语言】

目录 前言 一、整体逻辑分析 二、实现步骤 1、创建菜单和多次操作问题 2、创建通讯录 3、初始化通讯录 4、添加联系人 5、显示联系人 6、删除指定联系人 ​7、查找指定联系人 8、修改联系人信息 9、排序联系人信息 三、全部源码 前言 我们上期已经详细的介绍了自定…

mongodb.使用自带命令工具导出导入数据

在一次数据更新中,同事把老数据进行了清空操作,但是新的逻辑数据由于某种原因(好像是她的电脑中病毒了),一直无法正常连接数据库进行数据插入,然后下午2点左右要给甲方演示,所以要紧急恢复本地的…

于vue3+vite+element pro + pnpm开源项目

河码桌面是一个基于vue3viteelement pro pnpm 创建的monorepo项目,项目采用的是类操作系统的web界面,操作起来简单又方便,符合用户习惯,又没有操作系统的复杂! 有两个两个分支,一个是web版本,…

机器学习深度学习——机器翻译(序列生成策略)

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习&&深度学习——seq2seq实现机器翻译(详细实现与原理推导) 📚订阅专栏:机…

【Freertos基础入门】队列(queue)的使用

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、队列是什么?二、队列的操作二、示例代码总结 前言 本系列基于stm32系列单片机来使用freerots FreeRTOS是一个广泛使用的开源实时操作系统&…

Linux网络编程:Socket套接字编程

文章目录: 一:定义和流程分析 1.定义 2.流程分析 3.网络字节序 二:相关函数 IP地址转换函数inet_pton inet_ntop(本地字节序 网络字节序) socket函数(创建一个套接字) bind函数(给socket绑定一个服务器地址结…

企业数据库遭到360后缀勒索病毒攻击,360勒索病毒解密

在当今数字化时代,企业的数据安全变得尤为重要。随着数字化办公的推进,企业的生产运行效率得到了很大提升,然而针对网络安全威胁,企业也开始慢慢引起重视。近期,我们收到很多企业的求助,企业的服务器遭到了…

go 协程并发数控制

错误的写法&#xff1a; 这里的<-ch 是为了从channel 中读取 数据&#xff0c;为了不使channel通道被写满&#xff0c;阻塞 go 协程数的创建。但是请注意&#xff0c;go workForDraw(v, &wg) 是不阻塞后续的<-ch 执行的&#xff0c;所以就一直go workForDraw(v, &…

Find My资讯|苹果Vision Pro开发者需将设备配对 AirTag

最近苹果Vision Pro获开发者申请&#xff0c;苹果要求获批的申请者使用 Measure and Fit 应用确认合适的佩戴尺寸&#xff0c;并会根据申请者提交的信息&#xff0c;定制不同的 Vision Pro 开发者套件&#xff0c;以便于契合申请者的面部特征&#xff0c;提供更好的佩戴体验。 …

iPhone 15受益:骁龙8 Gen 3可能缺席部分安卓旗舰机

明年一批领先的安卓手机的性能可能与今年的机型非常相似。硅成本的上涨可能是原因。 你可以想象&#xff0c;2024年许多最好的手机都会在Snapdragon 8 Gen 3上运行&#xff0c;这是高通公司针对移动设备的顶级芯片系统的更新&#xff0c;尚未宣布。然而&#xff0c;来自中国的…

基于libevent的tcp服务器

libevent使用教程_evutil_make_socket_nonblocking_易方达蓝筹的博客-CSDN博客 一、准备 centos7下安装libevent库 yum install libevent yum install -y libevent-devel 二、代码 server.cpp /** You need libevent2 to compile this piece of code Please see: http://li…

多种方法实现 Nginx 隐藏式跳转(隐式URL,即浏览器 URL 跳转后保持不变)

多种方法实现 Nginx 隐藏式跳转(隐式URL,即浏览器 URL 跳转后保持不变)。 一个新项目,后端使用 PHP 实现,前端不做路由,提供一个模板,由后端路由控制。 Route::get(pages/{name}, [\App\Http\Controllers\ResourceController::class, getResourceVersion])

独立站SEO是什么意思?自主网站SEO的含义?

什么是独立站SEO优化&#xff1f;自建站搜索引擎优化是指什么&#xff1f; 独立站SEO&#xff0c;作为网络营销的重要一环&#xff0c;正在逐渐引起人们的关注。在当今数字化时代&#xff0c;独立站已经成为许多企业、个人宣传推广的首选平台之一。那么&#xff0c;究竟什么是…

【c语言】文件操作

朋友们&#xff0c;大家好&#xff0c;今天分享给大家的是文件操作的相关知识&#xff0c;跟着我一起学习吧&#xff01;&#xff01; &#x1f388;什么是文件 磁盘上的文件是文件。 但是在程序设计中&#xff0c;我们一般谈的文件有两种&#xff1a;程序文件、数据文件 程序文…

如何用输入函数为数组赋值

在编写程序时我们经常使用数组&#xff0c;而数组的大小可能是很大的但是我们并不需要为每个元素都自己赋值&#xff0c;我们可能会自定义输入数组元素个数&#xff0c;我们应该如何实现通过输入函数为数组赋值呢&#xff1f; 目录 第一种&#xff1a; 第二种&#xff1a; 第一…

Floyd(多源汇最短路)

Floyd求最短路 给定一个 n 个点 m 条边的有向图&#xff0c;图中可能存在重边和自环&#xff0c;边权可能为负数。 再给定 k 个询问&#xff0c;每个询问包含两个整数 x 和 y&#xff0c;表示查询从点 x 到点 y 的最短距离&#xff0c;如果路径不存在&#xff0c;则输出 impo…

深入理解python虚拟机:程序执行的载体——栈帧

栈帧&#xff08;Stack Frame&#xff09;是 Python 虚拟机中程序执行的载体之一&#xff0c;也是 Python 中的一种执行上下文。每当 Python 执行一个函数或方法时&#xff0c;都会创建一个栈帧来表示当前的函数调用&#xff0c;并将其压入一个称为调用栈&#xff08;Call Stac…