Strassen矩阵乘法——C++

【题目描述】

根据课本“Strassen矩阵乘法”的基本原理,设计并实现一个矩阵快速乘法的工具。并演示至少10000维的矩阵快速乘法对比样例。

【功能要求】
  1. 实现普通矩阵乘法算法和“Strassen矩阵乘法”算法
  2. 对相同的矩阵,分别用普通矩阵乘法算法,“Strassen矩阵乘法”算法和Matlab进行运算,比较时间差异(多次计算求平均值);
【选做功能】
  1. 突破2n的维数限制,能够对其他维数的矩阵进行运算。
  2. 方法不限,实现尽可能快的矩阵计算。
  3. 其他可扩展的功能。
【实验过程】
  1. 首先我们先设计实现普通的矩阵乘法,对于两个矩阵,普通的矩阵相乘做法是:遍历三层矩阵计算:我们设A和B是2个n*n的矩阵,它们的乘积AB同样是一个n*n矩阵。 A和B的乘积矩阵C中元素C[i][j]定义为:

比如,我们以下列的例子作为参考:对于它们的乘积,我们应该使用公式:

所以,从上述的公式中,我们知道如果使用这正常的矩阵相乘,由此得出:

所以我们的计算的时间复杂度是O(n^3)。

计算的代码为:对于数据的输入,我们使用的是将数据存储在data.txt中,每次去读取这个文件中的矩阵规模n和矩阵 arr1[][] 和 arr2[][]

#include<iostream>
#include<time.h>
#include "fstream"
void Multiply(int pInt, long long **pInt1, long long **pInt2, long long **pInt3);void out(int pInt, long long **pInt1);using namespace std;int main() {system("chcp 65001 > nul");std::ios::sync_with_stdio(false);std::cin.tie(0);
//    c++加速流int M;fstream f;f.open("data.txt",ios::in);f >> M;int length = M;if (M % 2 != 0) //若M为奇数,则补零{length++;}long long **A = new long long *[length];long long **B = new long long *[length];long long **C = new long long *[length];for (int i = 0; i < length; i++) {A[i] = new long long[length];B[i] = new long long[length];C[i] = new long long[length];}for (int i = 0; i < M; i++) {for (int j = 0; j < M; j++)f >> A[i][j];}for (int i = 0; i < M; i++) {for (int j = 0; j < M; j++) {C[i][j] = 0;f >> B[i][j];}}clock_t start;clock_t end;start = clock();Multiply(M, A, B, C);end = clock();cout <<"当数据量n为"<<M<<"时,耗费的时间:"<< (end - start) << "ms" << endl;  //输出时间(单位:ms)
//    out(M, C);}void out(int n, long long **arr) {for (int i = 0; i < n; i++) {for (int j = 0; j < n; j++) {cout << arr[i][j] << " ";}cout << endl;}
}void Multiply(int n, long long **A, long long **B, long long **C) {for (int i = 0; i < n; i++) {for (int j = 0; j < n; j++) {for (int k = 0; k < n; k++) {C[i][j] += A[i][k] * B[k][j];}}}
}

2. 观察这个算法之后,我们发现,在计算矩阵相乘的时候,时间复杂度达到了O(n^3)。 如果n过于大的话,需要计算很久才会出结果。对于10000 × 10000的数据量二维数组存储的话会爆栈。因此我们使用更加高效的算法Strassen矩阵乘法。

3. 1969年,Volker Strassen提出了第一个算法时间复杂度低于O(n^3的矩阵乘法算法,算法复杂度为,还是很接近3,因此StrassenStrassen 算法只有在对于维数比较大的矩阵,性能上才可能有优势,可以减少很多乘法计算。StrassenStrassen 算法证明了矩阵乘法存在时间复杂度低于O(n^3的算法的存在,后续学者不断研究发现新的更快的算法,截止目前时间复杂度最低的矩阵乘法算法是 Coppersmith-Winograd 方法的一种扩展方法,其算法复杂度为

4. Strassen原理详解:

假设矩阵 A 和矩阵 B 都是  N×N (N = 2^n)的方矩阵,求 C = AB,如下所示:

其中,8T(2n​) 表示 8 次矩阵乘法,而且相乘的矩阵规模降到了 2n​。

O()表示 4 次矩阵加法的时间复杂度以及合并矩阵 C 的时间复杂度。

最终可计算得到  T(n)= O()。

可以看出每次递归操作都需要 8 次矩阵相乘,而这正是瓶颈的来源。相比加法,矩阵乘法是非常慢的,于是减少矩阵相乘的次数就显得尤为重要。Strassen算法的主要目的其实也是从这个角度出发的,目的就是减少乘法次数,降低时间复杂度。

5. Strassen的实现步骤:

① 对于上述的A、B、C三个矩阵进行分解,分解花费的时间复杂度是O(1)

② 然后我们创建如下的10个 ×  的矩阵 S1 ,S2 ,S3 …… S10 ,花费的时间复杂大约是O(

③ 接下来递归计算七个矩阵P1 ,P2 ,P3 …… P7 每个P都是 n2 × n2 的矩阵。

④ 接着通过Pi 来计算C11 C12 C21 C22 ,花费的时间为O()。

这样就相对减少了一些时间复杂度。代码如下:

#include <iostream>
#include <time.h>
#include <fstream>
void out(int m, int **pInt);using namespace std;void subMatrix(int l, long long **m, long long **n, long long **ans) {for (int i = 0; i < l; i++) {for (int j = 0; j < l; j++) {ans[i][j] = m[i][j] - n[i][j];}}
}void addMatrix(int l, long long **m, long long **n, long long **ans) //两矩阵加法
{for (int i = 0; i < l; i++) {for (int j = 0; j < l; j++) {ans[i][j] = m[i][j] + n[i][j];}}
}void multiMatrix(int l, long long **m, long long **n, long long **ans) {for (int i = 0; i < l; i++) {for (int j = 0; j < l; j++) {ans[i][j] = 0;for (int k = 0; k < l; k++) {ans[i][j] += m[i][k] * n[k][j];}}}
}void Strassen(int M, long long **A, long long **B, long long **C) {int len = M / 2;long long **A11 = new long long  *[len];long long **A12 = new long long  *[len];long long **A21 = new long long  *[len];long long **A22 = new long long  *[len];long long **B11 = new long long  *[len];long long **B12 = new long long  *[len];long long **B21 = new long long  *[len];long long **B22 = new long long  *[len];long long **C11 = new long long  *[len];long long **C12 = new long long  *[len];long long **C21 = new long long  *[len];long long **C22 = new long long  *[len];long long **P1 = new  long long *[len];long long **P2 = new  long long *[len];long long **P3 = new  long long *[len];long long **P4 = new  long long *[len];long long **P5 = new  long long *[len];long long **P6 = new  long long *[len];long long **P7 = new  long long *[len];long long **AR = new long long *[len];long long **BR = new long long *[len];for (int i = 0; i < len; i++) {A11[i] = new long long [len];A12[i] = new long long [len];A21[i] = new long long [len];A22[i] = new long long [len];B11[i] = new long long [len];B12[i] = new long long [len];B21[i] = new long long [len];B22[i] = new long long [len];C11[i] = new long long [len];C12[i] = new long long [len];C21[i] = new long long [len];C22[i] = new long long [len];P1[i] = new  long long [len];P2[i] = new  long long [len];P3[i] = new  long long [len];P4[i] = new  long long [len];P5[i] = new  long long [len];P6[i] = new  long long [len];P7[i] = new  long long [len];AR[i] = new  long long [len];BR[i] = new  long long [len];}for (int i = 0; i < len; i++) {for (int j = 0; j < len; j++) {A11[i][j] = A[i][j];A12[i][j] = A[i][j + len];A21[i][j] = A[i + len][j];A22[i][j] = A[i + len][j + len];B11[i][j] = B[i][j];B12[i][j] = B[i][j + len];B21[i][j] = B[i + len][j];B22[i][j] = B[i + len][j + len];}}addMatrix(len, A11, A22, AR);addMatrix(len, B11, B22, BR);multiMatrix(len, AR, BR, P1);addMatrix(len, A21, A22, AR);multiMatrix(len, AR, B11, P2);subMatrix(len, B12, B22, BR);multiMatrix(len, A11, BR, P3);subMatrix(len, B21, B11, BR);multiMatrix(len, A22, BR, P4);addMatrix(len, A11, A12, AR);multiMatrix(len, AR, B22, P5);subMatrix(len, A21, A11, AR);addMatrix(len, B11, B12, BR);multiMatrix(len, AR, BR, P6);subMatrix(len, A12, A22, AR);addMatrix(len, B21, B22, BR);multiMatrix(len, AR, BR, P7);addMatrix(len, P1, P4, AR);subMatrix(len, P7, P5, BR);addMatrix(len, AR, BR, C11);addMatrix(len, P3, P5, C12);addMatrix(len, P2, P4, C21);addMatrix(len, P1, P3, AR);subMatrix(len, P6, P2, BR);addMatrix(len, AR, BR, C22);for (int i = 0; i < len; i++) {for (int j = 0; j < len; j++) {C[i][j] = C11[i][j];C[i][j + len] = C12[i][j];C[i + len][j] = C21[i][j];C[i + len][j + len] = C22[i][j];}}
}int main() {system("chcp 65001 > nul");std::ios::sync_with_stdio(false);std::cin.tie(0);
//    c++加速流int M;fstream f;f.open("data.txt",ios::in);f >> M;int length = M;if (M % 2 != 0) //若M为奇数,则补零{length++;}long long **A = new long long *[length];long long **B = new long long *[length];long long **C = new long long *[length];for (int i = 0; i < length; i++) {A[i] = new long long [length];B[i] = new long long [length];C[i] = new long long [length];}for (int i = 0; i < M; i++) {for (int j = 0; j < M; j++)f >> A[i][j];}for (int i = 0; i < M; i++) {for (int j = 0; j < M; j++) {f >> B[i][j];}}if (length > M) {for (int i = 0; i < length; i++) {A[i][M] = 0;A[M][i] = 0;B[i][M] = 0;B[M][i] = 0;}}clock_t start;clock_t end;start = clock();Strassen(length, A, B, C);end = clock();cout <<"当数据量n为"<<M<<"时,耗费的时间:"<< (end - start) << "ms" << endl;  //输出时间(单位:ms)
// 输出
//    out(M,C);return 0;
}void out(int M, int **C) {for (int i = 0; i < M; i++){for (int j = 0; j < M; j++){cout << C[i][j] << " \n"[j == M - 1];}}
}

接着,我们通过改变数据量的大小,来比较这两个算法的耗时。

对于测试数据的生成,我们使用makeData.cpp来生成并保存到文件data.txt。代码如下:

//简单的随机制造数据
#include<iostream>
#include <ctime>
#include "stdlib.h"
#include "fstream"
using namespace std;
// 左闭右闭区间
int getRand(int min, int max) {return (rand() % (max - min + 1)) + min;
}int main() {int n;cin >> n;fstream f;f.open("data.txt", ios::out);f << n << endl;srand(time(0));for (int i = 0; i < 2 * n; i++) {for (int j = 0; j < n; j++) {f << getRand(0, 10) << " ";}f << endl;}f.close();return 0;
}

我们使用了上述的矩阵生成代码,随机创建了10000×10000大小的矩阵进行测试,如下图所示:

计算得到结果:

我们再使用Matlab来计算一下两个矩阵相乘的耗时:

统计得到:(其中的数据都是由3次统计求平均值的方式得来的。)

数据量

10

50

100

500

1000

2000

3000

普通

0

0.7

17.5

1644.667

20423.67

209913.5

779112

Strassen

0

2.5

11.5

1426

11692.67

90211.33

521274

matlab

0

0.185

0.328

2.926

16.38

237.273

407.674

【实验结论】

最后,比较得出结论:

1. 在矩阵规模较小的情况下,(例如 n<64),普通的矩阵相乘算法表现更优,耗时更短。

2. 当矩阵规模较大时,Strassen算法表现更优,耗时更短。因为在矩阵规模较大时,Strassen算法所需的递归次数相对较少,而且该算法每一次递归所做的运算规模较小,这些都有利于提高运算效率。

3. 在Matlab中,可以使用自带的矩阵乘法函数*来进行矩阵相乘运算,该函数会根据矩阵规模和计算机硬件等情况自动选择最优算法进行计算。因此,在实际应用中,建议使用内置的矩阵乘法函数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/315800.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言:数据结构(单链表)

目录 1. 链表的概念及结构2. 实现单链表3. 链表的分类 1. 链表的概念及结构 概念&#xff1a;链表是一种物理存储结构上非连续、非顺序的存储结构&#xff0c;数据元素的逻辑顺序是通过链表的指针链接次序实现的。 链表的结构跟火车车厢相似&#xff0c;淡季时车次的车厢会相应…

ULTIMATE VOCAL REMOVER V5 for Mac:专业人声消除软件

ULTIMATE VOCAL REMOVER V5 for Mac是一款专为Mac用户设计的人声消除软件&#xff0c;它凭借强大的功能和卓越的性能&#xff0c;在音乐制作和后期处理领域崭露头角。 ULTIMATE VOCAL REMOVER V5 for Mac v5.6激活版下载 这款软件基于深度神经网络&#xff0c;通过先进的训练模…

Windows使用bat远程操作Linux并执行命令

背景&#xff1a;让客户可以简单在Windows中能自己执行 Linux中的脚本&#xff0c;傻瓜式操作&#xff01; 方法&#xff1a;做一个简单的bat脚本&#xff01;能远程连接到Linux&#xff0c;并执行Linux命令&#xff01;客户双击就能使用&#xff01; 1、原先上网查询到使用P…

hadoop命令

hadoop命令 目录 hadoop命令 1.查看文件下面有哪些文件和目录 2.获取文件信息 查看文件内容 3.创建一个文件夹 4.剪切 1&#xff09;从本地hadoop剪切到hdfs并上传到hdfs 2&#xff09;剪切 从hdfs剪切到本地hadoop目录上 5.删除 1&#xff09;递归删除 2&#xff0…

AI Agent新对决:LangGraph与AutoGen的技术角力

AI Agent变革未来&#xff0c;LangGraph对抗AutoGen ©作者|Blaze 来源|神州问学 引言 比尔.盖茨曾在他的博客上发表一篇文章&#xff1a;《AI is about to completely change how you use computers》。在文章中&#xff0c;比尔盖茨探讨AI Agent对我们未来生活的巨大影…

windows电脑改造为linux

有个大学用的旧笔记本电脑没啥用了&#xff0c;决定把它改成linux搭一个服务器&#xff1b; 一、linux安装盘制作 首先要有一个大于8G的U盘&#xff0c;然后去下载需要的linux系统镜像&#xff0c;我下的是ubuntu&#xff0c;这里自选版本 https://cn.ubuntu.com/download/d…

请求响应案例-员工管理系统

准备工作 1、在pom.xml文件中引入dom4j依赖&#xff0c;解析xml文件。如果该依赖爆红&#xff0c;那么刷新以下就可以。 <!-- 解析XML --><dependency><groupId>org.dom4j</groupId><artifactId>dom4j</artifactId><version>2.1.3…

自然语言处理: 第三十章Hugging face使用指南之——trainer

原文连接: Trainer (huggingface.co) 最近在用HF的transformer库自己做训练&#xff0c;所以用着了transformers.Trainer&#xff0c;这里记录下用法 基本参数 class transformers.Trainer( model: Union None,args: TrainingArguments None,data_collator: Optional Non…

debian和ubuntu的核心系统和系统命令的区别

Debian和Ubuntu虽然有很深的渊源&#xff0c;都是基于Debian的发行版&#xff0c;但它们在核心系统和系统命令上还是有一些差别的。以下是一些主要的不同之处&#xff1a; 1. 发布周期&#xff1a; - Debian&#xff1a; Debian项目采用滚动发布模型&#xff0c;持续更新&a…

[论文笔记]GAUSSIAN ERROR LINEAR UNITS (GELUS)

引言 今天来看一下GELU的原始论文。 作者提出了GELU(Gaussian Error Linear Unit,高斯误差线性单元)非线性激活函数&#xff1a; GELU x Φ ( x ) \text{GELU} x\Phi(x) GELUxΦ(x)&#xff0c;其中 Φ ( x ) \Phi(x) Φ(x)​是标准高斯累积分布函数。与ReLU激活函数通过输入…

Python 爬虫如何配置代理 IP (Py 采集)

在Python中配置代理IP&#xff0c;可以通过设置requests库的proxies参数来实现。以下是一个示例&#xff1a; import requests# 则立可以获取稳定代理Ip&#xff1a;https://www.kuaidaili.com/?refrg3jlsko0ymg # 推荐使用私密动态 IP proxies {"http": "ht…

以太网交换机自学习与转发帧

自学习算法:每次转发帧前先将当前MAC地址以及对应的接口好存入到帧交换表中

ios CI/CD 持续集成 组件化专题一 iOS 将图片打包成bundle

一、 创建 选择 macos 下的Bundle 二 、取名点击下一步 三、Base SDK 选择ios 四 、Build Active Architecture Only 五、Installation后面的内容删除 六、Skip Install 选择NO 七、Strip Debug Symbols During Copy 中"Release"项设置为 "YES" 八、COM…

区块链基础——区块链应用架构概览

目录 区块链应用架构概览&#xff1a; 1、区块链技术回顾 1.1、以太坊结点结构 1.2、多种应用场景 2、区块链应用架构概览 2.1、传统的Web2 应用程序架构 2.2、Web3 应用程序架构——最简架构 2.3、Web3 应用程序架构——前端web3.js ether.js 2.4、Web3 应用程序架构—…

Android Widget开发代码示例详细说明

因为AppWidgetProvider扩展自BroadcastReceiver, 所以你不能保证回调函数完成调用后&#xff0c;AppWidgetProvider还在继续运行。 a. AppWidgetProvider 的实现 /*** Copyright(C):教育电子有限公司 * Project Name: NineSync* Filename: SynWidgetProvider.java * Author(S…

逆向案例三十——webpack登录某游戏

网址&#xff1a;aHR0cHM6Ly93d3cuZ205OS5jb20v 步骤&#xff1a; 进行抓包分析&#xff0c;找到登录接口&#xff0c;发现密码有加密 跟栈分析&#xff0c;从第三个栈进入&#xff0c;打上断点&#xff0c;再次点击登录 明显找到password,它由o赋值&#xff0c;o由a.encode(…

gitee / github 配置git, 实现免密码登录

文章目录 怎么配置公钥和私钥验证配置成功问题 怎么配置公钥和私钥 以下内容参考自 github ssh 配置&#xff0c;gitee的配置也是一样的&#xff1b; 粘贴以下文本&#xff0c;将示例中使用的电子邮件替换为 GitHub 电子邮件地址。 ssh-keygen -t ed25519 -C "your_emai…

R语言的基本图形

一&#xff0c;条形图 安装包 install.packages("vcd") 绘制简单的条形图 barplot(c(1,2,4,5,6,3)) 水平条形图 barplot(c(1,2,4,5,6,3),horiz TRUE) 堆砌条形图 > d1<-c("Placebo","Treated") > d2<-c("None",&qu…

【Flink入门修炼】2-3 Flink Checkpoint 原理机制

如果让你来做一个有状态流式应用的故障恢复&#xff0c;你会如何来做呢&#xff1f; 单机和多机会遇到什么不同的问题&#xff1f; Flink Checkpoint 是做什么用的&#xff1f;原理是什么&#xff1f; 一、什么是 Checkpoint&#xff1f; Checkpoint 是对当前运行状态的完整记…

springboot 集成 activemq

文章目录 一&#xff1a;说明二&#xff1a;e-car项目配置1 引入activemq依赖2 application启动类配置消息监听3 application.yml配置4 MQConfig.java 配置类5 ecar 项目中的监听6 junit 发送消息 三&#xff1a;tcm-chatgpt项目配置5 MQListener.java 监听消息 三 测试启动act…