【深度学习】TCN,An Empirical Evaluation of Generic Convolutional【二】

文章目录

  • 膨胀卷积
      • 什么是膨胀卷积
      • 膨胀卷积公式
      • PyTorch代码
  • 从零开始手动实现一个1D膨胀卷积,不使用PyTorch的`nn.Conv1d`
      • 1. 基本概念
      • 2. 手动实现1D膨胀卷积
  • TCN结构
  • 如何使用TCN
      • 源码说明
        • 1. Chomp1d 类
        • 2. TemporalBlock 类
        • 3. TemporalConvNet 类
      • 使用方法

膨胀卷积

什么是膨胀卷积

膨胀卷积(Dilated Convolution),也称为空洞卷积(Atrous Convolution),是在标准卷积的基础上通过引入膨胀因子(dilation factor)来扩展感受野,而不增加参数数量或计算复杂度。膨胀卷积通过在滤波器的每两个元素之间插入空洞(即,零值)来实现这一点。

膨胀卷积公式

膨胀卷积的数学公式如下:

F ( s ) = ( x ∗ d f ) ( s ) = ∑ i = 0 k − 1 f ( i ) ⋅ x s − d ⋅ i F(s) = (x *_d f)(s) = \sum_{i=0}^{k-1} f(i) \cdot x_{s - d \cdot i} F(s)=(xdf)(s)=i=0k1f(i)xsdi

其中:

  • (x) 是输入信号。
  • (f) 是卷积滤波器。
  • (s) 是输出信号的位置。
  • (d) 是膨胀因子,表示滤波器元素之间的间隔。
  • (k) 是滤波器的大小。

当 (d=1) 时,膨胀卷积退化为标准卷积。

PyTorch代码

下面是一个使用PyTorch实现膨胀卷积的示例:

import torch
import torch.nn as nnclass DilatedConv1D(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, dilation):super(DilatedConv1D, self).__init__()self.dilated_conv = nn.Conv1d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,dilation=dilation,padding=(kernel_size - 1) * dilation // 2)def forward(self, x):return self.dilated_conv(x)# 示例输入
batch_size = 1
in_channels = 1
seq_length = 10x = torch.randn(batch_size, in_channels, seq_length)# 创建膨胀卷积层
dilated_conv_layer = DilatedConv1D(in_channels=1, out_channels=1, kernel_size=3, dilation=2)# 前向传播
output = dilated_conv_layer(x)
print(output)

从零开始手动实现一个1D膨胀卷积,不使用PyTorch的nn.Conv1d

1. 基本概念

膨胀卷积的公式为:

y [ t ] = ∑ k x [ t − k ⋅ d ] ⋅ w [ k ] y[t] = \sum_{k} x[t - k \cdot d] \cdot w[k] y[t]=kx[tkd]w[k]

其中:

  • y [ t ] y[t] y[t] 是输出信号。
  • x [ t ] x[t] x[t] 是输入信号。
  • w [ k ] w[k] w[k] 是卷积核的权重。
  • d d d 是膨胀率。

2. 手动实现1D膨胀卷积

下面是手动实现1D膨胀卷积的Python代码:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass ManualDilatedConv1D(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, dilation=1):super(ManualDilatedConv1D, self).__init__()self.in_channels = in_channelsself.out_channels = out_channelsself.kernel_size = kernel_sizeself.dilation = dilation# 初始化卷积核权重self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))self.bias = nn.Parameter(torch.randn(out_channels))def forward(self, x):batch_size, in_channels, length = x.shapeassert in_channels == self.in_channels# 计算输出的长度out_length = length - (self.kernel_size - 1) * self.dilation# 初始化输出张量out = torch.zeros(batch_size, self.out_channels, out_length)# 对每个输出通道进行卷积for b in range(batch_size):for o in range(self.out_channels):for i in range(out_length):sum = 0for k in range(self.kernel_size):sum += x[b, :, i + k * self.dilation] * self.weight[o, :, k]out[b, o, i] = sum + self.bias[o]return out# 示例参数
in_channels = 1
out_channels = 1
kernel_size = 3
dilation = 2# 创建一个输入张量 (batch_size, channels, length)
input_tensor = torch.randn(1, in_channels, 10)# 创建手动膨胀卷积层
manual_dilated_conv = ManualDilatedConv1D(in_channels, out_channels, kernel_size, dilation)# 前向传播
output_tensor = manual_dilated_conv(input_tensor)print(output_tensor)

TCN结构

在这里插入图片描述

