github:zhanghr2001/PromptTA: Source-free Domain Generalization
论文:[2409.14163] PromptTA: Prompt-driven Text Adapter for Source-free Domain Generalization
自己标注:PromptTA: Prompt-driven Text Adapter for Source-free Domain Generalizationhttps://github.com/Unchanged-Originality/Unchanged-Originality/blob/main/Prompt-driven%20Text%20Adapter%20for%20Source-free%20Domain%20Generalization.pdf
流程图:
这篇论文主要是受 Domain-Unified Prompt Representations for Source-Free Domain Generalization和PromptStyler: Prompt-driven Style Generation for Source-free Domain Generalization启发。
github:这两篇论文我读标记
复现结果:
PACS | VLCS | OfficeHome | DomainNet | |
ResNet-50 | 94.0 | 83.3 | 74.1 | |
ViT-B/16 | 97.4 | 84.9 | 82.9 |
PACS:
ResNet-50:
A:94.4 94.4 94.9 平均:94.733
C:95.3 95.2 95.2 95.233
P:99.5 99.5 99.5 99.5
s:86.5 86.4 86.7 86.567
平均 top - 1 准确率94.0
ViT-B/16:
a:97.9 97.9 98 97.933
C: 99.0 99.1 99.1 99.067
P:99.9 99.9 99.9 99.9
s:92.4 93.0 93.6 93
平均 top - 1 准确率97.475
VLCS:
ResNet-50:
C:100 100 100 平均:100
L:70.9 70.8 69.5 70.4
P:88.4 87.4 87.2 87.667
S:75.8 74.8 74.8 75.133
平均 top - 1 准确率83.3
ViT-B/16:
C:100 100 100 平均:100
L:71.3 74.7 75 73.667
P:89.8 90.4 90.1 90.1
S:77.1 75.3 76.0 76.133
平均 top - 1 准确率84.9
OfficeHome
ResNet-50:
A:73.3 73.3 73.5平均:73.367
C:55.3 55.2 55.0 55.167
P:84.2 84.3 83.9 84.133
R: 84.1 84.1 83.9 84.033
平均 top - 1 准确率74.1
ViT-B/16:
A:81.5 81.7 81.7 平均:81.633
C:70.1 70.0 70.6 70.233
P:89.7 90.0 89.8 89.833
r:90.0 90.3 90.2 90.167
平均 top - 1 准确率82.9
报错:
按照github配置环境过程中没出现问题。
AttributeError: module 'torch.utils.data' has no attribute 'collate'
改了半天,发现是原作者的引用错了
Traceback (most recent call last):
File "train.py", line 10, in <module>
from trainers import *
File "/opt/data/private/promptta/trainers/__init__.py", line 4, in <module>
from .prompt_ta import PROMPT_TA
File "/opt/data/private/promptta/trainers/prompt_ta.py", line 12, in <module>
from torch.utils.data import *
AttributeError: module 'torch.utils.data' has no attribute 'collate'
把prompt_ta.py文件中的 from torch.utils.data import *注释掉,改成
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import TensorDataset
然后就可以正常运行了。
相关配置
1、运行命令
bash scripts/prompt_ta/main_ta_all.sh pacs b128_ep50_pacs ViT-B/16 0 ./resume
bash scripts/prompt_ta/main_ta_all.sh pacs b128_ep50_pacs RN50 0 ./resume
bash scripts/prompt_ta/main_ta_all.sh vlcs b128_ep50_vlcs ViT-B/16 1 ./resume
bash scripts/prompt_ta/main_ta_all.sh vlcs b128_ep50_vlcs RN50 1 ./resume
bash scripts/prompt_ta/main_ta_all.sh office_home b128_ep50_officehome ViT-B/16 1 ./resume
bash scripts/prompt_ta/main_ta_all.sh office_home b128_ep50_officehome RN50 1 ./resume
bash scripts/prompt_ta/main_ta_all.sh domainnet b128_ep50_domainnet ViT-B/16 1 ./resume
bash scripts/prompt_ta/main_ta_all.sh office_home b128_ep50_domainnet RN50 1 ./resume
2、配置环境版本
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch