​DeepSpeed Activation Checkpointing

技术

DeepSpeed中的激活检查点API可用于启用一系列与激活检查点相关的内存优化。其中包括使用模型并行性时跨gpu的激活分区、CPU检查点、连续内存优化等。

picture.image

如图:

1.上半部分为正常情况下的逻辑子图。T1、T2为Transformer Layer的前向计算部分、子图中每个op计算完成后得到的中间激活特征将持续占用内存,当计算进行到反向时(T1_grad、T2_grad),再利用这些中间激活进行反向的计算;

2.下半部分为开启重计算(activation checkpointing)后的逻辑子图,可以看见,中间部分增加了虚线框住的,用于重计算的fake子图,由于fake子图的存在,正常forward子图在进行前向时,就无须保存中间激活了,当backward计算需要用到时,再临时根据fake子图进行前向的重计算。

https://zhuanlan.zhihu.com/p/373662730

  
    # Optional DeepSpeed Activation Checkpointing Features  
    #  
    if args.deepspeed and args.deepspeed_activation_checkpointing:  
        set_deepspeed_activation_checkpointing(args)  
  
def set\_deepspeed\_activation\_checkpointing(args):  
  
    deepspeed.checkpointing.configure(mpu,  
                            deepspeed_config=args.deepspeed_config,  
                            partition_activation=True)  
  
    mpu.checkpoint = deepspeed.checkpointing.checkpoint  
    mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker  
    mpu.model_parallel_cuda_manual_seed =  
                    deepspeed.checkpointing.model_parallel_cuda_manual_seed  

  
if deepspeed.checkpointing.is_configured():  
    global get_cuda_rng_tracker, checkpoint  
    get_cuda_rng_tracker = deepspeed.checkpoint.get_cuda_rng_tracker  
    checkpoint = deepspeed.checkpointing.checkpoint  

Global Vars

  
# DeepSpeed Checkpointing Enabled or Disabled  
deepspeed_checkpointing_enabled = False  
  
# MP parameters  
mpu = None  
mp_rank = None  
mp_size = None  
mp_group = None  
  
# Model Parameters  
num_layers = None  
  
# Checkpointing buffers  
contiguous_data_buffers = []  
data_offsets = []  
  
contiguous_size_buffers = []  
size_offsets = []  
  
timers = None  
  
# optimization flags  
PARTITION_ACTIVATIONS = False  
CPU_CHECKPOINT = False  
CONTIGUOUS_CHECKPOINTING = False  
SYNCHRONIZE = False  
PROFILE_TIME = False  
  
# Default name for the model parallel rng tracker.  
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'  
transport_stream = None  
cuda_device = None  

extract_tensors & merge_tensors

extract_tensors区分tensor和non-tensor。 merge_tensors还原为原list or tuple。

  
def extract\_tensors(all\_objects):  
    """  
    Separate objects in list/tuple into tensors and non-tensors and create a mapping to enable re-aggregation.  
    The order of tensors and non-tensors is preserved in their respective output groups.  
  
    Parameters:  
        all\_objects (list/tuple): Objects containing tensors and non-tensors to be split.  
  
    Returns:  
        tuple: Containing tensors, non-tensors, and bools of whether each position in original list/tuple was a tensor.  
  
    """  
    tensor_objects = [v for v in all_objects if torch.is_tensor(v)]  
    non_tensor_objects = [v for v in all_objects if not torch.is_tensor(v)]  
    tensor_flags = [torch.is_tensor(v) for v in all_objects]  
    if type(all_objects) is tuple:  
        return tuple(tensor_objects), tuple(non_tensor_objects), tuple(tensor_flags)  
    return tensor_objects, non_tensor_objects, tensor_flags  

  
def merge\_tensors(tensor\_objects, non\_tensor\_objects, tensor\_flags):  
    """  
    Merge two lists (or tuples) of tensors and non-tensors using a mapping of positions in merged list (or tuple).  
  
    Parameters:  
        tensor\_objects (list/tuple): Tensors to merge.  
        non\_tensor\_objects (list/tuple): Non-tensors to merge.  
        tensor\_flags (list/tuple): Indicates whether each position in output is a tensor.  
  
    Returns:  
        tuple: Merge of tensors and non-tensors  
    """  
    merged_objects = []  
    tensor_idx = 0  
    non_tensor_idx = 0  
  
    real_tensor_flags = None  
  
    # remove the flags that are assigned to the size of the flattened tensors  
    if PARTITION_ACTIVATIONS:  
        # (tensor, size, ..., tensor, size)  
        real_tensor_flags = []  
        previous_flag = False  
        for flag in tensor_flags:  
            if previous_flag:  
                previous_flag = False  
                continue  
            previous_flag = flag  
            real_tensor_flags.append(flag)  
    else:  
        real_tensor_flags = tensor_flags  
  
    for is_tensor in real_tensor_flags:  
        if is_tensor:  
            merged_objects.append(tensor_objects[tensor_idx])  
            tensor_idx += 1  
        else:  
            merged_objects.append(non_tensor_objects[non_tensor_idx])  
            non_tensor_idx += 1  
  
    return tuple(merged_objects)  

get_partition_start & get_partition_size

  
# model parallel or tensor parallel  
def get\_partition\_start(item):  
    global mp_rank, mp_size, mp_group  
    size = item.numel()  
    partition_size = size / mp_size  
    start = partition_size * mp_rank  
    return int(start)  
  
  
def get\_partition\_size(item):  
    global mp_rank, mp_size, mp_group  
    size = item.numel()  
    assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size"  
    partition_size = size / mp_size  
    return int(partition_size)  

partition_activations

MP(tensor并行)切分 activations。

  
def partition\_activations(args, cpu\_checkpoint: bool, contiguous\_checkpoint: bool):  
    global contiguous_data_buffers, data_offsets  
  
    inputs = []  
    num_non_fp_tensors = 0 # non-floatpoint  
  
    for arg_index, item in enumerate(args):  
        if not is_activation_to_checkpoint(item):  
            inputs.append(item)  
            num_non_fp_tensors += 1  
            continue  
  
        i = arg_index - num_non_fp_tensors  
        partition_size = get_partition_size(item)  
        partition = item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), partition_size).clone()  
  
        buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device  
  
        if contiguous_checkpoint:  
            # 将tensor放在一起存储。  
            # new mem  
            if i >= len(contiguous_data_buffers):  
                tensor_list = [  
                    torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)  
                    for _ in range(num_layers)  
                ]  
                contiguous_data_buffers.append(tensor_list)  
                data_offsets.append(0)  
            elif contiguous_data_buffers[i] is None:  
                tensor_list = [  
                    torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)  
                    for _ in range(num_layers)  
                ]  
                contiguous_data_buffers[i] = tensor_list  
                data_offsets[i] = 0  
  
            # Because the 'new\_empty' returns uninitialized pages,  
            # the pages need to be populated during the cudaMemcpy time  
            # which increases the data copy time. To avoid this, we  
            # pre-populate these pages by simply writing 0 ahead of  
            # the actual cudaMemcpy operation time. Due to the  
            # previously launched GPU kernels, there is a small  
            # window of time here for CPUs to populate pages asynchronously.  
            contiguous_data_buffers[i][data_offsets[i]].data[range(  
                0, contiguous_data_buffers[i][data_offsets[i]].data.shape[0],  
                int(mmap.PAGESIZE / contiguous_data_buffers[i][data_offsets[i]].data.element_size()))] = 0  
  
            contiguous_partition = contiguous_data_buffers[i][data_offsets[i]].data.copy_(partition.data)  
            data_offsets[i] = data_offsets[i] + 1  
            inputs.append(contiguous_partition)  
        else:  
            partition = partition.cpu() if CPU_CHECKPOINT else partition  
            inputs.append(partition)  
  
    return inputs  

gather_partitioned_activations

聚合切分的act为完整的act。

  
def is\_activation\_to\_checkpoint(item):  
    """  
        Is an activation to be checkpointed  
    """  
    global mp_size  
    return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size  
  
  
# tensors: [item, size, ..., item, size]  
def gather\_partitioned\_activations(tensors, device=None):  
    global mp_rank, mp_size, mp_group  
    assert len(tensors) % 2 == 0, f'Expected even count of tensors, instead got {len(tensors)}'  
    inputs = []  
    num_args = int(len(tensors) / 2)  
    for i in range(num_args):  
  
        item = tensors[2 * i]  
        size = tensors[2 * i + 1]  
  
        if not is_activation_to_checkpoint(item):  
            inputs.append(item)  
            continue  
  
        # don't need to do all\_gather if model parallel is not enabled  
        if mp_group is None or mp_size == 1:  
            item = item.view(list(size.numpy()))  
            inputs.append(item)  
            continue  
  
        partition_size = item.numel()  
        tensor_size = partition_size * mp_size  
        if device is not None:  
            flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device)  
        else:  
            flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device)  
        partitions = []  
        for i in range(mp_size):  
            part_i = flat_tensor.narrow(0, partition_size * i, partition_size)  
            if i == mp_rank:  
                part_i.copy_(item)  
            partitions.append(part_i)  
        dist.all_gather(partitions, partitions[mp_rank], group=mp_group)  
        input_tensor = flat_tensor.view(list(size.numpy()))  
        item.data = input_tensor.data  
  
        inputs.append(item)  
  
    return tuple(inputs)  

get_cpu_activations_for_backward

  
def get\_cpu\_activations\_for\_backward(args, inputs):  
    new_args = []  
    for i, (arg, inp) in enumerate(zip(args, inputs)):  
        if not is_activation_to_checkpoint(arg):  
            new_args.append(arg)  
            continue  
  
        # input.data赋值给arg.data  
        arg.data = inp.data  
        new_args.append(arg)  
  
    return new_args  

get_partitioned_activations_for_backward

  
def get\_partitioned\_activations\_for\_backward(args, inputs, contiguous\_checkpoint: bool):  
    global contiguous_size_buffers, size_offsets  
  
    # (tensor, size, ..., tensor, size)  
    new_args = []  
    num_non_fp_tensors = 0  
  
    for arg_index, (arg, inp) in enumerate(zip(args, inputs)):  
        size = torch.tensor(arg.size()) if torch.is_tensor(arg) else None  
        if not is_activation_to_checkpoint(arg):  
            new_args.append(arg)  
            new_args.append(size)  
            num_non_fp_tensors += 1  
            continue  
  
        arg.data = inp.data  
        new_args.append(arg)  
        i = arg_index - num_non_fp_tensors  
  
        if contiguous_checkpoint:  
            numel = size.numel()  
            if i >= len(contiguous_size_buffers):  
                tmp = torch.tensor(())  
                contiguous_size_buffers.append(  
                    tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device))  
                size_offsets.append(0)  
            elif contiguous_size_buffers[i] is None:  
                tmp = torch.tensor(())  
                contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)  
                size_offsets[i] = 0  
  
            contiguous_size = contiguous_size_buffers[i].narrow(0, size_offsets[i], numel).data.copy_(size.data)  
            contiguous_size = contiguous_size.view_as(size)  
            size_offsets[i] = size_offsets[i] + numel  
            new_args.append(contiguous_size)  
        else:  
            new_args.append(size)  
  
    return new_args  

CheckpointFunction

