工程系列 | 从使用角度学习DDP

摘要

为了尽可能加快训练,我们会使用两种GPU并行手段,DP和DDP,但是DP其实只会开一个进程去管理,计算资源分配不均,DDP上我们倾向于一张卡开一个进程,使得我们的计算资源能够最大化的利用。本次的文章会快速地形象地过一下DP和DDP并且告诉大家如何代码层面上实践。

DP(你只需要知道它垃圾就行,可以跳过)

picture.image

从图中我们可以看出,在forward环节,gpu1会先把所有的数据拿到,然后分发给其他的gpu,当然它自己也拿一份,接着它把自己的模型也复制成4份,每个gpu也拿一份,每个gpu自己跑自己的forward,跑完后将output传给gpu1,gpu1处理所有的output对应的梯度,然后进行backward,将要反向传播的梯度分配给其他的gpu,然后其他的gpu又各自进行自己的反向计算,计算完后将最后的梯度交给gpu1进行更新。我们可以看到,在gpu1分配任务和更新的时候,其实其他的gpu其实都是闲置的,所以利用率没法上来,全部人都得等gpu1。那么我们可不可以想一种新方法来让每个gpu自己拿到数据后,自己跑前后向,而且自己更新梯度呢?DDP这不就来了嘛!

DDP

picture.image 秉着尽量少理论,多形象的原则,加速理解,看图。我们将我们的数据以一个一个的batch传入网络,我们有两台machine,两台machine上各有两台gpu。每台gpu上都有自己的model(都是同一个model的复制品)和optimizer。每次来一个batch的数据,我们都会让Distributed sampler去将数据分配好发给指定的gpu,然后gpu们自己跑自己的,跑完前向后,每个gpu通过DDP的后端通讯可以知道其他所有gpu跑的结果,同步了所有gpu的梯度,拿到所有的信息后就吭哧吭哧自己去反向传播更新梯度。DDP就这么简单。

DDP代码实践

代码实践以前需要交代的概念

picture.image


        
        
            

          # 1. 导包:一些需要导入的库  
# 模型相关  
from torch.nn.parallel import DistributedDataParallel as DDP  
# 数据相关  
from torch.utils.data.distributed import DistributedSampler  
# ddp自身的机制相关  
import torch.distributed as dist  
  
# 2.后端多卡通讯及GPU序号(RANK)  
if DDP\_ON:  
    init\_process\_group(backend="nccl")  
    LOCAL\_RANK = device\_id = int(os.environ["LOCAL\_RANK"])  
    WORLD\_SIZE = torch.cuda.device\_count()  
  
    device = torch.device('cuda', device\_id) # note that device\_id is an integer but device is a datetype.  
    print(f"Start running basic DDP on rank {LOCAL\_RANK}.")  
    logging.info(f'Using device {device\_id}')  
  
# 3. DDP model  
net = DDP(net, device\_ids = [device\_id], output\_device=device\_id)  
  
  
# 4.喂数据给多卡  
loader\_args = dict(batch\_size=batch\_size, num\_workers=WORLD\_SIZE*4, pin\_memory=True) # batchsize is for a single proc  
if DDP\_ON:  
    train\_sampler = DistributedSampler(train\_set)  
    train\_loader = DataLoader(train\_set, sampler=train\_sampler, **loader\_args)  
else:  
    train\_loader = DataLoader(train\_set, shuffle=True, **loader\_args)  
      
# no need for distributed sampler for val  
val\_loader = DataLoader(val\_set, shuffle=False, drop\_last=True, **loader\_args)  
  
  
# 5.set\_epoch 防止每次数据都是一样的(如下图)  
# ref: https://blog.csdn.net/weixin\_41978699/article/details/121742647  
for epoch in range(start, start+epochs):  
    if LOCAL\_RANK == 0:  
        print('lr: ', optimizer.param\_groups[0]['lr'])   
  
    net.train()  
    epoch\_loss = 0  
  
    # To avoid duplicated data sent to multi-gpu  
    train\_loader.sampler.set\_epoch(epoch)
            

        
      

picture.image

