前言
最近在学习迁移学习,ADDA算法,由于嫌自己写麻烦,准备先跑通别人的代码。
代码名称:pytorch-adda-master
博客:https://www.cnblogs.com/BlairGrowing/p/17020378.html
github地址:https://github.com/corenel/pytorch-adda
源代码的配置环境为 python3.6 PyTorch 0.2.0
为了方便多个代码放在一起,我直接用我的环境
python3.6,pytorch 1.13.0
这导致了代码报错非常多,经过修改之后终于可以运行。。。
由于修改之处非常多,所以将修改之处放在CSDN里面,以便自己观看
报错一
刚开始运行报错
urllib.error.URLError: 「urlopen error [Errno 11004] getaddrinfo failed」
参考链接,修改DNS就行了,具体修改步骤在这里
报错二
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]
这是因为都是mnist数据集的灰度图片需要转变为RGB图片,也就是通道数需要从1变成3,参考链接
修改方法:将mnist.py、usps.py中的
pre_process = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=params.dataset_mean,std=params.dataset_std)])
修改成:
pre_process = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
报错三
IndexError: invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number
这应该是因为python、pytorch版本的不同导致的,参考链接
具体修改就是将报错的.data[0]
修改成 .item()
报错四
RuntimeError: result type Float can't be cast to the desired output type Long
数据类型算法有问题,参考链接
将 acc /= len(data_loader.dataset)
修改成
acc = acc/len(data_loader.dataset)