본문 바로가기
파이썬 & 머신러닝

[Pytorch] 모델 '적절하게' 저장하기 (+ 여러 모델)

by 두재 2020. 4. 15.

파이토치에는 모델을 저장하는 방법이 여러 가지 있습니다.

단순히

1. torch.save(path, model)

이렇게 하면 모델이 통채로 저장이 되기는 합니다.

2. torch.save(path, model.state_dict())

이렇게 하면 모델 안에 있는 파라미터들을 저장하는 것입니다.

 

첫번째 방식의 경우 

1. model = torch.load(path)

이렇게 하면 load가 가능하고 두 번째 방식의 경우에는 모델을 선언도 해줘야 하고 모델의 구조가 다르면 안됩니다.

model = MODEL().cuda()

model.load_state_dict(torch.load(path))

이렇게 모델을 선언하고, 그 모델에 파라미터들을 대입하는 과정이 필요합니다.

 

아무리 생각해도 첫 번째 방식이 편한 것 같지만 저도 이유는 정확히 모르겠지만 pytorch 사이트에서는 state_dict를 저장하는 것을 추천합니다 (웬만해서는 이렇게 하라고 합니다).

 

 


'적절하게' 저장을 하기 위해서,

혹시 저장한 모델을 나중에 fine-tuning을 하거나 아예 추가로 학습을 시키고 싶다고 하시면 모델만 저장하는 것으로는 부족합니다.

바로 optimizer에 있는 변수들이 없으면 모델을 저장하던 당시의 optimizer와 다시 학습을 재개하는 optimizer가 다르기 때문에 학습이 잘못 됩니다. 학습이 되긴 합니다만 잘못 되겠죠.

때문에 이러한 함수를 사용합니다.

def save_checkpoint(epoch, model, optimizer, filename):
    state = {
        'Epoch': epoch,
        'State_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    torch.save(state, filename)

딕셔너리 형식으로 저장할 당시의 에포크와 모델의 파라미터, 옵티마이저의 파라미터를 하나의 데이터로 만들어 이를 저장하는 것입니다. 모두다 Serialize 가능한 데이터기 때문에 문제 없이 저장됩니다.

 

추가로 모델이 여러 개일 경우,

저는 여러 모델을 저장하기 위해서 함수를 살짝 바꿨는데, 잘 됩니다.

def save_checkpoint(epoch, model, optimizer, filename):
	ms, os = [], []
    for m in model:
    	ms.append(m.state_dict())
    for o in optimizer:
    	os.append(o.state_dict())
    state = {
    	'Epoch' = epcoh,
        'State_dict' : ms,
        'optimizer' : os
    }
    torch.save(state, filename)

 

아마 아실 수 있겠지만 불러오는 과정은 간편합니다.

model = MODEL().cuda() --> 모델 선언하기

optimizer = optim.~~~  --> 옵티마이저 선언하기

checkpoint = torch.load(path)

model.load_state_dict(checkpoint['State_dict'])

optimizer.load_state_dict(checkpoint['optimizer'])