😠 Troubleshooting

RuntimeError: Error(s) in loading state_dict for ~ : size mismatch for ~

MIMDOING 2023. 2. 7. 14:09
  • Ubuntu 22.04
  • anaconda 가상환경

  • 전체 에러 메시지
    RuntimeError: Error(s) in loading state_dict for ~:
          size mismatch for ar_local_0.classifier.0.weight_g: copying a param with shape torch.Size([3, 1]) from checkpoint, the shape in current model is torch.Size([2, 1]).
          size mismatch for ar_local_0.classifier.0.weight_v: copying a param with shape torch.Size([3, 128]) from checkpoint, the shape in current model is torch.Size([2, 128]).
          size mismatch for ar_local_1.classifier.0.weight_g: copying a param with shape torch.Size([2, 1]) from checkpoint, the shape in current model is torch.Size([3, 1]).
          size mismatch for ar_local_1.classifier.0.weight_v: copying a param with shape torch.Size([2, 128]) from checkpoint, the shape in current model is torch.Size([3, 128]).
          size mismatch for ar_local_4.classifier.0.weight_g: copying a param with shape torch.Size([12, 1]) from checkpoint, the shape in current model is torch.Size([10, 1]).
          size mismatch for ar_local_4.classifier.0.weight_v: copying a param with shape torch.Size([12, 128]) from checkpoint, the shape in current model is torch.Size([10, 128]).
          size mismatch for ar_local_6.classifier.0.weight_g: copying a param with shape torch.Size([12, 1]) from checkpoint, the shape in current model is torch.Size([8, 1]).
          size mismatch for ar_local_6.classifier.0.weight_v: copying a param with shape torch.Size([12, 128]) from checkpoint, the shape in current model is torch.Size([8, 128]).

A 라는 데이터셋으로 학습 시키고, B 라는 데이터셋으로 평가를 하려고 하는데,
class의 갯수가 달라서 mismatch가 일어나는 것 처럼 보인다.

없는 layer나 있는 layer을 무시할 때는
모델을 load할 때, load_state_dict(model, strict=False)를 해주면 되지만,
지금은 같은 레이어에서 오류가 나는 경우이므로 이 방법으로 해결되지 않는다.

  • 해결방법
  1. 현재 python interpreter의 site-packages 경로를 찾는다.
  • 나의 경우는 anaconda 가상환경을 쓰고 있어서 경로가 home/{username}/anaconda3/env/{environment name}/lib/python3.9/site-packages이었다.
  1. 그 디렉토리의 torch/nn/modules/module.py 를 찾는다.
  • 나의 경우 home/{username}/anaconda3/env/{environment name}/lib/python3.9/site-packages/torch/nn/modules/module.py가 된다.
  1. class Moduleload_state_dict 함수를 찾고 그 밑에 아래 함수를 복붙하여 넣어준다.

     def on_load_checkpoint(self, checkpoint: dict) -> None:
         state_dict = checkpoint
         model_state_dict = self.state_dict()
         is_changed = False
         for k in state_dict:
             if k in model_state_dict:
                 if state_dict[k].shape != model_state_dict[k].shape:
                     logging.info(f"Skip loading parameter: {k}, "
                                 f"required shape: {model_state_dict[k].shape}, "
                                 f"loaded shape: {state_dict[k].shape}")
                     state_dict[k] = model_state_dict[k]
                     is_changed = True
             else:
                 logging.info(f"Dropping parameter {k}")
                 is_changed = True
    
         if is_changed:
             checkpoint.pop("optimizer_states", None)
  2. 저장하고, 기존 코드에서 load_state_dict 로 불러왔던 모델을 on_load_checkpoint 함수로 불러와준다.