NumPy中的transpose()函数用于对数组进行转置:
如果函数中不传递任何参数,它将进行标准的矩阵转置;
如果传递了一个轴序列,NumPy将按照这个序列重新排列轴。
二维的转置很好理解,就是线性代数中的矩阵转置。但高纬度的转置就有点抽象了,反正我一开始确实没太懂它是怎么进行转置/轴变换的。
numpy官方的文档没什么用,就不贴出来了,想看的自己去看。
下面是GPT的解释:
在 NumPy 中,多维数组的 `transpose` 通过重新排列步长(`strides`)来实现轴的重新排列。以下是具体的操作流程:
1. **原始数据和步长**:NumPy 数组的底层数据是连续存储的,每个维度的步长决定了沿该维度移动一个元素的字节数。
2. **获取形状和步长**:`transpose` 函数首先获取数组的当前形状和步长。
3. **根据新轴顺序重新排列**:假设我们有一个 3D 数组 `A`,形状是 `(2, 3, 4)`,步长是 `(48, 16, 4)`。如果新轴顺序是 `(1, 2, 0)`,则 `transpose` 会按顺序排列新步长 `(16, 4, 48)` 和新形状 `(3, 4, 2)`。
4. **创建视图而非复制数据**:`transpose` 会根据新步长和新形状生成一个新的视图,这样读取时将按新的轴顺序访问数据,但数据在内存中的实际位置没有改变。
这种操作方式能高效地实现多维数组的轴变换。
在 NumPy 中,步长(`strides`)表示在每个维度上移动一个元素所需的字节数。对于多维数组,步长帮助确定如何在内存中访问元素。比如,一个数组的步长指示从一个元素跳到下一个元素时,应该跳过多少字节。步长在执行诸如转置(`transpose`)等操作时非常重要,因为它决定了如何根据新的轴顺序访问原始数据而不需复制。
简单来说,轴重新排列就是把原来的轴的读取顺序换成了新的轴的读取顺序,然后又改回了原来(0,1,2,3)顺序的表现。可能还是有点抽象,我直接举个具体的例子来说。
a = np.array([5,15,8,41,39,30,39,18,23,42,25,13,15,6,36,25,14,4,42,20,44,3,19,7,24,36,45,38,14,47,23,42,18,31,8,2,20,21,41,8,8,2,11,33,32,31,32,47]).reshape(2,3,2,4)
a是一个四维的矩阵。打印出来的a是这样的
[[[[ 5 15 8 41][39 30 39 18]][[23 42 25 13][15 6 36 25]][[14 4 42 20][44 3 19 7]]][[[24 36 45 38][14 47 23 42]][[18 31 8 2][20 21 41 8]][[ 8 2 11 33][32 31 32 47]]]]
那么,a.transpose是什么样的呢?
[[[[ 5, 24],[23, 18],[14, 8]],[[39, 14],[15, 20],[44, 32]]],[[[15, 36],[42, 31],[ 4, 2]],[[30, 47],[ 6, 21],[ 3, 31]]],[[[ 8, 45],[25, 8],[42, 11]],[[39, 23],[36, 41],[19, 32]]],[[[41, 38],[13, 2],[20, 33]],[[18, 42],[25, 8],[ 7, 47]]]]
问题来了,这个转置后的a的轴的顺序是什么样的?答案是(3,2,1,0)
不过如果我们在不知道答案的情况下,怎么看出来这个答案呢?
首先,我们以(0,0,0,0)为起点往四根轴看。
3号轴 [5 15 8 41]
2号轴 [5,39]
1号轴 [5 23 14]
0号轴 [5 24]
这应该很容易能看出来。如果不知道轴怎么排的,我在文末有补充。
然后我们看下转置后的4根轴
3号轴 [5 24]
2号轴 [5 23 14]
1号轴 [5 39]
0号轴 [5 15 8 41]
也就是说原来的3号轴现在变成了0号,2号变成了1号,1号变成了2号,0号变成了3号。所以答案是(3,2,1,0)。可以验证:
所以如果我们需要将轴重新排列,也可以用同样的方法进行,只要将主要的几根轴变完了,其他元素按相对位置填进去就可以了。
下面是我用C++实现的transpose。虽然我感觉也许可能会更难理解?只有少量的必要的注释,结合前文自己理解吧,这注释确实不太好写
#include <bits/stdc++.h>
using namespace std;
void printArray(int *arr, const int len, const int dim, int *dims, int *axis)
{
// printf("dims:");for (int i=0; i<dim; i++) printf("%d%c", dims[i], i==dim-1?'\n':' ');int sufMul[dim]; //后缀乘积 用于计算每个维度的步长strideint idx=0;sufMul[dim-1]=1;for (int i=dim-2; i>=0; i--){sufMul[i] = sufMul[i+1]*dims[i+1];}
// printf("sufMul:");for (int i=0; i<dim; i++) printf("%d%c", sufMul[i], i==dim-1?'\n':' ');int stride[dim]; // 步长stride,即沿某一维度走一步,在底层的一维数组移动了多少步 for (int i=0; i<dim; ++i){stride[i] = sufMul[axis[i]];}
// printf("stride:");for (int i=0; i<dim; i++) printf("%d%c", stride[i], i==dim-1?'\n':' ');int newDim[dim]; // 轴变换后,新的每个轴的长度 for (int i=0; i<dim; ++i){newDim[i] = dims[axis[i]];}
// printf("newDim:");for (int i=0; i<dim; i++) printf("%d%c", newDim[i], i==dim-1?'\n':' ');int newSufMul[dim]; // 轴变换后的后缀乘积,只是用于格式打印输出换行 newSufMul[dim-1]=1;for (int i=dim-2; i>=0; i--){newSufMul[i] = newSufMul[i+1]*newDim[i+1];}
// printf("sufMul:");for (int i=0; i<dim; i++) printf("%d%c", newSufMul[i], i==dim-1?'\n':' ');idx = 0; // idx表示输出到第几个元素 while (idx < len){int index=0, tmp=idx, i=dim; // index表示该元素在arr中的下标 while (i--){index += (tmp%newDim[i]) * stride[i]; // tmp%newDim[i] 表示在某一维度的下标// vec.push_back(tmp%newDim[i]) // idx新轴序下的坐标 tmp /= newDim[i];}printf("%d,", arr[index]);
// printf("index=%d, idx=%d\n", index, idx);for (int t=0; t<dim-1; ++t){
// printf("**t=%d, sumMul[t]=%d, dix+1=%d**", t, sufMul[t], idx);if (((idx+1) % newSufMul[t]) == 0){for (int i=0; i<dim-t-1; i++) printf("\n");break;}}idx++;}
}
int main()
{srand(time(0));int dim;printf("input dimension:");scanf("%d", &dim);int dims[dim];int len=1;printf("input shape(split by space):");for (int i=0; i<dim; ++i){scanf("%d", &dims[i]);len *= dims[i];}printf("input %d numbers(split by space):", len);int arr[len];for (int i=0; i<len; ++i){
// scanf("%d", &arr[i]);arr[i] = rand()%50;}for (int i=0; i<len; i++) printf("%d%c", arr[i], i==len-1?'\n':',');int axis[dim];for (int i=0; i<dim; i++) axis[i]=i;printArray(arr, len, dim, dims, axis); //打印原始的形状 printf("input axis[0-%d](split by space):", dim-1);for (int i=0; i<dim; i++){scanf("%d", &axis[i]);}printArray(arr, len, dim, dims, axis); // 打印重排轴之后的形状
}
或许我应该放个python的实现会更合适一点?这里就先挖个坑下次再填吧
如果不太明白轴的顺序,我简单说一下。最先填充的方向轴序号最大,最后填充的方向是0轴。
就像一维是从左到右填充的,最后填充的方向就是从左到右的,所以从左到右就是0号轴。
二维是先从左到右,然后从上到下填充的,最先是从左到右的,所以从左到右是1号轴。一行填完之后从上往下填,所以从上到下是0号轴。
同理,三维的填充顺序就是先填完一层二维的,然后从前往后填充,所以前后方向是0号轴,每一层的填充顺序与二维一致,所以二维的轴的编号加个一就是三维里的编号了。
更高维的也是一样的道理,新的方向是0轴,原来的轴就依次加一。