activation checkpoing 的主函数,通过autograd.Function实现模型前向checkpoint activations,backword时load activations checkpoint 并参数反向计算。

  
class CheckpointFunction(torch.autograd.Function):  
    """This function is adapted from torch.utils.checkpoint with  
       two main changes:  
           1) torch.cuda.set\_rng\_state is replaced with `\_set\_cuda\_rng\_state`  #ignore-cuda  
           2) the states in the model parallel tracker are also properly  
              tracked/set/reset.  
           3) Performance activation partitioning, contiguous memory optimization  
           4) CPU Checkpointing  
           5) Profile forward and backward functions  
    """  

Forward Process

  
    @staticmethod  
    def forward(ctx, run\_function, all\_outputs, *args):  
        global mpu, timers, SYNCHRONIZE, PROFILE_TIME  
  
        def save\_args\_for\_backward(*all\_args):  
            # all\_args拆成tensor,non-tensor,flags保存到ctx中,用于backword计算。  
            tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)  
            ctx.deepspeed_saved_tensors = tensor_args  
            ctx.non_tensor_args = non_tensor_args  
            ctx.tensor_flags = tensor_flags  
  
        if SYNCHRONIZE:  
            get_accelerator().synchronize()  
  
        if timers is None and PROFILE_TIME:  
            timers = Timers()  
  
        if PROFILE_TIME:  
            timers('forward').start()  
  
        ctx.run_function = run_function  
        global num_layers  
        global mp_rank, mp_size, mp_group  
        global contiguous_data_buffers, contiguous_size_buffers  
        global data_offsets, size_offsets  
        # get tensor model parallel infos  
        if mp_rank is None:  
            if mpu is not None:  
                if hasattr(mpu, 'get\_tensor\_model\_parallel\_rank'):  
                    mp_rank = mpu.get_tensor_model_parallel_rank()  
                    mp_size = mpu.get_tensor_model_parallel_world_size()  
                    mp_group = mpu.get_tensor_model_parallel_group()  
                else:  
                    mp_rank = mpu.get_model_parallel_rank()  
                    mp_size = mpu.get_model_parallel_world_size()  
                    mp_group = mpu.get_model_parallel_group()  
            else:  
                mp_rank = 0  
                mp_size = 1  
                mp_group = None  
  
        global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset  
  
        if cuda_device is None:  
            see_memory_usage("First Forward Beginning", force=False)  
            if dist.get_rank() == 0:  
                logger.info(f"Activation Checkpointing Information")  
                logger.info(f"----Partition Activations {PARTITION\_ACTIVATIONS}, CPU CHECKPOINTING {CPU\_CHECKPOINT}")  
                logger.info(  
                    f"----contiguous Memory Checkpointing {CONTIGUOUS\_CHECKPOINTING} with {num\_layers} total layers")  
                logger.info(f"----Synchronization {SYNCHRONIZE}")  
                logger.info(f"----Profiling time in checkpointing {PROFILE\_TIME}")  
  
            cuda_device = get_accelerator().current_device_name()  
            transport_stream = get_accelerator().Stream(device=cuda_device)  
  
        if PARTITION_ACTIVATIONS:  
            # tensor parallel for inputs, tensor one partition  
            inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)  
        elif CPU_CHECKPOINT:  
            # offload acts to cpu, all tensor  
            inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint)  
  
        # just in case something funky is happening such as reuse of inputs  
        inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint)  
  
        # Copy the rng states.  
        ctx.fwd_cpu_rng_state = torch.get_rng_state()  
        ctx.fwd_cuda_rng_state = get_accelerator().get_rng_state()  
        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()  
  
        # ran forwords functions and get outputs, with no\_grad  
        see_memory_usage("Before running forward on the layer", force=False)  
        # ctx.save\_for\_backward(*args)  
        with torch.no_grad():  
            outputs = run_function(*inputs_cuda)  
  
        see_memory_usage("After running forward on the layer", force=False)  
        del inputs_cuda  
  
        # save inputs args for backwords  
        if PARTITION_ACTIVATIONS:  
            new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING)  
            assert len(new_args) % 2 == 0, f'save\_for\_backward called with odd number of args, {len(new\_args)}'  
            save_args_for_backward(*new_args)  
        elif CPU_CHECKPOINT:  
            new_args = get_cpu_activations_for_backward(args, inputs)  
            save_args_for_backward(*new_args)  
        else:  
            save_args_for_backward(*args)  
  
  
        if PROFILE_TIME:  
            timers('forward').stop()  
            timers.log(['forward'])  
        if SYNCHRONIZE:  
            get_accelerator().synchronize()  
  
        # Tensors returned from forward() may not be differentiable. 非float-point得是无法求梯度的。  
        if torch.is_tensor(outputs):  
            non_grad_outputs = [outputs] if not outputs.is_floating_point() else []  
        else:  
            non_grad_outputs = [o for o in outputs if torch.is_tensor(o) and not o.is_floating_point()]  
        ctx.mark_non_differentiable(*non_grad_outputs)  
  
        if torch.is_tensor(outputs):  
            all_outputs += [outputs]  
            return outputs  
        else:  
            all_outputs += outputs  
            outputs, _, _ = extract_tensors(all_objects=outputs)  
            return tuple(outputs)  

