experts
class Experts(torch.nn.Module):
def \_\_init\_\_(self, expert, num\_local\_experts=1, expert\_group\_name=None):
super(Experts, self).__init__()
self.deepspeed_experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
self.num_local_experts = num_local_experts
# TODO: revisit allreduce for moe.gate...
for expert in self.deepspeed_experts:
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe\_group)
for name, param in expert.named_parameters():
param.allreduce = False
param.group_name = expert_group_name
def forward(self, inputs):
chunks = inputs.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.deepspeed_experts):
out = expert(chunk)
if type(out) is tuple:
out = out[0] # Ignore the bias term for now
expert_outputs += [out]
expert_output = torch.cat(expert_outputs, dim=1)
return expert_output
gather_tokens
All gather tokens for MP.
def gather\_tokens(input\_, dim=0):
mpu = deepspeed.utils.groups.mpu
if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1:
# no tensor parallelism for non-experts
return input_
return _GatherTokens.apply(input_, dim)
class \_GatherTokens(torch.autograd.Function):
"""All gather tokens among the tensor parallel ranks"""
@staticmethod
def symbolic(graph, input\_, dim):
return _gather_tokens(input_, dim)
@staticmethod
def forward(ctx, input\_, dim):
ctx.dim = dim
return _gather_tokens(input_, dim)
@staticmethod
def backward(ctx, grad\_output):
return _drop_tokens(grad_output, ctx.dim), None
def \_gather\_tokens(input\_, dim=0):
"""Gather tensors and concatenate them along a dimension"""
mpu = deepspeed.utils.groups.mpu
input_ = input_.contiguous()
# Size and dimension.
rank = mpu.get_tensor_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(mpu.get_tensor_model_parallel_world_size())]
tensor_list[rank] = input_
deepspeed.comm.all_gather(tensor_list, input_, group=mpu.get_tensor_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
def \_drop\_tokens(input\_, dim=0):
"""Divide a tensor among the tensor parallel ranks"""
mpu = deepspeed.utils.groups.mpu
total_chunks = mpu.get_tensor_model_parallel_world_size()
this_chunk = mpu.get_tensor_model_parallel_rank()
assert input_.shape[
dim] % total_chunks == 0, f"input dimension {dim} ({input\_.shape[dim]}) is not divisible by tensor parallel world size ({total\_chunks})"
chunk_size = input_.shape[dim] // total_chunks
return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)
drop_tokens
Partition tokens for MP.
def drop\_tokens(input\_, dim=0):
mpu = deepspeed.utils.groups.mpu
if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1:
# no tensor parallelism for non-experts
return input_
return _DropTokens.apply(input_, dim)
class \_DropTokens(torch.autograd.Function):
"Divide tokens equally among the tensor parallel ranks"
@staticmethod
def symbolic(graph, input\_, dim):
return _drop_tokens(input_, dim)
@staticmethod
def forward(ctx, input\_, dim):
ctx.dim = dim
return _drop_tokens(input_, dim)
@staticmethod
def backward(ctx, input\_):
return _gather_tokens(input_, ctx.dim), None
shared_moe
uniform_map: Dict[torch.device, Callable] = {}
gumbel_map: Dict[torch.device, Callable] = {}
exp_selection_uniform_map: Dict[torch.device, Callable] = {}
multiplicative_jitter
def multiplicative\_jitter(x, device: torch.device, epsilon=1e-2):
"""
Modified from switch transformer paper. mesh transformers
Multiply values by a random number between 1-epsilon and 1+epsilon.
Makes models more resilient to rounding errors introduced by bfloat16.
This seems particularly important for logits.
Args:
x: a torch.tensor
device: torch.device
epsilon: a floating point value
Returns:
a jittered x.
"""
if epsilon == 0:
return x
uniform = uniform_map.get(device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - epsilon, device=device),
high=torch.tensor(1.0 + epsilon,
device=device)).rsample # type: ignore
uniform_map[device] = uniform
return x * uniform(x.shape)
gumbel_rsample
def gumbel\_rsample(shape: Tuple, device: torch.device) -> Tensor:
gumbel = gumbel_map.get(device)
if gumbel is None:
one = torch.tensor(1.0, device=device)
zero = torch.tensor(0.0, device=device)
gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
gumbel_map[device] = gumbel
return gumbel(shape)
_AllToAll
# Based on https://github.com/pytorch/pytorch/pull/40762
class \_AllToAll(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
# TODO: replace with DS process group
group: torch.distributed.ProcessGroup,
input: Tensor) -> Tensor: # type: ignore
ctx.group = group
input = input.contiguous()
output = torch.empty_like(input)
dist.all_to_all_single(output, input, group=group)
return output
@staticmethod
def backward(ctx: Any, *grad\_output: Tensor) -> Tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))
Each process splits input tensor and then scatters the split list to all processes in a group. Then concatenate the received tensors from all the processes in the group and return single output tensor.
>>> input = torch.arange(4) + rank * 4
>>> input
tensor([0, 1, 2, 3]) # Rank 0
tensor([4, 5, 6, 7]) # Rank 1
tensor([8, 9, 10, 11]) # Rank 2
tensor([12, 13, 14, 15]) # Rank 3
>>> output = torch.empty([4], dtype=torch.int64)
>>> dist.all_to_all_single(output, input)
>>> output
tensor([0, 4, 8, 12]) # Rank 0
tensor([1, 5, 9, 13]) # Rank 1
tensor([2, 6, 10, 14]) # Rank 2
tensor([3, 7, 11, 15]) # Rank 3
einsum
# einsum rewrites are on par or more performant
# switch can be bubbled up in future
USE_EINSUM = True
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
# See https://arxiv.org/pdf/2006.16668.pdf for details.
def einsum(rule, a, b):
if USE_EINSUM:
return torch.einsum(rule, a, b)
elif rule == 's,se->se':
return a.reshape(a.shape[0], -1) * b
elif rule == 'se,sc->sec':
return a.unsqueeze(2) * b.unsqueeze(1)
elif rule == 'se,se->s':
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
elif rule == 'sec,sm->ecm':
s = a.shape[0]
e = a.shape[1]
c = a.shape[2]
m = b.shape[1]
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
elif rule == 'sec,ecm->sm':
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
elif rule == 'ks,ksm->sm':
k = b.shape[0]
s = b.shape[1]
m = b.shape[2]
# [k, s] -> [s, k] -> [s, 1, k]
a = a.t().unsqueeze(1)
# [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
b = b.reshape(k, -1).t().reshape(s, m, k)
# bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
else:
return torch.einsum(rule, a, b)
top1gating
@torch.jit.script
def \_capacity(gates: Tensor, capacity\_factor: Tensor, min\_capacity: Tensor) -> Tensor:
# gates has shape of SE
num_tokens = gates.shape[0]
num_experts = gates.shape[1]
# to(torch.int64) works around a bug in torch.onnx.export:
# it should cast k to int64 when converting torch.topk but it doesn't.
capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
if capacity < min_capacity:
capacity = min_capacity.to(torch.int64)
return capacity
@torch.jit.script
def \_top\_idx(source, k):
return torch.topk(source, k=k, dim=0)[1]
def top1gating(logits: Tensor,
capacity\_factor: float,
min\_capacity: int,
used\_token: Tensor = None,
noisy\_gate\_policy: Optional[str] = None,
drop\_tokens: bool = True,
use\_rts: bool = True,
use\_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top1Gating on logits."""
if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))
# Create a mask for 1st's expert per token
# noisy gating
indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
# mask only used tokens
if used_token is not None:
mask1 = einsum("s,se->se", used_token, mask1)
# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
# if we don't want to drop any tokens
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
capacity = new_capacity
# Compute l\_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.sum(me * ce) * num_experts
# Random Token Selection
if use_rts:
uniform = exp_selection_uniform_map.get(logits.device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device),
high=torch.tensor(1.0, device=logits.device)).rsample
exp_selection_uniform_map[logits.device] = uniform
mask1_rand = mask1 * uniform(mask1.shape)
else:
mask1_rand = mask1
assert logits.shape[
0] >= min_capacity, "No. of tokens (batch-size) should be greater than min\_capacity. Either set min\_capacity to 0 or increase your batch size."
top_idx = _top_idx(mask1_rand, capacity)
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
mask1 = new_mask1
if use_tutel:
# Tutel doesn't support index values masked with zero
# so we need to replace masked indices with -1
indices_mask = mask1.sum(dim=1) * num_experts - 1
indices1_s = torch.min(indices1_s, indices_mask)
# Compute locations in capacity buffer
if use_tutel:
locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
else:
locations1 = torch.cumsum(mask1, dim=0) - 1
if use_tutel:
gates1_s = (gates * mask1).sum(dim=1)
locations1_s = torch.sum(locations1 * mask1, dim=1)
return l_aux, capacity, num_experts, [
indices1_s,
], [
locations1_s,
], [
gates1_s,
], exp_counts
# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
# Normalize gate probabilities
mask1_float = mask1.float()
gates = gates * mask1_float
locations1_sc = _one_hot_to_float(locations1_s, capacity)
combine_weights = einsum("se,sc->sec", gates, locations1_sc)
dispatch_mask = combine_weights.bool()
return l_aux, combine_weights, dispatch_mask, exp_counts
top2gating
def top2gating(logits: Tensor, capacity\_factor: float, min\_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# Replace top-expert with min value
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
indices2_s = torch.argmax(logits_except1, dim=1)
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
# Compute locations in capacity buffer
locations1 = torch.cumsum(mask1, dim=0) - 1
locations2 = torch.cumsum(mask2, dim=0) - 1
# Update 2nd's location by accounting for locations of 1st
locations2 += torch.sum(mask1, dim=0, keepdim=True)
# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
# Compute l\_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts
# Remove locations outside capacity from mask
mask1 *= torch.lt(locations1, capacity)
mask2 *= torch.lt(locations2, capacity)
# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
locations2_s = torch.sum(locations2 * mask2, dim=1)
# Normalize gate probabilities
mask1_float = mask1.float()
mask2_float = mask2.float()
gates1_s = einsum("se,se->s", gates, mask1_float)
gates2_s = einsum("se,se->s", gates, mask2_float)
denom_s = gates1_s + gates2_s
# Avoid divide-by-zero
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
gates1_s /= denom_s
gates2_s /= denom_s
# Calculate combine\_weights and dispatch\_mask
gates1 = einsum("s,se->se", gates1_s, mask1_float)
gates2 = einsum("s,se->se", gates2_s, mask2_float)
locations1_sc = _one_hot_to_float(locations1_s, capacity)
locations2_sc = _one_hot_to_float(locations2_s, capacity)
combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()
return l_aux, combine_weights, dispatch_mask, exp_counts
TopKGate
class TopKGate(Module):
"""Gate module which implements Top2Gating as described in Gshard\_.
::
gate = TopKGate(model\_dim, num\_experts)
l\_aux, combine\_weights, dispatch\_mask = gate(input)
.. Gshard\_: https://arxiv.org/pdf/2006.16668.pdf
Args:
model\_dim (int):
size of model embedding dimension
num\_experts (ints):
number of experts in model
"""
wg: torch.nn.Linear
def \_\_init\_\_(self,
model\_dim: int,
num\_experts: int,
k: int = 1,
capacity\_factor: float = 1.0,
eval\_capacity\_factor: float = 1.0,
min\_capacity: int = 8,
noisy\_gate\_policy: Optional[str] = None,
drop\_tokens: bool = True,
use\_rts: bool = True) -> None:
super().__init__()
# Only top-1 and top-2 are supported at the moment.
if k != 1 and k != 2:
raise ValueError('Only top-1 and top-2 gatings are supported.')
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
self.k = k
self.capacity_factor = capacity_factor
self.eval_capacity_factor = eval_capacity_factor
self.min_capacity = min_capacity
self.noisy_gate_policy = noisy_gate_policy
self.timers = SynchronizedWallClockTimer()
self.wall_clock_breakdown = False
self.gate_time = 0.0
self.drop_tokens = drop_tokens
self.use_rts = use_rts
def forward(self,
input: torch.Tensor,
used\_token: torch.Tensor = None,
use\_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
if self.wall_clock_breakdown:
self.timers('TopKGate').start()
if self.wg.weight.dtype != torch.float32:
self.wg = self.wg.float()
input_fp32 = input.float()
# input jittering
if self.noisy_gate_policy == 'Jitter' and self.training:
input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
logits = self.wg(input_fp32)
if self.k == 1:
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
self.drop_tokens, self.use_rts, use_tutel)
else:
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity)
if self.wall_clock_breakdown:
self.timers('TopKGate').stop()
self.gate_time = self.timers('TopKGate').elapsed(reset=False)
return gate_output
MOELayer
class MOELayer(Base):
"""MOELayer module which implements MixtureOfExperts as described in Gshard\_.
::
gate = TopKGate(model\_dim, num\_experts)
moe = MOELayer(gate, expert)
output = moe(input)
l\_aux = moe.l\_aux
.. Gshard\_: https://arxiv.org/pdf/2006.16668.pdf
Args:
gate (torch.nn.Module):
gate network
expert (torch.nn.Module):
expert network
"""
def \_\_init\_\_(self,
gate: Module,
experts: Module,
ep\_group\_name,
ep\_size,
num\_local\_experts: int,
use\_tutel: bool = False) -> None:
super().__init__()
self.gate = gate
self.experts = experts
self.ep_group = None
self.ep_size = ep_size
self.ep_group_name = ep_group_name
self.num_local_experts = num_local_experts
self.time_falltoall = 0.0
self.time_salltoall = 0.0
self.time_moe = 0.0
self.timers = SynchronizedWallClockTimer()
self.wall_clock_breakdown = False
self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1
if self.use_tutel:
logger.info('Using Tutel optimizations.')
elif use_tutel and not TUTEL_INSTALLED:
logger.warning("Tutel optimization requested but not installed. "
"Proceeding without Tutel.")
elif use_tutel and TUTEL_INSTALLED and gate.k != 1:
logger.warning("To enable Tutel optimization, use top-1 instead of top-2 gate. "
"Proceeding without Tutel.")
def \_set\_ep\_group(self, ep\_group):
self.ep_group = ep_group
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
if self.wall_clock_breakdown:
self.timers(MOE_TIMER).start()
# Implement Algorithm 2 from GShard paper.
d_model = input[0].shape[-1]
# Initial implementation -> Reshape into S tokens by dropping sequence dimension.
# Reshape into G groups so that each group can distribute tokens equally
# group\_size = kwargs['group\_size'] if 'group\_size' in kwargs.keys() else 1
reshaped_input = input[0].reshape(-1, d_model)
if self.use_tutel:
self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True)
S, M = reshaped_input.size(0), reshaped_input.size(1)
if not hasattr(self, '\_tutel\_dispatcher'):
self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
else:
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input)
if self.wall_clock_breakdown:
self.timers(FIRST_ALLTOALL_TIMER).start()
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, it will create
# duplicate tokens on the tensor-parallel ranks.
# Since our experts are not tensor-parallel, these duplicates
# need to be dropped to ensure correctness.
# this also doubles up as a communication optimization as we are
# reducing the all-to-all communication volume.
dispatched_input = drop_tokens(dispatched_input, dim=1)
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
if self.wall_clock_breakdown:
self.timers(FIRST_ALLTOALL_TIMER).stop()
self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)
# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
expert_output = self.experts(dispatched_input)
if self.wall_clock_breakdown:
self.timers(SECOND_ALLTOALL_TIMER).start()
expert_output = _AllToAll.apply(self.ep_group, expert_output)
if self.wall_clock_breakdown:
self.timers(SECOND_ALLTOALL_TIMER).stop()
self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)
# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
if groups._get_expert_model_parallel_world_size() == 1:
# the dropped duplicate tokens need to be gathered on each
# tensor parallel rank again for the tensor-parallel
# non-expert of the next layer.
expert_output = gather_tokens(expert_output, dim=1)
if self.use_tutel:
combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
else:
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
a = combined_output.reshape(input[0].shape)
if self.wall_clock_breakdown:
self.timers(MOE_TIMER).stop()
self.time_moe = self.timers(MOE_TIMER).elapsed(reset=False)
return a
MoE
class MoE(torch.nn.Module):
"""Initialize an MoE layer.
Arguments:
hidden\_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.
expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).
num\_experts (int, optional): default=1, the total number of experts per layer.
ep\_size (int, optional): default=1, number of ranks in the expert parallel world or group.
k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
capacity\_factor (float, optional): default=1.0, the capacity of the expert at training time.
eval\_capacity\_factor (float, optional): default=1.0, the capacity of the expert at eval time.
min\_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity\_factor.
use\_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer.
noisy\_gate\_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
drop\_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
use\_rts (bool, optional): default=True, whether to use Random Token Selection.
use\_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
enable\_expert\_tensor\_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts
"""
def \_\_init\_\_(self,
hidden\_size,
expert,
num\_experts=1,
ep\_size=1,
k=1,
capacity\_factor=1.,
eval\_capacity\_factor=1.,
min\_capacity=4,
use\_residual=False,
noisy\_gate\_policy: typing.Optional[str] = None,
drop\_tokens: bool = True,
use\_rts=True,
use\_tutel: bool = False,
enable\_expert\_tensor\_parallelism: bool = False):
super(MoE, self).__init__()
self.use_residual = use_residual
self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism
assert num_experts % ep_size == 0, f"Number of experts ({num\_experts}) should be divisible by expert parallel size ({ep\_size})"
self.ep_size = ep_size
self.expert_group_name = f"ep\_size\_{self.ep\_size}"
self.num_experts = num_experts
self.num_local_experts = num_experts // self.ep_size
log_dist(
f'Creating MoE layer with num\_experts: {num\_experts} | num\_local\_experts: {self.num\_local\_experts} | expert\_parallel\_size: {self.ep\_size}',
[0])
assert noisy_gate_policy is None or noisy_gate_policy in ['None', 'Jitter', 'RSample'], \
'Unsupported noisy\_gate\_policy: ' + noisy_gate_policy
experts = Experts(expert, self.num_local_experts, self.expert_group_name)
self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
min_capacity, noisy_gate_policy, drop_tokens, use_rts),
experts,
self.expert_group_name,
self.ep_size,
self.num_local_experts,
use_tutel=use_tutel)
if self.use_residual:
self.mlp = expert
# coefficient is used for weighted sum of the output of expert and mlp
self.coefficient = torch.nn.Linear(hidden_size, 2)
def set\_deepspeed\_parallelism(self):
self._create_process_groups()
def \_create\_process\_groups(self):
# Create process group for a layer if needed
if self.expert_group_name not in groups._get_expert_parallel_group_dict():
print(f"No existing process group found, creating a new group named: {self.expert\_group\_name}")
if (groups.mpu is None) or (not self.enable_expert_tensor_parallelism):
# Condition 1 - no groups.mpu means no tensor parallelism
# Condition 2 - disabling expert tensor parallelism on purpose
groups._create_expert_and_data_parallel(self.ep_size)
else:
# expert tensor parallelism is enabled
groups._create_expert_data_and_model_parallel(self.ep_size, mpu=groups.mpu)
# Set the group handle for the MOELayer (deepspeed\_moe) object
self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name))
def forward(self, hidden\_states, used\_token=None):
""" MoE forward
Arguments:
hidden\_states (Tensor): input to the layer
used\_token (Tensor, optional): default: None, mask only used tokens
Returns:
A tuple including output, gate loss, and expert count.
* output (Tensor): output of the model
* l\_aux (Tensor): gate loss value
* exp\_counts (int): expert count
"""
output = self.deepspeed_moe(hidden_states, used_token)
if self.use_residual:
# Residual MoE
output_mlp = self.mlp(hidden_states)
if type(output_mlp) is tuple:
output_mlp = output_mlp[0] # Ignore the bias term for now
coef = self.coefficient(hidden_states)
coef = torch.nn.functional.softmax(coef, dim=-1)
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts