机器学习:卷积介绍及代码实现卷积操作

在这里插入图片描述

传统卷积运算是将卷积核以滑动窗口的方式在输入图上滑动,当前窗口内对应元素相乘然后求和得到结果,一个窗口一个结果。相乘然后求和恰好也是向量内积的计算方式,所以可以将每个窗口内的元素拉成向量,通过向量内积进行运算,多个窗口的向量放在一起就成了矩阵,每个卷积核也拉成向量,多个卷积核的向量排在一起也成了矩阵,于是,卷积运算转化成了矩阵乘法运算。下图很好地演示了矩阵乘法的运算过程:

im2col

将卷积运算转化为矩阵乘法,从乘法和加法的运算次数上看,两者没什么差别,但是转化成矩阵后,运算时需要的数据被存在连续的内存上,这样访问速度大大提升(cache),同时,矩阵乘法有很多库提供了高效的实现方法,像BLAS、MKL等,转化成矩阵运算后可以通过这些库进行加速。

缺点呢?这是一种空间换时间的方法,消耗了更多的内存——转化的过程中数据被冗余存储。

代码实现

太久没写python代码,面试的时候居然想用c++来实现,其实肯定能实现,但是比起使用python复杂太多了,所以这里使用python中的numpy来实现。

一、滑动窗口版本实现(这个好理解)

import numpy as np# 为了简化运算,默认batch_size = 1
class my_conv(object):def __init__(self, input_data, weight_data, stride, padding = 'SAME'):self.input = np.asarray(input_data, np.float32)self.weights = np.asarray(weight_data, np.float32)self.stride = strideself.padding = paddingdef my_conv2d(self):"""self.input: c * h * w  # 输入的数据格式self.weights: c * h * w"""[c, h, w] = self.input.shape[kc, k, _] = self.weights.shape  # 这里默认卷积核的长宽相等assert c == kc  # 如果输入的channel与卷积核的channel不一致即报错output = []# 分通道卷积,最后再加起来for i in range(c):  f_map = self.input[i]kernel = self.weights[i]rs = self.compute_conv(f_map, kernel)if output == []:output = rselse:output += rsreturn output# padding和rs的宽高计算全部基于rs_h = (h - k + 2p)//s + 1def compute_conv(self, fm, kernel):[h, w] = fm.shape[k, _] = kernel.shapeif self.padding == 'SAME': # 知道rs_hw,求pad_hwrs_h = h // self.striders_w = w // self.stridepad_h = (self.stride * (rs_h - 1) + k - h) // 2pad_w = (self.stride * (rs_w - 1) + k - w) // 2elif self.padding == 'VALID': # 知道pad_hw,求rspad_h = 0pad_w = 0rs_h = (h - k) // self.stride + 1rs_w = (w - k) // self.stride + 1elif self.padding == 'FULL': # 知道pad_hw,求rs_hwpad_h = k - 1pad_w = k - 1rs_h = (h + 2 * pad_h - k) // self.stride + 1rs_w = (w + 2 * pad_w - k) // self.stride + 1padding_fm = np.zeros([h + 2 * pad_h, w + 2 * pad_w], np.float32)padding_fm[pad_h:pad_h+h, pad_w:pad_w+w] = fm  # 完成对fm的zeros paddingrs = np.zeros([rs_h, rs_w], np.float32)for i in range(rs_h):for j in range(rs_w):roi = padding_fm[i*self.stride:(i*self.stride + k), j*self.stride:(j*self.stride + k)]rs[i, j] = np.sum(roi * kernel) # np.asarray格式下的 * 是对应元素相乘return rsif __name__=='__main__':input_data = [[[1, 0, 1, 2, 1],[0, 2, 1, 0, 1],[1, 1, 0, 2, 0],[2, 2, 1, 1, 0],[2, 0, 1, 2, 0],],[[2, 0, 2, 1, 1],[0, 1, 0, 0, 2],[1, 0, 0, 2, 1],[1, 1, 2, 1, 0],[1, 0, 1, 1, 1],],]weight_data = [[[1, 0, 1],[-1, 1, 0],[0, -1, 0],],[[-1, 0, 1],[0, 0, 1],[1, 1, 1],]]conv = my_conv(input_data, weight_data, 1, 'SAME')print(conv.my_conv2d())

二、矩阵乘法版本实现