Backword Process:

  
    @staticmethod  
    def backward(ctx, *grads):  
        global timers  
        see_memory_usage("In backward", force=False)  
        # removing pointers to the contiguous buffer memory  
        # so that they can be garbage collected once the checkpoints  
        # have been used  
        if SYNCHRONIZE:  
            get_accelerator().synchronize()  
        if PROFILE_TIME:  
            timers('backward').start()  
  
        if CONTIGUOUS_CHECKPOINTING:  
            global data_offsets, size_offsets  
            global contiguous_data_buffers, contiguous_size_buffers  
  
            for buffers in contiguous_data_buffers:  
                buffers = []  
  
            # frees up all the pointers to the checkpoints except for the ones  
            # stored by save for backward  
            contiguous_data_buffers = []  
            contiguous_size_buffers = []  
            data_offsets = []  
            size_offsets = []  
  
        see_memory_usage("In backward checkpointing code", force=False)  
        if not torch.autograd._is_checkpoint_valid():  
            raise RuntimeError("Checkpointing is not compatible with .grad(), "  
                               "please use .backward() if possible")  
  
        global cuda_device, transport_stream, PARTITION_ACTIVATIONS  
  
        if PARTITION_ACTIVATIONS:  
            # with get\_accelerator().stream(transport\_stream):  
            inputs = gather_partitioned_activations(ctx.deepspeed_saved_tensors,  
                                                    device=cuda_device if CPU_CHECKPOINT else None)  
            detached_inputs = detach_variable(inputs)  
        elif CPU_CHECKPOINT:  
            inputs = move_to_device(ctx.deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint)  
            detached_inputs = detach_variable(inputs)  
        else:  
            inputs = ctx.deepspeed_saved_tensors  
            detached_inputs = detach_variable(inputs)  
  
        # Add non tensor input args  
        detached_inputs = merge_tensors(tensor_objects=detached_inputs,  
                                        non_tensor_objects=ctx.non_tensor_args,  
                                        tensor_flags=ctx.tensor_flags)  
  
        # Store the current states.  
        bwd_cpu_rng_state = torch.get_rng_state()  
        bwd_cuda_rng_state = get_accelerator().get_rng_state()  
        bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()  
  
        # Set the states to what it used to be before the forward pass.  
        torch.set_rng_state(ctx.fwd_cpu_rng_state)  
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state)  
        get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)  
  
        # if PARTITION\_ACTIVATIONS:  
        #     current\_stream=get\_accelerator().current\_stream()  
        #     current\_stream.wait\_stream(transport\_stream)  
  
        see_memory_usage("In backward checkpointing code before forward", force=False)  
  
        # recompute activations  
        with torch.enable_grad():  
            outputs = ctx.run_function(*detached_inputs)  
  
        see_memory_usage("In backward checkpointing code after forward", force=False)  
        # Set the states back to what it was at the start of this function.  
        torch.set_rng_state(bwd_cpu_rng_state)  
        _set_cuda_rng_state(bwd_cuda_rng_state)  
        get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)  
  
        if isinstance(outputs, torch.Tensor):  
            outputs = (outputs, )  
  
        # Filter out non tensor outputs  
        outputs, _, _ = extract_tensors(all_objects=outputs)  
  
        # Construct arguments to autograd.backward().  
        # This is usually just outputs and grads, but forward() can return tensors that  
        # are not differentiable.  
        output_tensors = []  
        grad_tensors = []  
        for out, grad in zip(outputs, grads):  
            if out.requires_grad:  
                output_tensors.append(out)  
                grad_tensors.append(grad)  
  
        see_memory_usage("In backward checkpointing code before backward", force=False)  
  
        torch.autograd.backward(output_tensors, grad_tensors)  
  
        # Force clear our stashed tensors to prevent a memory leak in certain scenarios  
        ctx.deepspeed_saved_tensors = None  
        ctx.non_tensor_args = None  
        ctx.tensor_flags = None  
  
        see_memory_usage("After backward checkpointing code after backward", force=False)  
  
        if PROFILE_TIME:  
            timers('backward').stop()  
            timers.log(['backward'])  
        if SYNCHRONIZE:  
            get_accelerator().synchronize()  
        ret_list = [None, None]  # first None for ctx  
        for inp in detached_inputs:  
            if torch.is_tensor(inp):  
                ret_list.append(inp.grad)  
            else:  
                ret_list.append(None)  
  
        return tuple(ret_list)  
  

