想查看其他题的真题及题解的同学可以前往查看:CCF-CSP真题附题解大全
试题编号: | 202305-2 |
试题名称: | 矩阵运算 |
时间限制: | 5.0s |
内存限制: | 512.0MB |
问题描述: | 题目背景Softmax(Q×KTd)×V 是 Transformer 中注意力模块的核心算式,其中 Q、K 和 V 均是 n 行 d 列的矩阵,KT 表示矩阵 K 的转置,× 表示矩阵乘法。 问题描述为了方便计算,顿顿同学将 Softmax 简化为了点乘一个大小为 n 的一维向量 W: 现给出矩阵 Q、K 和 V 和向量 W,试计算顿顿按简化的算式计算的结果。 输入格式从标准输入读入数据。 输入的第一行包含空格分隔的两个正整数 n 和 d,表示矩阵的大小。 接下来依次输入矩阵 Q、K 和 V。每个矩阵输入 n 行,每行包含空格分隔的 d 个整数,其中第 i 行的第 j 个数对应矩阵的第 i 行、第 j 列。 最后一行输入 n 个整数,表示向量 W。 输出格式输出到标准输出中。 输出共 n 行,每行包含空格分隔的 d 个整数,表示计算的结果。 样例输入
样例输出
子任务70 的测试数据满足:n≤100 且 d≤10;输入矩阵、向量中的元素均为整数,且绝对值均不超过 30。 全部的测试数据满足:n≤104 且 d≤20;输入矩阵、向量中的元素均为整数,且绝对值均不超过 1000。 提示请谨慎评估矩阵乘法运算后的数值范围,并使用适当数据类型存储矩阵中的整数。 |
真题来源:矩阵运算
感兴趣的同学可以如此编码进去进行练习提交
思路讲解:
这道题也不难,再纸上推一下规律就能找到循环去计算的规律。这道题的重点在于时间复杂度,如果先算QK矩阵相乘,会得到n * n的矩阵,会显示超时,所以要先算后面两个矩阵,时间复杂度是可以过的。
python满分题解:
n, d = map(int, input().split())
Q = [[i for i in map(int, input().split())] for j in range(n)]
K = [[i for i in map(int, input().split())] for j in range(n)]
V = [[i for i in map(int, input().split())] for j in range(n)]
W = [i for i in map(int, input().split())]
tmp = []
ans = []# 计算 K的转置 * V = tmp
for i in range(d):tmp.append([])for j in range(d):tmp[i].append(0)for k in range(n):tmp[i][j] += K[k][i]*V[k][j]# 计算 Q * tmp = ans
for i in range(n):ans.append([])for j in range(d):ans[i].append(0)for k in range(d):ans[i][j] += Q[i][k]*tmp[k][j]ans[i][j] *= W[i]for i in range(n):a = []for j in range(d):a.append(ans[i][j])print(*a)
运行结果:
c++满分题解:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 10010, D = 30;
LL tmp[D][D], ans[N][N];
int n, d;
int Q[N][D], K[N][D], V[N][D], W[N];
int main()
{cin >> n >> d;for (int i = 1; i <= n; i ++)for (int j = 1; j <= d; j ++)cin >> Q[i][j];for (int i = 1; i <= n; i ++)for (int j = 1; j <= d; j ++)cin >> K[i][j];for (int i = 1; i <= n; i ++)for (int j = 1; j <= d; j ++)cin >> V[i][j];for (int i = 1; i <= n; i ++) cin >> W[i];// 计算 K的转置 * V = tmpfor (int i = 1; i <= d; i ++)for (int j = 1; j <= d; j ++)for (int k = 1; k <= n; k ++)tmp[i][j] += K[k][i] * V[k][j];// 计算 Q * tmp = ansfor (int i = 1; i <= n; i ++)for (int j = 1; j <= d; j ++){for (int k = 1; k <= d; k ++)ans[i][j] += Q[i][k] * tmp[k][j];ans[i][j] *= (LL) W[i];}for (int i = 1; i <= n; i ++){for (int j = 1; j <= d; j ++)cout << ans[i][j] << " ";cout << endl;}return 0;
}
运行结果: