摘要
为了尽可能加快训练,我们会使用两种GPU并行手段,DP和DDP,但是DP其实只会开一个进程去管理,计算资源分配不均,DDP上我们倾向于一张卡开一个进程,使得我们的计算资源能够最大化的利用。本次的文章会快速地形象地过一下DP和DDP并且告诉大家如何代码层面上实践。
DP(你只需要知道它垃圾就行,可以跳过)
从图中我们可以看出,在forward环节,gpu1会先把所有的数据拿到,然后分发给其他的gpu,当然它自己也拿一份,接着它把自己的模型也复制成4份,每个gpu也拿一份,每个gpu自己跑自己的forward,跑完后将output传给gpu1,gpu1处理所有的output对应的梯度,然后进行backward,将要反向传播的梯度分配给其他的gpu,然后其他的gpu又各自进行自己的反向计算,计算完后将最后的梯度交给gpu1进行更新。我们可以看到,在gpu1分配任务和更新的时候,其实其他的gpu其实都是闲置的,所以利用率没法上来,全部人都得等gpu1。那么我们可不可以想一种新方法来让每个gpu自己拿到数据后,自己跑前后向,而且自己更新梯度呢?DDP这不就来了嘛!
DDP
秉着尽量少理论,多形象的原则,加速理解,看图。我们将我们的数据以一个一个的batch传入网络,我们有两台machine,两台machine上各有两台gpu。每台gpu上都有自己的model(都是同一个model的复制品)和optimizer。每次来一个batch的数据,我们都会让Distributed sampler去将数据分配好发给指定的gpu,然后gpu们自己跑自己的,跑完前向后,每个gpu通过DDP的后端通讯可以知道其他所有gpu跑的结果,同步了所有gpu的梯度,拿到所有的信息后就吭哧吭哧自己去反向传播更新梯度。DDP就这么简单。
DDP代码实践
代码实践以前需要交代的概念
# 1. 导包:一些需要导入的库
# 模型相关
from torch.nn.parallel import DistributedDataParallel as DDP
# 数据相关
from torch.utils.data.distributed import DistributedSampler
# ddp自身的机制相关
import torch.distributed as dist
# 2.后端多卡通讯及GPU序号(RANK)
if DDP\_ON:
init\_process\_group(backend="nccl")
LOCAL\_RANK = device\_id = int(os.environ["LOCAL\_RANK"])
WORLD\_SIZE = torch.cuda.device\_count()
device = torch.device('cuda', device\_id) # note that device\_id is an integer but device is a datetype.
print(f"Start running basic DDP on rank {LOCAL\_RANK}.")
logging.info(f'Using device {device\_id}')
# 3. DDP model
net = DDP(net, device\_ids = [device\_id], output\_device=device\_id)
# 4.喂数据给多卡
loader\_args = dict(batch\_size=batch\_size, num\_workers=WORLD\_SIZE*4, pin\_memory=True) # batchsize is for a single proc
if DDP\_ON:
train\_sampler = DistributedSampler(train\_set)
train\_loader = DataLoader(train\_set, sampler=train\_sampler, **loader\_args)
else:
train\_loader = DataLoader(train\_set, shuffle=True, **loader\_args)
# no need for distributed sampler for val
val\_loader = DataLoader(val\_set, shuffle=False, drop\_last=True, **loader\_args)
# 5.set\_epoch 防止每次数据都是一样的(如下图)
# ref: https://blog.csdn.net/weixin\_41978699/article/details/121742647
for epoch in range(start, start+epochs):
if LOCAL\_RANK == 0:
print('lr: ', optimizer.param\_groups[0]['lr'])
net.train()
epoch\_loss = 0
# To avoid duplicated data sent to multi-gpu
train\_loader.sampler.set\_epoch(epoch)
如何启动
torchrun --nproc\_per\_node=4 \
multigpu\_torchrun.py \
--batch\_size 4 \
--lr 1e-3
torchrun 是torch 1.11后的版本才可以用,以前的版本需要用
python -m torch.distributed.launch \
--nproc\_per\_node = 4 \
train.py \
--batch\_size 4
参考资料
pytorch官方
-
https://pytorch.org/tutorials/beginner/ddp\_series\_multigpu.html#multi-gpu-training-with-ddp
-
https://pytorch.org/tutorials/beginner/ddp\_series\_multigpu.html
完整代码布局可参考
import argparse
import logging
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch import optim
from torch.utils.data import DataLoader, random\_split
from tqdm import tqdm
from utils.data\_loading import BasicDataset, CarvanaDataset
from utils.dice\_score import dice\_loss
from evaluate import evaluate
from unet import UNet
import os
import torch.distributed as dist
# for reproducibility
import random
import numpy as np
import torch.backends.cudnn as cudnn
# ABOUT DDP
# for model loading in ddp mode
from torch.nn.parallel import DistributedDataParallel as DDP
# for data loading in ddp mode
from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp
from torch.distributed import init\_process\_group, destroy\_process\_group
def init\_seeds(seed=0, cuda\_deterministic=True):
random.seed(seed)
np.random.seed(seed)
torch.manual\_seed(seed)
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
if cuda\_deterministic: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
cudnn.deterministic = False
cudnn.benchmark = True
def train\_net(net,
device,
start: int = 0,
epochs: int = 5,
batch\_size: int = 1,
learning\_rate: float = 1e-5,
val\_percent: float = 0.1,
save\_checkpoint: bool = True,
img\_scale: float = 0.5,
amp: bool = False):
if DDP\_ON: # modify the net's attributes when using ddp
net.n\_channels = net.module.n\_channels
net.n\_classes = net.module.n\_classes
# 1. Create dataset
try:
dataset = CarvanaDataset(dir\_img, dir\_mask, img\_scale)
except (AssertionError, RuntimeError):
dataset = BasicDataset(dir\_img, dir\_mask, img\_scale)
# 2. Split into train / validation partitions
n\_val = int(len(dataset) * val\_percent)
n\_train = len(dataset) - n\_val
train\_set, val\_set = random\_split(dataset, [n\_train, n\_val], generator=torch.Generator().manual\_seed(0))
# 3. Create data loaders
loader\_args = dict(batch\_size=batch\_size, num\_workers=WORLD\_SIZE*4, pin\_memory=True) # batchsize is for a single process(GPU)
if DDP\_ON:
train\_sampler = DistributedSampler(train\_set)
train\_loader = DataLoader(train\_set, sampler=train\_sampler, **loader\_args)
else:
train\_loader = DataLoader(train\_set, shuffle=True, **loader\_args)
# no need for distributed sampler for val
val\_loader = DataLoader(val\_set, shuffle=False, drop\_last=True, **loader\_args)
# (Initialize logging)
if LOCAL\_RANK == 0:
experiment = wandb.init(project='U-Net-DDP', resume='allow', anonymous='must')
experiment.config.update(dict(epochs=epochs, batch\_size=batch\_size, learning\_rate=learning\_rate,
val\_percent=val\_percent, save\_checkpoint=save\_checkpoint, img\_scale=img\_scale,
amp=amp))
logging.info(f'''Starting training:
Epochs: {epochs}
Start from: {start}
Batch size: {batch\_size}
Learning rate: {learning\_rate}
Training size: {n\_train}
Validation size: {n\_val}
Checkpoints: {save\_checkpoint}
Device: {device.type}
Images scaling: {img\_scale}
Mixed Precision: {amp}
''')
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(net.parameters(), lr=learning\_rate, weight\_decay=1e-8)
scheduler = optim.lr\_scheduler.CosineAnnealingLR(optimizer, T\_max=epochs, eta\_min=1e-7)
grad\_scaler = torch.cuda.amp.GradScaler(enabled=amp)
global\_step = 0
# 5. Begin training
for epoch in range(start, start+epochs):
if LOCAL\_RANK == 0:
print('lr: ', optimizer.param\_groups[0]['lr'])
net.train()
epoch\_loss = 0
# To avoid duplicated data sent to multi-gpu
train\_loader.sampler.set\_epoch(epoch)
disable = False if LOCAL\_RANK == 0 else True
with tqdm(total=n\_train, desc=f'Epoch {epoch}/{epochs+start}', unit='img', disable=disable) as pbar:
for batch in train\_loader:
images = batch['image']
true\_masks = batch['mask']
assert images.shape[1] == net.n\_channels, \
f'Network has been defined with {net.n\_channels} input channels, ' \
f'but loaded images have {images.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
images = images.to(device=device, dtype=torch.float32)
true\_masks = true\_masks.to(device=device, dtype=torch.long)
with torch.cuda.amp.autocast(enabled=amp):
masks\_pred = net(images)
loss = criterion(masks\_pred, true\_masks) \
+ dice\_loss(F.softmax(masks\_pred, dim=1).float(),
F.one\_hot(true\_masks, net.n\_classes).permute(0, 3, 1, 2).float(),
multiclass=True)
optimizer.zero\_grad(set\_to\_none=True)
grad\_scaler.scale(loss).backward()
grad\_scaler.step(optimizer)
grad\_scaler.update()
pbar.update(images.shape[0])
global\_step += 1
epoch\_loss += loss.item()
if LOCAL\_RANK == 0:
experiment.log({
'train loss': loss.item(),
'step': global\_step,
'epoch': epoch
})
pbar.set\_postfix(**{'loss (batch)': loss.item()})
# Evaluation round
division\_step = (n\_train // (5 * batch\_size))
if division\_step > 0:
if global\_step % division\_step == 0:
histograms = {}
for tag, value in net.named\_parameters():
tag = tag.replace('/', '.')
if not torch.isinf(value).any():
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
if not torch.isinf(value.grad).any():
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
val\_score = evaluate(net, val\_loader, device, disable\_log = disable)
if LOCAL\_RANK == 0:
logging.info('Validation Dice score: {}'.format(val\_score))
experiment.log({
'learning rate': optimizer.param\_groups[0]['lr'],
'validation Dice': val\_score,
'images': wandb.Image(images[0].cpu()),
'masks': {
'true': wandb.Image(true\_masks[0].float().cpu()),
'pred': wandb.Image(masks\_pred.argmax(dim=1)[0].float().cpu()),
},
'step': global\_step,
'epoch': epoch,
**histograms
})
scheduler.step()
if save\_checkpoint and LOCAL\_RANK == 0 and (epoch % args.save\_every == 0):
Path(dir\_checkpoint).mkdir(parents=True, exist\_ok=True)
torch.save(net.module.state\_dict(), str(dir\_checkpoint / 'DDP\_checkpoint\_epoch{}.pth'.format(epoch)))
logging.info(f'Checkpoint {epoch} saved!')
##################################### arguments ###########################################
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add\_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
parser.add\_argument('--batch-size', '-b', dest='batch\_size', metavar='B', type=int, default=1, help='Batch size')
parser.add\_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
help='Learning rate', dest='lr')
parser.add\_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
parser.add\_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
parser.add\_argument('--validation', '-v', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
parser.add\_argument('--amp', action='store\_true', default=False, help='Use mixed precision')
parser.add\_argument('--bilinear', action='store\_true', default=False, help='Use bilinear upsampling')
parser.add\_argument('--classes', '-c', type=int, default=2, help='Number of classes')
parser.add\_argument('--exp\_name', type=str, default='hgb\_exp')
parser.add\_argument('--ddp\_mode', action='store\_true')
parser.add\_argument('--save\_every', type=int, default=5)
parser.add\_argument('--start\_from', type=int, default=0)
args = parser.parse\_args()
dir\_img = Path('./data/imgs/')
dir\_mask = Path('./data/masks/')
dir\_checkpoint = Path('./checkpoints/')
DDP\_ON = True if args.ddp\_mode else False
#########################################################################################
if DDP\_ON:
init\_process\_group(backend="nccl")
LOCAL\_RANK = device\_id = int(os.environ["LOCAL\_RANK"])
WORLD\_SIZE = torch.cuda.device\_count()
device = torch.device('cuda', device\_id) # note that device\_id is an integer but device is a datetype.
print(f"Start running basic DDP on rank {LOCAL\_RANK}.")
logging.info(f'Using device {device\_id}')
if \_\_name\_\_ == '\_\_main\_\_':
#!highly recommended]
# ref: pytorch org ddp tutorial
# 1. https://pytorch.org/tutorials/beginner/ddp\_series\_multigpu.html#multi-gpu-training-with-ddp
# 2. https://pytorch.org/tutorials/beginner/ddp\_series\_multigpu.html
init\_seeds(0)
# Change here to adapt to your data
# n\_channels=3 for RGB images
# n\_classes is the number of probabilities you want to get per pixel
net = UNet(n\_channels=3, n\_classes=args.classes, bilinear=args.bilinear)
if LOCAL\_RANK == 0:
print(f'Network:\n'
f'\t{net.n\_channels} input channels\n'
f'\t{net.n\_classes} output channels (classes)\n'
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
if args.load:
# ref: https://blog.csdn.net/hustwayne/article/details/120324639 use method 2 with module
# net.load\_state\_dict(torch.load(args.load, map\_location=device))
net.load\_state\_dict({k.replace('module.', ''): v for k, v in
torch.load(args.load, map\_location=device).items()})
logging.info(f'Model loaded from {args.load}')
torch.cuda.set\_device(LOCAL\_RANK)
net.to(device=device)
# wrap our model with ddp
net = DDP(net, device\_ids = [device\_id], output\_device=device\_id)
try:
train\_net(net=net,
start=args.start\_from,
epochs=args.epochs,
batch\_size=args.batch\_size,
learning\_rate=args.lr,
device=device,
img\_scale=args.scale,
val\_percent=args.val / 100,
amp=args.amp)
except KeyboardInterrupt:
torch.save(net.module.state\_dict(), 'INTERRUPTED\_DDP.pth')
logging.info('Saved interrupt')
raise
destroy\_process\_group()