import numpy as np# 为了简化运算,默认batch_size = 1
class my_conv(object):def __init__(self, input_data, weight_data, stride, padding = 'SAME'):self.input = np.asarray(input_data, np.float32)self.weights = np.asarray(weight_data, np.float32)self.stride = strideself.padding = paddingdef my_conv2d(self):"""self.input: c * h * w  # 输入的数据格式self.weights: c * h * w"""[c, h, w] = self.input.shape[kc, k, _] = self.weights.shape  # 这里默认卷积核的长宽相等assert c == kc  # 如果输入的channel与卷积核的channel不一致即报错# rs_h与rs_w为最后输出的feature map的高与宽if self.padding == 'SAME':pad_h = (self.stride * (h - 1) + k - h) // 2pad_w = (self.stride * (w - 1) + k - w) // 2rs_h = hrs_w = welif self.padding == 'VALID':pad_h = 0pad_w = 0rs_h = (h - k) // self.stride + 1rs_w = (w - k) // self.stride + 1elif self.padding == 'FULL':pad_h = k - 1pad_w = k - 1rs_h = (h + 2 * pad_h - k) // self.stride + 1rs_w = (w + 2 * pad_w - k) // self.stride + 1# 对输入进行zeros padding,注意padding后依然是三维的pad_fm = np.zeros([c, h+2*pad_h, w+2*pad_w], np.float32)pad_fm[:, pad_h:pad_h+h, pad_w:pad_w+w] = self.input# 将输入和卷积核转化为矩阵相乘的规格mat_fm = np.zeros([rs_h*rs_w, kc*k*k], np.float32)mat_kernel = self.weightsmat_kernel.shape = (kc*k*k, 1) # 转化为列向量row = 0   for i in range(rs_h):for j in range(rs_w):roi = pad_fm[:, i*self.stride:(i*self.stride+k), j*self.stride:(j*self.stride+k)]mat_fm[row] = roi.flatten()  # 将roi扁平化,即变为行向量row += 1# 卷积的矩阵乘法实现rs = np.dot(mat_fm, mat_kernel).reshape(rs_h, rs_w) return rsif __name__=='__main__':input_data = [[[1, 0, 1, 2, 1],[0, 2, 1, 0, 1],[1, 1, 0, 2, 0],[2, 2, 1, 1, 0],[2, 0, 1, 2, 0],],[[2, 0, 2, 1, 1],[0, 1, 0, 0, 2],[1, 0, 0, 2, 1],[1, 1, 2, 1, 0],[1, 0, 1, 1, 1],],]weight_data = [[[1, 0, 1],[-1, 1, 0],[0, -1, 0],],[[-1, 0, 1],[0, 0, 1],[1, 1, 1],]]conv = my_conv(input_data, weight_data, 1, 'SAME')print(conv.my_conv2d())

参考资料

1、im2col:将卷积运算转为矩阵相乘
2、面试基础–深度学习 卷积及其代码实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/257619.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STM32控制JQ8400语音播报模块

时间记录:2024/2/7 一、JQ8400引脚介绍 标示说明ONE LINE一线操作引脚BUSY忙信号引脚,正在播放语音时输出高电平RX串口两线操作接收引脚TX串口两线操作发送引脚GND电源地引脚DC-5V电源引脚,3.3-5VDAC-RDAC输出右声道引脚DAC-LDAC输出左声道…

爬虫——ajax和selenuim总结

为什么要写这个博客呢,这个代码前面其实都有,就是结束了。明天搞个qq登录,这个就结束了。 当然也会更新小说爬取,和百度翻译,百度小姐姐的爬取,的对比爬取。总结嘛!!!加…

【运维测试】测试理论+工具总结笔记第1篇:测试理论的主要内容(已分享,附代码)

本系列文章md笔记(已分享)主要讨论测试理论测试工具相关知识。Python测试理论的主要内容,掌握软件测试的基本流程,知道软件测试的V和W模型的优缺点,掌握测试用例设计的要素,掌握等价类划分法、边界值法、因…

可视化工具:将多种数据格式转化为交互式图形展示的利器