如何启动


        
        
            

          torchrun --nproc\_per\_node=4 \  
          multigpu\_torchrun.py \  
          --batch\_size 4 \  
          --lr 1e-3
            

        
      

torchrun 是torch 1.11后的版本才可以用,以前的版本需要用


        
        
            

          python -m torch.distributed.launch \  
      --nproc\_per\_node = 4 \  
        train.py \  
      --batch\_size 4
            

        
      

参考资料

pytorch官方

  1. https://pytorch.org/tutorials/beginner/ddp\_series\_multigpu.html#multi-gpu-training-with-ddp

  2. https://pytorch.org/tutorials/beginner/ddp\_series\_multigpu.html

完整代码布局可参考


        
        
            

          import argparse  
import logging  
import sys  
from pathlib import Path  
  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
import wandb  
from torch import optim  
from torch.utils.data import DataLoader, random\_split  
from tqdm import tqdm  
  
from utils.data\_loading import BasicDataset, CarvanaDataset  
from utils.dice\_score import dice\_loss  
from evaluate import evaluate  
from unet import UNet  
import os  
import torch.distributed as dist  
  
# for reproducibility  
import random  
import numpy as np  
import torch.backends.cudnn as cudnn  
  
# ABOUT DDP  
# for model loading in ddp mode  
from torch.nn.parallel import DistributedDataParallel as DDP  
# for data loading in ddp mode  
from torch.utils.data.distributed import DistributedSampler  
  
import torch.multiprocessing as mp  
from torch.distributed import init\_process\_group, destroy\_process\_group  
  
  
  
def init\_seeds(seed=0, cuda\_deterministic=True):  
    random.seed(seed)  
    np.random.seed(seed)  
    torch.manual\_seed(seed)  
    # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html  
    if cuda\_deterministic:  # slower, more reproducible  
        cudnn.deterministic = True  
        cudnn.benchmark = False  
    else:  # faster, less reproducible  
        cudnn.deterministic = False  
        cudnn.benchmark = True  
  