checkpoint

activations checkpoint 入口。

  
def checkpoint(function, *args):  
    """Checkpoint a model or part of the model.  
    This has been directly copied from torch.utils.checkpoint. """  
  
    all_outputs = []  
    CheckpointFunction.apply(function, all_outputs, *args)  
    if len(all_outputs) == 1:  
        return all_outputs[0]  
    else:  
        return tuple(all_outputs)  

test_activation_checkpointing

  
ckpt = deepspeed.checkpointing.checkpoint  
  
  
def \_compute(module, *inputs, do\_checkpoint=False):  
    if do_checkpoint:  
        outputs = ckpt(module, *inputs)  
    else:  
        outputs = module(*inputs)  
  
    if torch.is_tensor(outputs):  
        outputs = (outputs, )  
  
    sum(o.sum() for o in outputs if torch.is_tensor(o) and o.requires_grad).backward()  
  
    grads = [p.grad for p in module.parameters()]  
    input_grads = [inp.grad for inp in inputs if torch.is_tensor(inp)]  
  
    return {  
        'outputs': outputs,  
        'module\_grads': grads,  
        'input\_grads': input_grads,  
    }  
  
def \_test\_activation\_checkpoint(module, *inputs):  
    # Move to device  
    module.to(get_accelerator().device_name())  
  
    # Get rid of dropouts until we fork the RNG between tests.  
    module.eval()  
  
    module_ = deepcopy(module)  
    inputs_ = _prep_inputs(*inputs)  
    base = _compute(module_, *inputs_, do_checkpoint=False)  
  
    module_ = deepcopy(module)  
    inputs_ = _prep_inputs(*inputs)  
    test = _compute(module_, *inputs_, do_checkpoint=True)  
  
    for group in base.keys():  
        for b, t in zip(base[group], test[group]):  
            _match_outputs(b, t)  

参课文献

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
融合开放,新一代边缘云网络平台 | 第 11 期边缘云主题Meetup
《融合开放,新一代边缘云网络平台 》李冰|火山引擎边缘云网络产品负责人
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论