引言 在数据驱动的时代,数据的分析和理解对于决策过程至关重要。然而,不同的数据格式和结构使得数据的解读变得复杂和困难。为了解决这个问题,一种强大的可视化工具应运而生。这个工具具有将多种数据格式(包括JSON、YAML、XML、C…

专业140+总分420+东北大学841通信专业基础考研经验东大电子信息与通信工程,真题,大纲,参考书。

今年考研顺利上岸,被东北大学通信工程录取,其中专业课841通信专业基础140,数二140,总分420,整体每门课都还是比较均衡,刚开始考研前也和大家一样,焦虑,紧张,面对考研怕失…

关于npmlink的问题

深入浅出关于Npm linl的问题 关键词: vue3报错 Uncaught TypeError: Cannot read properties of null (reading ‘isCE‘) at renderSlot npm link 无法实现热更新 我的开发环境是 “vue”: “^3.2.13” 今天在使用 rollup搭建组件库的时候我发现我的组件库不能…

模拟电子技术——基本放大电路

文章目录 前言一、三极管输入输出特性三极管放大作用三极管电流放大关系三极管的特性曲线 二、基本放大电路-电路结构与工作原理基本放大电路的构成基本放大电路放大原理三种基本放大电路比较 三、基本放大电路静态工作点什么是静态工作点?静态工作点的作用估算法分…

MySQL-----函数篇

目录 ▶ 字符串函数 ▶ 数值函数 ▶ 日期函数 ▶ 流程函数 ▶ 简介 函数是指一段可以直接被另一段程序调用的程序或代码。 ▶ 字符串函数 函数描述实例ASCII(s)返回字符串 s 的第一个字符的 ASCII 码。 返回 CustomerName 字段第一个字母的 ASCII 码: S…

FastJson、Jackson使用AOP切面进行日志打印异常

FastJson、Jackson使用AOP切面进行日志打印异常 一、概述 1、问题详情 使用FastJson、Jackson进行日志打印时分别包如下错误: 源码: //fastjon log.info("\nRequest Info :{} \n", JSON.toJSONString(requestInfo)); //jackson …

无人机概述及系统组成,无人机系统的构成

无人机的定义 无人驾驶航空器,是一架由遥控站管理(包括远程操纵或自主飞行)的航空器,也称遥控驾驶航空器,以下简称无人机。 无人机系统的定义 无人机系统,也称无人驾驶航空器系统,是指一架无人…

【MySQL/Redis】如何实现缓存一致

目录 不实用的方案 1. 先写 MySQL , 再写 Redis 2. 先写 Redis , 再写MySQL 3. 先删 Redis,再写 MySQL 实用的方案 1. 先删 Redis,再写 MySQL, 再删 Redis 2. 先写 MySQL , 再删 Redis 3. 先写MySQL,通过BinLog&#xff0…

anomalib1.0学习纪实

回顾:细分、纵深、高端、上游、积累、极致。 回顾:产品化,资本化,规模化,大干快上,小农思维必死无疑。 春节在深圳新地中央,学习anomalib1.0。 一、安装: 1、常规安装 采用的是…

【MySQL】外键约束的删除和更新总结

🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​💫个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-7niJLSFaPo0wso60 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-siz…

CentOS7.9+Kubernetes1.29.2+Docker25.0.3高可用集群二进制部署

CentOS7.9Kubernetes1.29.2Docker25.0.3高可用集群二进制部署 Kubernetes高可用集群(Kubernetes1.29.2Docker25.0.3)二进制部署二进制软件部署flannel v0.22.3网络,使用的etcd是版本3,与之前使用版本2不同。查看官方文档进行了解…

小白学习Halcon100例:如何利用动态阈值分割图像进行PCB印刷缺陷检测?

文章目录 *读入图片*关闭所有窗口*获取图片尺寸*根据图片尺寸打开一个窗口*在窗口中显示图片* 缺陷检测开始 ...*1.开运算 使用选定的遮罩执行灰度值开运算。*2.闭运算 使用选定的遮罩执行灰度值关闭运算*3.动态阈值分割 使用局部阈值分割图像显示结果*显示原图*设置颜色为红色…

一文搞懂“什么是双亲委派”

文章目录 双亲委派介绍类加载器介绍类加载流程验证自定义类加载器为什么要设计这种机制 提前声明:以下介绍都是基于jdk9之前版本的双亲委派机制,jdk9及之后版本双亲委派会有变化,不在本篇介绍。 双亲委派介绍 双亲委派机制(Pare…

前端秘法进阶篇之事件循环

目录 一.浏览器的进程模型 1.进程 2.线程 二.浏览器的进程和线程 1. 浏览器进程 2. 网络进程 3. 渲染进程 三.渲染主线程 四.异步 五.优先级 1. 延时队列: 2.交互队列: 3.微队列: 六.JS 的事件循环 附加:JS 中的计时器能做到精…

MATLAB实现LSTM时间序列预测

LSTM模型可以在一定程度上学习和预测非平稳的时间序列,其具有强大的记忆和非线性建模能力,可以捕捉到时间序列中的复杂模式和趋势[4]。在这种情况下,LSTM模型可能会自动学习到时间序列的非平稳性,并在预测中进行适当的调整。其作为…

微信小程序介绍、账号申请、开发者工具目录结构详解及小程序配置

目录 一、微信小程序介绍 1.什么是小程序? 2.小程序可以干什么? 3.微信小程序特点 二、账号申请 1.账号注册 2.测试号申请 三、安装开发工具 四、开发小程序 五、目录结构 JSON 配置 小程序配置 app.json 工具配置 project.config.json 页…

云备份项目:在云端保护您的数据【二、开发】

☘️过度的信息对一个过着充实生活的人来说,是一种不必要的负担☘️ 文章目录 前言工具类实现文件实用工具类代码实现 Json实用工具类代码实现 服务端单例配置类系统配置信息单例配置类 数据管理类数据信息数据管理 热点管理类业务处理类 客户端数据管理类文件备份类…