def train\_net(net,  
              device,  
              start: int = 0,  
              epochs: int = 5,  
              batch\_size: int = 1,  
              learning\_rate: float = 1e-5,  
              val\_percent: float = 0.1,  
              save\_checkpoint: bool = True,  
              img\_scale: float = 0.5,  
              amp: bool = False):  
      
  
    if DDP\_ON: # modify the net's attributes when using ddp  
        net.n\_channels = net.module.n\_channels  
        net.n\_classes  = net.module.n\_classes  
  
    # 1. Create dataset  
    try:  
        dataset = CarvanaDataset(dir\_img, dir\_mask, img\_scale)  
    except (AssertionError, RuntimeError):  
        dataset = BasicDataset(dir\_img, dir\_mask, img\_scale)  
  
    # 2. Split into train / validation partitions  
    n\_val = int(len(dataset) * val\_percent)  
    n\_train = len(dataset) - n\_val  
    train\_set, val\_set = random\_split(dataset, [n\_train, n\_val], generator=torch.Generator().manual\_seed(0))  
  
    # 3. Create data loaders  
    loader\_args = dict(batch\_size=batch\_size, num\_workers=WORLD\_SIZE*4, pin\_memory=True) # batchsize is for a single process(GPU)  
  
    if DDP\_ON:  
        train\_sampler = DistributedSampler(train\_set)  
        train\_loader = DataLoader(train\_set, sampler=train\_sampler, **loader\_args)  
    else:  
        train\_loader = DataLoader(train\_set, shuffle=True, **loader\_args)  
      
      
    # no need for distributed sampler for val  
    val\_loader = DataLoader(val\_set, shuffle=False, drop\_last=True, **loader\_args)  
      
    # (Initialize logging)  
    if LOCAL\_RANK == 0:  
        experiment = wandb.init(project='U-Net-DDP', resume='allow', anonymous='must')  
        experiment.config.update(dict(epochs=epochs, batch\_size=batch\_size, learning\_rate=learning\_rate,  
                                  val\_percent=val\_percent, save\_checkpoint=save\_checkpoint, img\_scale=img\_scale,  
                                  amp=amp))  
              
        logging.info(f'''Starting training:  
                Epochs:          {epochs}  
                Start from:      {start}  
                Batch size:      {batch\_size}  
                Learning rate:   {learning\_rate}  
                Training size:   {n\_train}  
                Validation size: {n\_val}  
                Checkpoints:     {save\_checkpoint}  
                Device:          {device.type}  
                Images scaling:  {img\_scale}  
                Mixed Precision: {amp}  
            ''')  
  
    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP  
    criterion = nn.CrossEntropyLoss()   
      
    optimizer = optim.AdamW(net.parameters(), lr=learning\_rate, weight\_decay=1e-8)  
    scheduler = optim.lr\_scheduler.CosineAnnealingLR(optimizer, T\_max=epochs, eta\_min=1e-7)  
    grad\_scaler = torch.cuda.amp.GradScaler(enabled=amp)  
    global\_step = 0  
  
    # 5. Begin training  
    for epoch in range(start, start+epochs):  
        if LOCAL\_RANK == 0:  
            print('lr: ', optimizer.param\_groups[0]['lr'])   
          
        net.train()  
        epoch\_loss = 0  
  
        # To avoid duplicated data sent to multi-gpu  
        train\_loader.sampler.set\_epoch(epoch)  
  
        disable = False if LOCAL\_RANK == 0 else True  
  
        with tqdm(total=n\_train, desc=f'Epoch {epoch}/{epochs+start}', unit='img', disable=disable) as pbar:  
            for batch in train\_loader:  
                images = batch['image']  
                true\_masks = batch['mask']  
                      
                assert images.shape[1] == net.n\_channels, \  
                    f'Network has been defined with {net.n\_channels} input channels, ' \  
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \  
                    'the images are loaded correctly.'  
  
                images = images.to(device=device, dtype=torch.float32)  
                true\_masks = true\_masks.to(device=device, dtype=torch.long)  
  
                with torch.cuda.amp.autocast(enabled=amp):  
                    masks\_pred = net(images)  
                    loss = criterion(masks\_pred, true\_masks) \  
                           + dice\_loss(F.softmax(masks\_pred, dim=1).float(),  
                                       F.one\_hot(true\_masks, net.n\_classes).permute(0, 3, 1, 2).float(),  
                                       multiclass=True)  
  
                optimizer.zero\_grad(set\_to\_none=True)  
                grad\_scaler.scale(loss).backward()  
                grad\_scaler.step(optimizer)  
                grad\_scaler.update()  
  
                pbar.update(images.shape[0])  
                global\_step += 1  
                epoch\_loss += loss.item()  
  
                if LOCAL\_RANK == 0:  
                    experiment.log({  
                        'train loss': loss.item(),  
                        'step': global\_step,  
                        'epoch': epoch  
                    })  
                pbar.set\_postfix(**{'loss (batch)': loss.item()})  
  
                # Evaluation round  
                division\_step = (n\_train // (5 * batch\_size))  
                if division\_step > 0:  
                    if global\_step % division\_step == 0:  
                        histograms = {}  
                        for tag, value in net.named\_parameters():  
                            tag = tag.replace('/', '.')  
                            if not torch.isinf(value).any():  
                                histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())  
                            if not torch.isinf(value.grad).any():  
                                histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())  
  
                        val\_score = evaluate(net, val\_loader, device, disable\_log = disable)  
  
                        if LOCAL\_RANK == 0:  
                            logging.info('Validation Dice score: {}'.format(val\_score))  
                            experiment.log({  
                                'learning rate': optimizer.param\_groups[0]['lr'],  
                                'validation Dice': val\_score,  
                                'images': wandb.Image(images[0].cpu()),  
                                'masks': {  
                                    'true': wandb.Image(true\_masks[0].float().cpu()),  
                                    'pred': wandb.Image(masks\_pred.argmax(dim=1)[0].float().cpu()),  
                                },  
                                'step': global\_step,  
                                'epoch': epoch,  
                                **histograms  
                            })  
        scheduler.step()  
        if save\_checkpoint and LOCAL\_RANK == 0 and (epoch % args.save\_every == 0):  
            Path(dir\_checkpoint).mkdir(parents=True, exist\_ok=True)  
            torch.save(net.module.state\_dict(), str(dir\_checkpoint / 'DDP\_checkpoint\_epoch{}.pth'.format(epoch)))  
              
            logging.info(f'Checkpoint {epoch} saved!')  
  
  