import torch
import torch.nn as nn
from torch.nn.utils import weight_normclass Chomp1d(nn.Module):def __init__(self, chomp_size):super(Chomp1d, self).__init__()self.chomp_size = chomp_sizedef forward(self, x):return x[:, :, :-self.chomp_size].contiguous()class TemporalBlock(nn.Module):def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):super(TemporalBlock, self).__init__()self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,stride=stride, padding=padding, dilation=dilation))self.chomp1 = Chomp1d(padding)self.relu1 = nn.ReLU()self.dropout1 = nn.Dropout(dropout)self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,stride=stride, padding=padding, dilation=dilation))self.chomp2 = Chomp1d(padding)self.relu2 = nn.ReLU()self.dropout2 = nn.Dropout(dropout)self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,self.conv2, self.chomp2, self.relu2, self.dropout2)self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else Noneself.relu = nn.ReLU()self.init_weights()def init_weights(self):self.conv1.weight.data.normal_(0, 0.01)self.conv2.weight.data.normal_(0, 0.01)if self.downsample is not None:self.downsample.weight.data.normal_(0, 0.01)def forward(self, x):out = self.net(x)res = x if self.downsample is None else self.downsample(x)return self.relu(out + res)class TemporalConvNet(nn.Module):def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):super(TemporalConvNet, self).__init__()layers = []num_levels = len(num_channels)for i in range(num_levels):dilation_size = 2 ** iin_channels = num_inputs if i == 0 else num_channels[i-1]out_channels = num_channels[i]layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,padding=(kernel_size-1) * dilation_size, dropout=dropout)]self.network = nn.Sequential(*layers)def forward(self, x):return self.network(x)

如何使用TCN

以下是如何使用上述Temporal Convolutional Network (TCN) 代码的详细讲解和步骤:

源码说明

1. Chomp1d 类

Chomp1d类用于从输入的末端裁剪指定大小的时间步长。

class Chomp1d(nn.Module):def __init__(self, chomp_size):super(Chomp1d, self).__init__()self.chomp_size = chomp_sizedef forward(self, x):return x[:, :, :-self.chomp_size].contiguous()
2. TemporalBlock 类

TemporalBlock类构建了一个基础的时间卷积模块,包括两个卷积层,每个卷积层后都有一个Chomp1d、ReLU激活函数和Dropout。

class TemporalBlock(nn.Module):def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):super(TemporalBlock, self).__init__()self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,stride=stride, padding=padding, dilation=dilation))self.chomp1 = Chomp1d(padding)self.relu1 = nn.ReLU()self.dropout1 = nn.Dropout(dropout)self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,stride=stride, padding=padding, dilation=dilation))self.chomp2 = Chomp1d(padding)self.relu2 = nn.ReLU()self.dropout2 = nn.Dropout(dropout)self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,self.conv2, self.chomp2, self.relu2, self.dropout2)self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else Noneself.relu = nn.ReLU()self.init_weights()def init_weights(self):self.conv1.weight.data.normal_(0, 0.01)self.conv2.weight.data.normal_(0, 0.01)if self.downsample is not None:self.downsample.weight.data.normal_(0, 0.01)def forward(self, x):out = self.net(x)res = x if self.downsample is None else self.downsample(x)return self.relu(out + res)
3. TemporalConvNet 类

TemporalConvNet类将多个TemporalBlock组合在一起,形成完整的TCN模型。

class TemporalConvNet(nn.Module):def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):super(TemporalConvNet, self).__init__()layers = []num_levels = len(num_channels)for i in range(num_levels):dilation_size = 2 ** iin_channels = num_inputs if i == 0 else num_channels[i-1]out_channels = num_channels[i]layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,padding=(kernel_size-1) * dilation_size, dropout=dropout)]self.network = nn.Sequential(*layers)def forward(self, x):return self.network(x)

使用方法

  1. 准备输入数据
    TCN适用于一维序列数据,如时间序列。输入数据的形状应该是(batch_size, num_inputs, sequence_length)

  2. 初始化模型
    定义模型的输入通道数num_inputs,每一层的输出通道数列表num_channels,卷积核大小kernel_sizedropout比例。

num_inputs = 10  # 输入通道数,例如10个特征
num_channels = [16, 32, 64]  # 每个TemporalBlock的输出通道数
kernel_size = 2
dropout = 0.2model = TemporalConvNet(num_inputs, num_channels, kernel_size, dropout)
  1. 训练模型
    使用PyTorch常规的训练步骤,定义损失函数和优化器,然后进行前向传播、计算损失、反向传播和参数更新。
# 示例:随机生成输入数据
batch_size = 8
sequence_length = 30
input_data = torch.randn(batch_size, num_inputs, sequence_length)# 模型输出
output = model(input_data)
print(output.shape)  # 输出形状为(batch_size, num_channels[-1], sequence_length)

batch_size, num_inputs, 和 sequence_length 是与输入数据和模型有关的参数。以下是对它们的详细解释:

  • batch_size:

    • 表示每次输入模型的样本数量。
    • 例如,如果你有1000个样本数据,并且你希望每次输入模型进行训练时处理32个样本,那么batch_size将是32。
    • 这个参数通常用于加速训练过程,并使得计算更高效,因为可以利用并行计算。
  • num_inputs:

    • 表示输入数据的特征数量或通道数。
    • 在时间序列数据中,通常每个时间步可能包含多个特征。例如,如果你的时间序列数据在每个时间步包含温度和湿度两个特征,那么num_inputs将是2。
    • 对于一维时间序列数据,如果每个时间步只有一个值(如单变量时间序列),则num_inputs为1。
  • sequence_length:

    • 表示时间序列的长度,即每个样本中包含的时间步数。
    • 例如,如果你有一个每天记录温度的时间序列数据,记录了30天的数据,那么sequence_length将是30。
    • 这个参数决定了输入数据的时间维度长度。
  1. 定义损失函数和优化器
    可以使用MSELoss或其他适合具体任务的损失函数。
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 示例:随机生成目标数据
target_data = torch.randn(batch_size, num_channels[-1], sequence_length)# 前向传播
output = model(input_data)# 计算损失
loss = criterion(output, target_data)# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()print('Loss:', loss.item())

以上是如何使用Temporal Convolutional Network (TCN)代码的详细步骤和示例。通过这些步骤,你可以定义并训练一个TCN模型来处理一维序列数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/355241.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

DVWA - Brute Force

DVWA - Brute Force 等级:low ​ 直接上bp弱口令爆破,设置变量,攻击类型最后一个,payload为用户名、密码简单列表 ​ 直接run,长度排序下,不一样的就是正确的用户名和密码 ​ 另解: 看一下…

【SpringBoot + Vue 尚庭公寓实战】公寓杂费接口实现(八)

【SpringBoot Vue 尚庭公寓实战】公寓杂费接口实现(八) 文章目录 【SpringBoot Vue 尚庭公寓实战】公寓杂费接口实现(八)1、公寓杂费业务介绍2、公寓杂费逻辑模型介绍3、接口实现3.1、保存或更新杂费值3.2、保存或更新杂费名称3…

Android【SDK】 SDK是如何开发的,怎么打包aar包

文章目录 一、Android SDK开发示例工程二、Android SDK的开发三、打包aar包四、Android SDK的使用 一、Android SDK开发示例工程 本教程工程Git链接:https://gitcode.com/xiaohuihui1400/AndroidSdkExample/overview 二、Android SDK的开发 新建项目,…

Photoshop中图像美化工具的应用

Photoshop中图像美化工具的应用 Photoshop中的裁剪工具Photoshop中的修饰工具模糊工具锐化工具涂抹工具 Photoshop中的颜色调整工具减淡工具加深工具海绵工具 Photoshop中的修复工具仿制图章工具污点修复画笔工具修复画笔工具修补工具内容感知移动工具红眼工具 Photoshop中的裁…

Redis持久化主从哨兵分片集群

文章目录 1. 单点Redis的问题数据丢失问题并发能力问题故障恢复问题存储能力问题 2. Redis持久化 -> 数据丢失问题RDB持久化linux单机安装Redis步骤RDB持久化与恢复示例RDB机制RDB配置示例RDB的fork原理总结 AOF持久化AOF配置示例AOF文件重写RDB与AOF对比 3. Redis主从 ->…

智能制造uwb高精度定位系统模块,飞睿智能3厘米定位测距芯片,无人机高速传输

在科技日新月异的今天,定位技术已经渗透到我们生活的方方面面。从手机导航到自动驾驶,再到无人机定位,都离不开精准的定位系统。然而,随着应用场景的不断拓展,传统的定位技术如GPS、WiFi定位等,因其定位精度…

小摩法兴纷纷转多,看涨港股的时机来了吗?

恒生指数今日高开一度上涨89点报18520点,创近两周高。之后持续震荡下行;恒指临近中 午跌幅扩大,恒生科技指数一度跌近1.5%。截止收盘,恒生指数跌0.52%,盘面上,石油、煤炭、环保、建筑节能等板块涨幅居前&a…

VScode中js关闭烦人的ts检查

类似如下的代码在vscode 会报错,我们可以在前面添加忽略检查或者错误,如下: 但是!!!这太不优雅了!!!,js代码命名没有问题,错在ts上面,…

