PytorchでDataParallelしたモデルの保存と読み込み

発端

PytorchでDataParallelしたモデルにシングルGPUで学習したパラメタを読み込もうとしたらエラーが出た。

原因

PytorchではDataParallelでモデルを包むと元のモデルがself.moduleに格納される。
そのためDataParallel後のモデルから見るとパラメタ名全てに"module."がついて元のモデルと対応が取れなくなる。
参考リンク

解決方法

モデルの保存時も読み込み時もDataParallelしてたらmoduleを参照すればいい。

モデル保存

if multi_gpu:
    model=model.module
torch.save(model.state_dict(),save_path)

モデル読み込み

if multi_gpu:
    model=model.module
model.load_state_dict(torch.load(load_path))