PytorchでDataParallelしたモデルの保存と読み込み
発端
PytorchでDataParallelしたモデルにシングルGPUで学習したパラメタを読み込もうとしたらエラーが出た。
原因
PytorchではDataParallelでモデルを包むと元のモデルがself.moduleに格納される。
そのためDataParallel後のモデルから見るとパラメタ名全てに"module."がついて元のモデルと対応が取れなくなる。
参考リンク
解決方法
モデルの保存時も読み込み時もDataParallelしてたらmoduleを参照すればいい。
モデル保存
if multi_gpu:
model=model.module
torch.save(model.state_dict(),save_path)
モデル読み込み
if multi_gpu:
model=model.module
model.load_state_dict(torch.load(load_path))