DeepSpeed中的激活检查点API可用于启用一系列与激活检查点相关的内存优化。其中包括使用模型并行性时跨gpu的激活分区、CPU检查点、连续内存优化等。
如图:
1.上半部分为正常情况下的逻辑子图。T1、T2为Transformer Layer的前向计算部分、子图中每个op计算完成后得到的中间激活特征将持续占用内存,当计算进行到反向时(T1_grad、T2_grad),再利用这些中间激活进行反向的计算;
2.下半部分为开启重计算(activation checkpointing)后的逻辑子图,可以看见,中间部分增加了虚线框住的,用于重计算的fake子图,由于fake子图的存在,正常forward子图在进行前向时,就无须保存中间激活了,当backward计算需要用到时,再临时根据fake子图进行前向的重计算。
# Optional DeepSpeed Activation Checkpointing Features
#
if args.deepspeed and args.deepspeed_activation_checkpointing:
set_deepspeed_activation_checkpointing(args)
def set\_deepspeed\_activation\_checkpointing(args):
deepspeed.checkpointing.configure(mpu,
deepspeed_config=args.deepspeed_config,
partition_activation=True)
mpu.checkpoint = deepspeed.checkpointing.checkpoint
mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
mpu.model_parallel_cuda_manual_seed =
deepspeed.checkpointing.model_parallel_cuda_manual_seed
if deepspeed.checkpointing.is_configured():
global get_cuda_rng_tracker, checkpoint
get_cuda_rng_tracker = deepspeed.checkpoint.get_cuda_rng_tracker
checkpoint = deepspeed.checkpointing.checkpoint
Global Vars
# DeepSpeed Checkpointing Enabled or Disabled
deepspeed_checkpointing_enabled = False
# MP parameters
mpu = None
mp_rank = None
mp_size = None
mp_group = None
# Model Parameters
num_layers = None
# Checkpointing buffers
contiguous_data_buffers = []
data_offsets = []
contiguous_size_buffers = []
size_offsets = []
timers = None
# optimization flags
PARTITION_ACTIVATIONS = False
CPU_CHECKPOINT = False
CONTIGUOUS_CHECKPOINTING = False
SYNCHRONIZE = False
PROFILE_TIME = False
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
transport_stream = None
cuda_device = None
extract_tensors & merge_tensors
extract_tensors区分tensor和non-tensor。 merge_tensors还原为原list or tuple。
def extract\_tensors(all\_objects):
"""
Separate objects in list/tuple into tensors and non-tensors and create a mapping to enable re-aggregation.
The order of tensors and non-tensors is preserved in their respective output groups.
Parameters:
all\_objects (list/tuple): Objects containing tensors and non-tensors to be split.
Returns:
tuple: Containing tensors, non-tensors, and bools of whether each position in original list/tuple was a tensor.
"""
tensor_objects = [v for v in all_objects if torch.is_tensor(v)]
non_tensor_objects = [v for v in all_objects if not torch.is_tensor(v)]
tensor_flags = [torch.is_tensor(v) for v in all_objects]
if type(all_objects) is tuple:
return tuple(tensor_objects), tuple(non_tensor_objects), tuple(tensor_flags)
return tensor_objects, non_tensor_objects, tensor_flags
def merge\_tensors(tensor\_objects, non\_tensor\_objects, tensor\_flags):
"""
Merge two lists (or tuples) of tensors and non-tensors using a mapping of positions in merged list (or tuple).
Parameters:
tensor\_objects (list/tuple): Tensors to merge.
non\_tensor\_objects (list/tuple): Non-tensors to merge.
tensor\_flags (list/tuple): Indicates whether each position in output is a tensor.
Returns:
tuple: Merge of tensors and non-tensors
"""
merged_objects = []
tensor_idx = 0
non_tensor_idx = 0
real_tensor_flags = None
# remove the flags that are assigned to the size of the flattened tensors
if PARTITION_ACTIVATIONS:
# (tensor, size, ..., tensor, size)
real_tensor_flags = []
previous_flag = False
for flag in tensor_flags:
if previous_flag:
previous_flag = False
continue
previous_flag = flag
real_tensor_flags.append(flag)
else:
real_tensor_flags = tensor_flags
for is_tensor in real_tensor_flags:
if is_tensor:
merged_objects.append(tensor_objects[tensor_idx])
tensor_idx += 1
else:
merged_objects.append(non_tensor_objects[non_tensor_idx])
non_tensor_idx += 1
return tuple(merged_objects)
get_partition_start & get_partition_size
# model parallel or tensor parallel
def get\_partition\_start(item):
global mp_rank, mp_size, mp_group
size = item.numel()
partition_size = size / mp_size
start = partition_size * mp_rank
return int(start)
def get\_partition\_size(item):
global mp_rank, mp_size, mp_group
size = item.numel()
assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size"
partition_size = size / mp_size
return int(partition_size)
partition_activations
MP(tensor并行)切分 activations。
def partition\_activations(args, cpu\_checkpoint: bool, contiguous\_checkpoint: bool):
global contiguous_data_buffers, data_offsets
inputs = []
num_non_fp_tensors = 0 # non-floatpoint
for arg_index, item in enumerate(args):
if not is_activation_to_checkpoint(item):
inputs.append(item)
num_non_fp_tensors += 1
continue
i = arg_index - num_non_fp_tensors
partition_size = get_partition_size(item)
partition = item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), partition_size).clone()
buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device
if contiguous_checkpoint:
# 将tensor放在一起存储。
# new mem
if i >= len(contiguous_data_buffers):
tensor_list = [
torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)
for _ in range(num_layers)
]
contiguous_data_buffers.append(tensor_list)
data_offsets.append(0)
elif contiguous_data_buffers[i] is None:
tensor_list = [
torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)
for _ in range(num_layers)
]
contiguous_data_buffers[i] = tensor_list
data_offsets[i] = 0
# Because the 'new\_empty' returns uninitialized pages,
# the pages need to be populated during the cudaMemcpy time
# which increases the data copy time. To avoid this, we
# pre-populate these pages by simply writing 0 ahead of
# the actual cudaMemcpy operation time. Due to the
# previously launched GPU kernels, there is a small
# window of time here for CPUs to populate pages asynchronously.
contiguous_data_buffers[i][data_offsets[i]].data[range(
0, contiguous_data_buffers[i][data_offsets[i]].data.shape[0],
int(mmap.PAGESIZE / contiguous_data_buffers[i][data_offsets[i]].data.element_size()))] = 0
contiguous_partition = contiguous_data_buffers[i][data_offsets[i]].data.copy_(partition.data)
data_offsets[i] = data_offsets[i] + 1
inputs.append(contiguous_partition)
else:
partition = partition.cpu() if CPU_CHECKPOINT else partition
inputs.append(partition)
return inputs
gather_partitioned_activations
聚合切分的act为完整的act。
def is\_activation\_to\_checkpoint(item):
"""
Is an activation to be checkpointed
"""
global mp_size
return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size
# tensors: [item, size, ..., item, size]
def gather\_partitioned\_activations(tensors, device=None):
global mp_rank, mp_size, mp_group
assert len(tensors) % 2 == 0, f'Expected even count of tensors, instead got {len(tensors)}'
inputs = []
num_args = int(len(tensors) / 2)
for i in range(num_args):
item = tensors[2 * i]
size = tensors[2 * i + 1]
if not is_activation_to_checkpoint(item):
inputs.append(item)
continue
# don't need to do all\_gather if model parallel is not enabled
if mp_group is None or mp_size == 1:
item = item.view(list(size.numpy()))
inputs.append(item)
continue
partition_size = item.numel()
tensor_size = partition_size * mp_size
if device is not None:
flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device)
else:
flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device)
partitions = []
for i in range(mp_size):
part_i = flat_tensor.narrow(0, partition_size * i, partition_size)
if i == mp_rank:
part_i.copy_(item)
partitions.append(part_i)
dist.all_gather(partitions, partitions[mp_rank], group=mp_group)
input_tensor = flat_tensor.view(list(size.numpy()))
item.data = input_tensor.data
inputs.append(item)
return tuple(inputs)
get_cpu_activations_for_backward
def get\_cpu\_activations\_for\_backward(args, inputs):
new_args = []
for i, (arg, inp) in enumerate(zip(args, inputs)):
if not is_activation_to_checkpoint(arg):
new_args.append(arg)
continue
# input.data赋值给arg.data
arg.data = inp.data
new_args.append(arg)
return new_args
get_partitioned_activations_for_backward
def get\_partitioned\_activations\_for\_backward(args, inputs, contiguous\_checkpoint: bool):
global contiguous_size_buffers, size_offsets
# (tensor, size, ..., tensor, size)
new_args = []
num_non_fp_tensors = 0
for arg_index, (arg, inp) in enumerate(zip(args, inputs)):
size = torch.tensor(arg.size()) if torch.is_tensor(arg) else None
if not is_activation_to_checkpoint(arg):
new_args.append(arg)
new_args.append(size)
num_non_fp_tensors += 1
continue
arg.data = inp.data
new_args.append(arg)
i = arg_index - num_non_fp_tensors
if contiguous_checkpoint:
numel = size.numel()
if i >= len(contiguous_size_buffers):
tmp = torch.tensor(())
contiguous_size_buffers.append(
tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device))
size_offsets.append(0)
elif contiguous_size_buffers[i] is None:
tmp = torch.tensor(())
contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)
size_offsets[i] = 0
contiguous_size = contiguous_size_buffers[i].narrow(0, size_offsets[i], numel).data.copy_(size.data)
contiguous_size = contiguous_size.view_as(size)
size_offsets[i] = size_offsets[i] + numel
new_args.append(contiguous_size)
else:
new_args.append(size)
return new_args
CheckpointFunction
activation checkpoing 的主函数,通过autograd.Function实现模型前向checkpoint activations,backword时load activations checkpoint 并参数反向计算。
class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set\_rng\_state is replaced with `\_set\_cuda\_rng\_state` #ignore-cuda
2) the states in the model parallel tracker are also properly
tracked/set/reset.
3) Performance activation partitioning, contiguous memory optimization
4) CPU Checkpointing
5) Profile forward and backward functions
"""
Forward Process
@staticmethod
def forward(ctx, run\_function, all\_outputs, *args):
global mpu, timers, SYNCHRONIZE, PROFILE_TIME
def save\_args\_for\_backward(*all\_args):
# all\_args拆成tensor,non-tensor,flags保存到ctx中,用于backword计算。
tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)
ctx.deepspeed_saved_tensors = tensor_args
ctx.non_tensor_args = non_tensor_args
ctx.tensor_flags = tensor_flags
if SYNCHRONIZE:
get_accelerator().synchronize()
if timers is None and PROFILE_TIME:
timers = Timers()
if PROFILE_TIME:
timers('forward').start()
ctx.run_function = run_function
global num_layers
global mp_rank, mp_size, mp_group
global contiguous_data_buffers, contiguous_size_buffers
global data_offsets, size_offsets
# get tensor model parallel infos
if mp_rank is None:
if mpu is not None:
if hasattr(mpu, 'get\_tensor\_model\_parallel\_rank'):
mp_rank = mpu.get_tensor_model_parallel_rank()
mp_size = mpu.get_tensor_model_parallel_world_size()
mp_group = mpu.get_tensor_model_parallel_group()
else:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()
else:
mp_rank = 0
mp_size = 1
mp_group = None
global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
if cuda_device is None:
see_memory_usage("First Forward Beginning", force=False)
if dist.get_rank() == 0:
logger.info(f"Activation Checkpointing Information")
logger.info(f"----Partition Activations {PARTITION\_ACTIVATIONS}, CPU CHECKPOINTING {CPU\_CHECKPOINT}")
logger.info(
f"----contiguous Memory Checkpointing {CONTIGUOUS\_CHECKPOINTING} with {num\_layers} total layers")
logger.info(f"----Synchronization {SYNCHRONIZE}")
logger.info(f"----Profiling time in checkpointing {PROFILE\_TIME}")
cuda_device = get_accelerator().current_device_name()
transport_stream = get_accelerator().Stream(device=cuda_device)
if PARTITION_ACTIVATIONS:
# tensor parallel for inputs, tensor one partition
inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
elif CPU_CHECKPOINT:
# offload acts to cpu, all tensor
inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint)
# just in case something funky is happening such as reuse of inputs
inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint)
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = get_accelerator().get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# ran forwords functions and get outputs, with no\_grad
see_memory_usage("Before running forward on the layer", force=False)
# ctx.save\_for\_backward(*args)
with torch.no_grad():
outputs = run_function(*inputs_cuda)
see_memory_usage("After running forward on the layer", force=False)
del inputs_cuda
# save inputs args for backwords
if PARTITION_ACTIVATIONS:
new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING)
assert len(new_args) % 2 == 0, f'save\_for\_backward called with odd number of args, {len(new\_args)}'
save_args_for_backward(*new_args)
elif CPU_CHECKPOINT:
new_args = get_cpu_activations_for_backward(args, inputs)
save_args_for_backward(*new_args)
else:
save_args_for_backward(*args)
if PROFILE_TIME:
timers('forward').stop()
timers.log(['forward'])
if SYNCHRONIZE:
get_accelerator().synchronize()
# Tensors returned from forward() may not be differentiable. 非float-point得是无法求梯度的。
if torch.is_tensor(outputs):
non_grad_outputs = [outputs] if not outputs.is_floating_point() else []
else:
non_grad_outputs = [o for o in outputs if torch.is_tensor(o) and not o.is_floating_point()]
ctx.mark_non_differentiable(*non_grad_outputs)
if torch.is_tensor(outputs):
all_outputs += [outputs]
return outputs
else:
all_outputs += outputs
outputs, _, _ = extract_tensors(all_objects=outputs)
return tuple(outputs)
Backword Process:
@staticmethod
def backward(ctx, *grads):
global timers
see_memory_usage("In backward", force=False)
# removing pointers to the contiguous buffer memory
# so that they can be garbage collected once the checkpoints
# have been used
if SYNCHRONIZE:
get_accelerator().synchronize()
if PROFILE_TIME:
timers('backward').start()
if CONTIGUOUS_CHECKPOINTING:
global data_offsets, size_offsets
global contiguous_data_buffers, contiguous_size_buffers
for buffers in contiguous_data_buffers:
buffers = []
# frees up all the pointers to the checkpoints except for the ones
# stored by save for backward
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []
see_memory_usage("In backward checkpointing code", force=False)
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
global cuda_device, transport_stream, PARTITION_ACTIVATIONS
if PARTITION_ACTIVATIONS:
# with get\_accelerator().stream(transport\_stream):
inputs = gather_partitioned_activations(ctx.deepspeed_saved_tensors,
device=cuda_device if CPU_CHECKPOINT else None)
detached_inputs = detach_variable(inputs)
elif CPU_CHECKPOINT:
inputs = move_to_device(ctx.deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint)
detached_inputs = detach_variable(inputs)
else:
inputs = ctx.deepspeed_saved_tensors
detached_inputs = detach_variable(inputs)
# Add non tensor input args
detached_inputs = merge_tensors(tensor_objects=detached_inputs,
non_tensor_objects=ctx.non_tensor_args,
tensor_flags=ctx.tensor_flags)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = get_accelerator().get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
# if PARTITION\_ACTIVATIONS:
# current\_stream=get\_accelerator().current\_stream()
# current\_stream.wait\_stream(transport\_stream)
see_memory_usage("In backward checkpointing code before forward", force=False)
# recompute activations
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
see_memory_usage("In backward checkpointing code after forward", force=False)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs, )
# Filter out non tensor outputs
outputs, _, _ = extract_tensors(all_objects=outputs)
# Construct arguments to autograd.backward().
# This is usually just outputs and grads, but forward() can return tensors that
# are not differentiable.
output_tensors = []
grad_tensors = []
for out, grad in zip(outputs, grads):
if out.requires_grad:
output_tensors.append(out)
grad_tensors.append(grad)
see_memory_usage("In backward checkpointing code before backward", force=False)
torch.autograd.backward(output_tensors, grad_tensors)
# Force clear our stashed tensors to prevent a memory leak in certain scenarios
ctx.deepspeed_saved_tensors = None
ctx.non_tensor_args = None
ctx.tensor_flags = None
see_memory_usage("After backward checkpointing code after backward", force=False)
if PROFILE_TIME:
timers('backward').stop()
timers.log(['backward'])
if SYNCHRONIZE:
get_accelerator().synchronize()
ret_list = [None, None] # first None for ctx
for inp in detached_inputs:
if torch.is_tensor(inp):
ret_list.append(inp.grad)
else:
ret_list.append(None)
return tuple(ret_list)
checkpoint
activations checkpoint 入口。
def checkpoint(function, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint. """
all_outputs = []
CheckpointFunction.apply(function, all_outputs, *args)
if len(all_outputs) == 1:
return all_outputs[0]
else:
return tuple(all_outputs)
test_activation_checkpointing
ckpt = deepspeed.checkpointing.checkpoint
def \_compute(module, *inputs, do\_checkpoint=False):
if do_checkpoint:
outputs = ckpt(module, *inputs)
else:
outputs = module(*inputs)
if torch.is_tensor(outputs):
outputs = (outputs, )
sum(o.sum() for o in outputs if torch.is_tensor(o) and o.requires_grad).backward()
grads = [p.grad for p in module.parameters()]
input_grads = [inp.grad for inp in inputs if torch.is_tensor(inp)]
return {
'outputs': outputs,
'module\_grads': grads,
'input\_grads': input_grads,
}
def \_test\_activation\_checkpoint(module, *inputs):
# Move to device
module.to(get_accelerator().device_name())
# Get rid of dropouts until we fork the RNG between tests.
module.eval()
module_ = deepcopy(module)
inputs_ = _prep_inputs(*inputs)
base = _compute(module_, *inputs_, do_checkpoint=False)
module_ = deepcopy(module)
inputs_ = _prep_inputs(*inputs)
test = _compute(module_, *inputs_, do_checkpoint=True)
for group in base.keys():
for b, t in zip(base[group], test[group]):
_match_outputs(b, t)
参课文献
- https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/activation\_checkpointing/checkpointing.py
- https://github.com/microsoft/DeepSpeed/blob/master/tests/unit/runtime/activation\_checkpointing/test\_activation\_checkpointing.py
- https://deepspeed.readthedocs.io/en/latest/activation-checkpointing.html
- https://www.deepspeed.ai/tutorials/megatron/#deepspeed-activation-checkpoints-optional
- https://zhuanlan.zhihu.com/p/373662730