티스토리 뷰

개발/Python

pytorch GPU 설정 주피터노트북

시크시크시크 2023. 1. 17. 16:32

1. GPU 사용량 확인

watch -n -1 nvidia-smi

1초마다 GPU 사용량 확인

2. pytorch CUDA SETTING

import os

import torch

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

os.environ["CUDA_VISIBLE_DEVICES"] = "GPU 넘버"

ex) GPU 0,1,2,3 중에 1만 사용 

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

ex)GPU 0,1,2,3 다 사용

os.environ["CUDA_VISIBLE_DEVICES"] = "GPU 넘버"

3. cuda setting 확인 

device = torch.device("cpu") if torch.cuda.is_available() else torch.device('cpu')

print(device) 파이토치에서 CPU사용할지 GPU 사용할지 확인

print(torch.cuda.current_device()) 현재 할당된 GPU 넘버

print(torch.cuda.deivce_count()) 사용할 수 있는 GPU 갯수 

 

4. model train

model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

model.to(device)