[Pytorch] model parameter 개수 확인하는 방법

모델의 파라미터 개수를 확인하는 방법은 크게 두가지가 있습니다

 

1. model.parameters() 함수를 이용한 방법

아래는 예시 코드입니다.

import torchvision.models import *

if __name__ == '__main__':
	model = vgg11()
    
    # 학습 가능한 파라미터 개수
    trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # 전체 파라미터 개수
    total_parameters = sum(p.numel() for p in model.parameters())

 

 

2. torchsummary

torchsummary 라이브러리를 이용한 방법인데, 해당 방법은 가끔 작동이 안되는 모델이 있을 수 있습니다.

from torchvision.models import *
from torchsummary import summary

if __name__ == '__main__':
	model = vgg11()
    summary(model, (3, 224, 224), batch_size=8)