- Ubuntu 22.04
- anaconda 가상환경
- 전체 에러 메시지
RuntimeError: Error(s) in loading state_dict for ~: size mismatch for ar_local_0.classifier.0.weight_g: copying a param with shape torch.Size([3, 1]) from checkpoint, the shape in current model is torch.Size([2, 1]). size mismatch for ar_local_0.classifier.0.weight_v: copying a param with shape torch.Size([3, 128]) from checkpoint, the shape in current model is torch.Size([2, 128]). size mismatch for ar_local_1.classifier.0.weight_g: copying a param with shape torch.Size([2, 1]) from checkpoint, the shape in current model is torch.Size([3, 1]). size mismatch for ar_local_1.classifier.0.weight_v: copying a param with shape torch.Size([2, 128]) from checkpoint, the shape in current model is torch.Size([3, 128]). size mismatch for ar_local_4.classifier.0.weight_g: copying a param with shape torch.Size([12, 1]) from checkpoint, the shape in current model is torch.Size([10, 1]). size mismatch for ar_local_4.classifier.0.weight_v: copying a param with shape torch.Size([12, 128]) from checkpoint, the shape in current model is torch.Size([10, 128]). size mismatch for ar_local_6.classifier.0.weight_g: copying a param with shape torch.Size([12, 1]) from checkpoint, the shape in current model is torch.Size([8, 1]). size mismatch for ar_local_6.classifier.0.weight_v: copying a param with shape torch.Size([12, 128]) from checkpoint, the shape in current model is torch.Size([8, 128]).
A 라는 데이터셋으로 학습 시키고, B 라는 데이터셋으로 평가를 하려고 하는데,
class의 갯수가 달라서 mismatch가 일어나는 것 처럼 보인다.
없는 layer나 있는 layer을 무시할 때는
모델을 load할 때, load_state_dict(model, strict=False)를 해주면 되지만,
지금은 같은 레이어에서 오류가 나는 경우이므로 이 방법으로 해결되지 않는다.
- 해결방법
- 현재 python interpreter의 site-packages 경로를 찾는다.
- 나의 경우는 anaconda 가상환경을 쓰고 있어서 경로가
home/{username}/anaconda3/env/{environment name}/lib/python3.9/site-packages이었다.
- 그 디렉토리의
torch/nn/modules/module.py를 찾는다.
- 나의 경우
home/{username}/anaconda3/env/{environment name}/lib/python3.9/site-packages/torch/nn/modules/module.py가 된다.
class Module의load_state_dict함수를 찾고 그 밑에 아래 함수를 복붙하여 넣어준다.def on_load_checkpoint(self, checkpoint: dict) -> None: state_dict = checkpoint model_state_dict = self.state_dict() is_changed = False for k in state_dict: if k in model_state_dict: if state_dict[k].shape != model_state_dict[k].shape: logging.info(f"Skip loading parameter: {k}, " f"required shape: {model_state_dict[k].shape}, " f"loaded shape: {state_dict[k].shape}") state_dict[k] = model_state_dict[k] is_changed = True else: logging.info(f"Dropping parameter {k}") is_changed = True if is_changed: checkpoint.pop("optimizer_states", None)저장하고, 기존 코드에서 load_state_dict 로 불러왔던 모델을 on_load_checkpoint 함수로 불러와준다.
'😠 Troubleshooting' 카테고리의 다른 글
| mac에서 삼성 ssd T7 mount 안됨 (ExFAT 형식) (0) | 2023.02.21 |
|---|---|
| Linux 터미널에서 현재 가상환경과 깃 브랜치 출력하게 하기 (0) | 2023.01.27 |
| Ubuntu 22.04에서 TeamViewer 연결 시 '디스플레이 파라미터 초기화 중' (듀얼모니터 사용) (0) | 2023.01.05 |
| tensorboard로 학습 과정 모니터링 하기 (1) | 2022.10.06 |
| VSCode에서 서버 연결하기 (0) | 2022.10.06 |