저장하기나 불러오기를 통해 모델의 상태를 유지(persist)하고 모델의 예측을 실행하는 방법을 알아보자.
모델 가중치 저장하고 불러오기
PyTorch 모델은 학습한 매개변수를 state_dict라고 불리는 내부 상태 사전(internal state dictionary)에 저장한다. 이 상태 값들은 torch.save 메소드를 사용하여 저장(persist)할 수 있다.
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
모델 가중치를 불러오기 위해서는, 먼저 동일한 모델의 인스턴스(instance)를 생성한 다음에 load_state_dict() 메소드를 사용하여 매개변수들을 불러온다.
model = models.vgg16() # 기본 가중치를 불러오지 않으므로 pretrained=True를 지정하지 않습니다.
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
추론(inference)을 하기 전에 model.eval() 메소드를 호출하여 드롭아웃(dropout)과 배치 정규화(batch normalization)를 평가 모드(evaluation mode)로 설정해야 한다. 그렇지 않으면 일관성 없는 추론 결과가 생성된다.
모델의 형태를 포함하여 저장하고 불러오기
모델의 가중치를 불러올 때, 신경망의 구조를 정의하기 위해 모델 클래스를 먼저 생성(instantiate)해야 했다. 이 클래스의 구조를 모델과 함께 저장하고 싶으면, (model.state_dict()가 아닌) model 을 저장 함수에 전달한다.
torch.save(model, 'model.pth')
다음과 같이 모델을 불러올 수 있습니다.
model = torch.load('model.pth')
이 접근 방식은 Python pickle 모듈을 사용하여 모델을 직렬화(serialize)하므로, 모델을 불러올 때 실제 클래스 정의(definition)를 적용(rely on)한다.
참고: https://tutorials.pytorch.kr/beginner/basics/saveloadrun_tutorial.html
'Python > PyTorch' 카테고리의 다른 글
[Pytroch] Check point 저장하고 불러오기 (0) | 2021.09.28 |
---|---|
[Pytorch] 모델 매개변수 최적화하기 (0) | 2021.09.28 |
[Pytorch] COCO Data format과 Pycocotools (0) | 2021.09.27 |
[Pytorch] TORCH.AUTOGRAD를 사용한 자동 미분 (0) | 2021.09.26 |
[Pytorch] 신경망 모델 구성하기 (0) | 2021.09.25 |