##################################### arguments ###########################################  
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')  
parser.add\_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')  
parser.add\_argument('--batch-size', '-b', dest='batch\_size', metavar='B', type=int, default=1, help='Batch size')  
parser.add\_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,  
                    help='Learning rate', dest='lr')  
parser.add\_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')  
parser.add\_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')  
parser.add\_argument('--validation', '-v', dest='val', type=float, default=10.0,  
                    help='Percent of the data that is used as validation (0-100)')  
parser.add\_argument('--amp', action='store\_true', default=False, help='Use mixed precision')  
parser.add\_argument('--bilinear', action='store\_true', default=False, help='Use bilinear upsampling')  
parser.add\_argument('--classes', '-c', type=int, default=2, help='Number of classes')  
parser.add\_argument('--exp\_name', type=str, default='hgb\_exp')  
parser.add\_argument('--ddp\_mode', action='store\_true')  
parser.add\_argument('--save\_every', type=int, default=5)  
parser.add\_argument('--start\_from', type=int, default=0)  
  
  
  
  
args = parser.parse\_args()  
  
dir\_img = Path('./data/imgs/')  
dir\_mask = Path('./data/masks/')  
dir\_checkpoint = Path('./checkpoints/')  
  
DDP\_ON = True if args.ddp\_mode else False  
  
#########################################################################################  
  
if DDP\_ON:  
    init\_process\_group(backend="nccl")  
    LOCAL\_RANK = device\_id = int(os.environ["LOCAL\_RANK"])  
    WORLD\_SIZE = torch.cuda.device\_count()  
  
    device = torch.device('cuda', device\_id) # note that device\_id is an integer but device is a datetype.  
    print(f"Start running basic DDP on rank {LOCAL\_RANK}.")  
    logging.info(f'Using device {device\_id}')  
  
  
if \_\_name\_\_ == '\_\_main\_\_':  
    #!highly recommended]  
    # ref: pytorch org ddp tutorial   
    # 1. https://pytorch.org/tutorials/beginner/ddp\_series\_multigpu.html#multi-gpu-training-with-ddp  
    # 2. https://pytorch.org/tutorials/beginner/ddp\_series\_multigpu.html  
      
    init\_seeds(0)  
    # Change here to adapt to your data  
    # n\_channels=3 for RGB images  
    # n\_classes is the number of probabilities you want to get per pixel  
    net = UNet(n\_channels=3, n\_classes=args.classes, bilinear=args.bilinear)  
      
    if LOCAL\_RANK == 0:  
        print(f'Network:\n'  
            f'\t{net.n\_channels} input channels\n'  
            f'\t{net.n\_classes} output channels (classes)\n'  
            f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')  
  
    if args.load:  
        # ref: https://blog.csdn.net/hustwayne/article/details/120324639  use method 2 with module  
        # net.load\_state\_dict(torch.load(args.load, map\_location=device))  
        net.load\_state\_dict({k.replace('module.', ''): v for k, v in                   
                       torch.load(args.load, map\_location=device).items()})  
  
        logging.info(f'Model loaded from {args.load}')  
  
  
    torch.cuda.set\_device(LOCAL\_RANK)  
    net.to(device=device)  
    # wrap our model with ddp  
    net = DDP(net, device\_ids = [device\_id], output\_device=device\_id)  
  
    try:  
        train\_net(net=net,  
                  start=args.start\_from,  
                  epochs=args.epochs,  
                  batch\_size=args.batch\_size,  
                  learning\_rate=args.lr,  
                  device=device,  
                  img\_scale=args.scale,  
                  val\_percent=args.val / 100,  
                  amp=args.amp)  
    except KeyboardInterrupt:  
        torch.save(net.module.state\_dict(), 'INTERRUPTED\_DDP.pth')  
        logging.info('Saved interrupt')  
        raise  
    destroy\_process\_group()  

        
      
0
0
0
0
评论
未登录
暂无评论