AI & DL/Pytorch Lightning

[Pytorch Lightning] 튜토리얼 2 - 모델을 저장하고 불러오자

Giliit 2024. 7. 21. 21:08
728x90

이번에는 훈련한 모델을 저장하는 방법과 불러오는 방법에 대해 알아보려고 한다. 1편에서 사용한 코드를 거의 그대로 사용하며, Trainer선언 부분만 바뀌는 정도이며, 전체 코드는 깃허브에 있으니 참고하기 바란다.

요약

  • 모델을 특정경로에 체크포인트 저장하기
  • 모델의 하이퍼파라미터 저장하기
  • 체크포인트로 모델을 불러오기

 

모델을 특정경로에 체크포인트 저장하기

모델을 저장하는 것은 각 epoch마다 자동으로 저장된다. Jupyter 기준으로 설명하면, Jupyter 파일의 같은 디렉터리 내에 lightning_logs폴더에 각 validation_step마다 저장되는 것을 알 수 있다. 하지만, 나는 여기서 특정 경로에 저장하는 방법에 대해 소개하려고 한다.

# 이전 코드
trainer = pl.Trainer(max_epochs=2)

위의 코드는 1편에서 사용된 코드이다. 여기서 저장할 특정경로를 지정하지 않았기 때문에 실행파일과 동일한 디렉토리에 모델이 저장된다. 

trainer = pl.Trainer(
    max_epochs=2,
    # 저장 경로 지정
    default_root_dir='ckpt/'
    )

위의 코드는 수정된 코드이며, 'ckpt/' 경로에 모델이 저장되는 것이다. 그림으로 살펴보자.

ckpt 폴더 내에 무언가 저장된 것을 볼 수 있는데 각각 요소에 대해 설명하겠다.

  • version_숫자 : n번째에 실행했을때 저장되는 것들이다. 
  • hparams.yaml : hyperparameter들이 저장되어 있는 yaml 파일이다. 이에 대해서는 밑에서 다루도록 하겠다.
  • metrics.csv : 우리가 측정하려는 metric들이 각 epoch마다 csv 형태로 저장되는 파일이다.
  • checkpoints : 모델이 훈련하고 각 epoch가 끝날 때마다 얼마나 훈련되어 있는지 저장되어 있는 파일이다. 이 파일을 통해 어떠한 오류로 실험이 중단되면 저 ~~~.ckpt 파일을 읽어와서 실험을 진행할 수 있다. (그래서 사용한다.)

version_1에서의 저장되어있는 하이퍼파라미터 없음

우리는 모델의 파라미터만 저장하고 하이퍼 파라미터는 저장하지 않았다. 그래서 하이퍼 파라미터를 저장하는 방법에 대해 다루려고 한다. 

 

모델의 하이퍼파라미터 저장하기

하이퍼 파라미터를 저장하는 방법은 간단하다. 코드 한줄을 추가하면 된다.

class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        # 코드 추가
        self.save_hyperparameters()
        
        
        self.encoder = encoder
        self.decoder = decoder

# 코드 중략

self.save_hyperparameters() 이 함수를 하나 추가하면 된다.  그리고 코드를 돌리면(훈련하면) 다음과 같이 나온다.

version_2에서 하이퍼파라미터가 저장됨

 

체크포인트로 모델을 불러오기

model = LitAutoEncoder(Encoder(), Decoder())
trainer = pl.Trainer()

# 수정 코드
trainer.fit(model,
            train_dataloaders=train_loader,
            val_dataloaders=valid_loader,
            # 체크포인트 저장한 경로
            ckpt_path='ckpt/lightning_logs/version_0/checkpoints/epoch=1-step=96000.ckpt')

여기서 Trainer에서는 건드릴 것이 없다. 그 이유는 ckpt파일에 정보가 저장되어 있다. 그래서 필요한 것은 다음과 같다.

  • model : 훈련시킬 모델의 틀
  • dataloader: 훈련시킬 데이터를 다시 가져와야 함
  • ckpt_path = 훈련되어있는 parameter를 불러올 경로

 

다음 것들을 알면, 이제부터 훈련이 날아가서 처음부터 다시 돌리는 경우가 안 생긴다. 파이토치 라이트닝을 사용한다면, 꼭 이것을 알아주었으면 한다. 다음은 OverFitting을 방지하는 Early stopping 방법에 대해 소개하려고 한다.