想在 pytorch 中实现一个张量变换
输入是
输出是
这是我在草稿纸上演算的结果,想在 pytorch 中高效实现。于是求助 chatGPT。
一开始我用语言描述了一下我想实现的功能,chatGPT 给出了结果,看着是对的,不过漏掉了中间维度。
我不想用 for 循环,因为它太慢了。
这个代码我试了,已经和我想要的答案很接近了,不过可能还需要变换一步。
结果不对。
换个思路,直接告诉它输入输出。
终于给出了想要的答案。原来这个操作只需要一步 .transpose(0, 1)
。我对 pytorch 中这么多张量变换还不熟,让我自己想可能要很久了。