112、路径总和

给你二叉树的根节点 root 和一个表示目标和的整数 targetSum 。判断该树中是否存在 根节点到叶子节点 的路径,这条路径上所有节点值相加等于目标和 targetSum 。如果存在,返回 true ;否则,返回 false 。 叶子节点 是指没有子节点…

简单且高效的水域物探轨迹坐标转换程序

简单且高效的水域物探轨迹坐标转换程序 前言 水上测线的高精度定位是水域物探的难题,水域磁法、水域地震实施时常采用船舶拖拽传感器进行走航式观测,GPS仪器放在船舶上测量,造成船舶位置与传感器位置存在偏差,后期资料整理需要校…

QT-QPainter实现一个可切换的开关控件

1、效果 2、核心代码 #ifndef SWITCH_H #define SWITCH_H #include <QWidget> #include <QTimer>

Javase.String 类

String 类 【本节目标】1. String类的重要性2. 常用方法2.1 字符串构造2.2 String对象的比较2.3 字符串查找2.4 转化2.5 字符串替换2.7 字符串截取2.8 其他操作方法2.9 字符串的不可变性2.10 字符串修改 3. StringBuilder和StringBuffer3.2 面试题&#xff1a; 4. String类oj4.…

【MySQL】 -- 用户管理

1. 权限 如果我们只能使用root用户&#xff0c;这样存在安全隐患。这时&#xff0c;就需要使用MySQL的用户管理。创建出非root用户&#xff0c;限制其权限。 权限这个概念拿出来就是用来限制非root用户的。这样从技术手段上保证了数据的安全性和完整性&#xff0c;防止有人删库…

【SAP Abap】一条SQL语句实现支持报表项配置的财务报表

【SAP Abap】一条SQL语句实现支持报表项配置的财务报表 1、业务背景2、配置项特殊处理3、实现方式&#xff08;Hana Studio SQL语句&#xff09;4、实现方式&#xff08;Abap OpenSQL语句&#xff09;5、总结 1、业务背景 在财务三大报表之外&#xff0c;业务需要使用类似的科…

数据库复习——模式分解

模式分解这边主要包括无损分解和保持函数依赖的分解两种形式&#xff0c;简单整理一下。 无损分解 把一个 R R R 分成 ρ { R 1 , R 2 , ⋯ , R k } \rho \{R_1,R_2,\cdots,R_k\} ρ{R1​,R2​,⋯,Rk​}&#xff0c;然后通过自然连接 R 1 ⋈ R 2 ⋈ ⋯ ⋈ R k R_1\bowtie R…

git的远程管理与标签管理

✨前言✨ &#x1f4d8; 博客主页&#xff1a;to Keep博客主页 &#x1f646;欢迎关注&#xff0c;&#x1f44d;点赞&#xff0c;&#x1f4dd;留言评论 ⏳首发时间&#xff1a;2024年6月20日 &#x1f4e8; 博主码云地址&#xff1a;博主码云地址 &#x1f4d5;参考书籍&…

swift使用swift-protobuf协议通讯,使用指北

什么是Protobuf Protobuf&#xff08;Protocol Buffers&#xff09;协议&#x1f609; Protobuf 是一种由 Google 开发的二进制序列化格式和相关的技术&#xff0c;它用于高效地序列化和反序列化结构化数据&#xff0c;通常用于网络通信、数据存储等场景。 为什么要使用Proto…

【python】Sklearn—Cluster

参考学习来自 10种聚类算法的完整python操作示例 文章目录 聚类数据集亲和力传播——AffinityPropagation聚合聚类——AgglomerationClusteringBIRCH——Birch&#xff08;✔&#xff09;DBSCAN——DBSCANK均值——KMeansMini-Batch K-均值——MiniBatchKMeans均值漂移聚类——…

MySQL之复制(七)

复制 定制的复制方案 分离功能 许多应用都混合了在线事务处理(OLTP)和在线数据分析(OLAP)的查询。OLTP查询比较短并且是事务型的。OLAP查询则通常很大&#xff0c;也很慢&#xff0c;并且不要求绝对最新的数据。这两种查询给服务器带来的负担完全不同&#xff0c;因此它们需…

Linux系统部署Samba服务,共享文件夹给Windows

Samba服务是在Linux和UNIX系统上实现SMB协议的一个免费软件&#xff0c;由服务器及客户端程序构成。 Samba服务是连接Linux与Windows的桥梁&#xff0c;它通过实现SMB&#xff08;Server Message Block&#xff09;协议来允许跨平台的文件和打印机共享。该服务不仅支持Linux和…