Deepspeed Mixture-of-Expert

技术

experts

  
class Experts(torch.nn.Module):  
  
    def \_\_init\_\_(self, expert, num\_local\_experts=1, expert\_group\_name=None):  
        super(Experts, self).__init__()  
  
        self.deepspeed_experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])  
        self.num_local_experts = num_local_experts  
  
        # TODO: revisit allreduce for moe.gate...  
        for expert in self.deepspeed_experts:  
            # TODO: Create param groups to handle expert + data case (e.g. param.group = moe\_group)  
            for name, param in expert.named_parameters():  
                param.allreduce = False  
                param.group_name = expert_group_name  
  
    def forward(self, inputs):  
        chunks = inputs.chunk(self.num_local_experts, dim=1)  
        expert_outputs = []  
        for chunk, expert in zip(chunks, self.deepspeed_experts):  
            out = expert(chunk)  
            if type(out) is tuple:  
                out = out[0]  # Ignore the bias term for now  
            expert_outputs += [out]  
  
        expert_output = torch.cat(expert_outputs, dim=1)  
        return expert_output  

gather_tokens

All gather tokens for MP.

  
def gather\_tokens(input\_, dim=0):  
    mpu = deepspeed.utils.groups.mpu  
    if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1:  
        # no tensor parallelism for non-experts  
        return input_  
    return _GatherTokens.apply(input_, dim)  

  
class \_GatherTokens(torch.autograd.Function):  
    """All gather tokens among the tensor parallel ranks"""  
  
    @staticmethod  
    def symbolic(graph, input\_, dim):  
        return _gather_tokens(input_, dim)  
  
    @staticmethod  
    def forward(ctx, input\_, dim):  
        ctx.dim = dim  
        return _gather_tokens(input_, dim)  
  
    @staticmethod  
    def backward(ctx, grad\_output):  
        return _drop_tokens(grad_output, ctx.dim), None  

  
def \_gather\_tokens(input\_, dim=0):  
    """Gather tensors and concatenate them along a dimension"""  
    mpu = deepspeed.utils.groups.mpu  
  
    input_ = input_.contiguous()  
    # Size and dimension.  
    rank = mpu.get_tensor_model_parallel_rank()  
  
    tensor_list = [torch.empty_like(input_) for _ in range(mpu.get_tensor_model_parallel_world_size())]  
    tensor_list[rank] = input_  
    deepspeed.comm.all_gather(tensor_list, input_, group=mpu.get_tensor_model_parallel_group())  
  
    # Note: torch.cat already creates a contiguous tensor.  
    output = torch.cat(tensor_list, dim=dim).contiguous()  
  
    return output  

  
def \_drop\_tokens(input\_, dim=0):  
    """Divide a tensor among the tensor parallel ranks"""  
    mpu = deepspeed.utils.groups.mpu  
  
    total_chunks = mpu.get_tensor_model_parallel_world_size()  
    this_chunk = mpu.get_tensor_model_parallel_rank()  
    assert input_.shape[  
        dim] % total_chunks == 0, f"input dimension {dim} ({input\_.shape[dim]}) is not divisible by tensor parallel world size ({total\_chunks})"  
    chunk_size = input_.shape[dim] // total_chunks  
  
    return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)  

drop_tokens

Partition tokens for MP.

  
def drop\_tokens(input\_, dim=0):  
    mpu = deepspeed.utils.groups.mpu  
    if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1:  
        # no tensor parallelism for non-experts  
        return input_  
    return _DropTokens.apply(input_, dim)  

  
class \_DropTokens(torch.autograd.Function):  
    "Divide tokens equally among the tensor parallel ranks"  
  
    @staticmethod  
    def symbolic(graph, input\_, dim):  
        return _drop_tokens(input_, dim)  
  
    @staticmethod  
    def forward(ctx, input\_, dim):  
        ctx.dim = dim  
        return _drop_tokens(input_, dim)  
  
    @staticmethod  
    def backward(ctx, input\_):  
        return _gather_tokens(input_, ctx.dim), None  

shared_moe

  
uniform_map: Dict[torch.device, Callable] = {}  
gumbel_map: Dict[torch.device, Callable] = {}  
exp_selection_uniform_map: Dict[torch.device, Callable] = {}  

multiplicative_jitter

  
def multiplicative\_jitter(x, device: torch.device, epsilon=1e-2):  
    """  
    Modified from switch transformer paper. mesh transformers  
    Multiply values by a random number between 1-epsilon and 1+epsilon.  
    Makes models more resilient to rounding errors introduced by bfloat16.  
    This seems particularly important for logits.  
    Args:  
        x: a torch.tensor  
        device: torch.device  
        epsilon: a floating point value  
    Returns:  
        a jittered x.  
    """  
    if epsilon == 0:  
        return x  
    uniform = uniform_map.get(device)  
    if uniform is None:  
        uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - epsilon, device=device),  
                                                      high=torch.tensor(1.0 + epsilon,  
                                                                        device=device)).rsample  # type: ignore  
        uniform_map[device] = uniform  
    return x * uniform(x.shape)  

gumbel_rsample

  
def gumbel\_rsample(shape: Tuple, device: torch.device) -> Tensor:  
    gumbel = gumbel_map.get(device)  
    if gumbel is None:  
        one = torch.tensor(1.0, device=device)  
        zero = torch.tensor(0.0, device=device)  
        gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample  # type: ignore  
        gumbel_map[device] = gumbel  
    return gumbel(shape)  

_AllToAll

  
# Based on https://github.com/pytorch/pytorch/pull/40762  
class \_AllToAll(torch.autograd.Function):  
  
    @staticmethod  
    def forward(  
            ctx: Any,  
            # TODO: replace with DS process group  
            group: torch.distributed.ProcessGroup,  
            input: Tensor) -> Tensor:  # type: ignore  
        ctx.group = group  
        input = input.contiguous()  
        output = torch.empty_like(input)  
        dist.all_to_all_single(output, input, group=group)  
        return output  
  
    @staticmethod  
    def backward(ctx: Any, *grad\_output: Tensor) -> Tuple[None, Tensor]:  
        return (None, _AllToAll.apply(ctx.group, *grad_output))  

Each process splits input tensor and then scatters the split list to all processes in a group. Then concatenate the received tensors from all the processes in the group and return single output tensor.

  
>>> input = torch.arange(4) + rank * 4  
>>> input  
tensor([0, 1, 2, 3])     # Rank 0  
tensor([4, 5, 6, 7])     # Rank 1  
tensor([8, 9, 10, 11])   # Rank 2  
tensor([12, 13, 14, 15]) # Rank 3  
>>> output = torch.empty([4], dtype=torch.int64)  
>>> dist.all_to_all_single(output, input)  
>>> output  
tensor([0, 4, 8, 12])    # Rank 0  
tensor([1, 5, 9, 13])    # Rank 1  
tensor([2, 6, 10, 14])   # Rank 2  
tensor([3, 7, 11, 15])   # Rank 3  

einsum

  
# einsum rewrites are on par or more performant  
# switch can be bubbled up in future  
USE_EINSUM = True  
  
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity  
# See https://arxiv.org/pdf/2006.16668.pdf for details.  
def einsum(rule, a, b):  
    if USE_EINSUM:  
        return torch.einsum(rule, a, b)  
    elif rule == 's,se->se':  
        return a.reshape(a.shape[0], -1) * b  
    elif rule == 'se,sc->sec':  
        return a.unsqueeze(2) * b.unsqueeze(1)  
    elif rule == 'se,se->s':  
        return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)  
    elif rule == 'sec,sm->ecm':  
        s = a.shape[0]  
        e = a.shape[1]  
        c = a.shape[2]  
        m = b.shape[1]  
        return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)  
    elif rule == 'sec,ecm->sm':  
        return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))  
    elif rule == 'ks,ksm->sm':  
        k = b.shape[0]  
        s = b.shape[1]  
        m = b.shape[2]  
        # [k, s] -> [s, k] -> [s, 1, k]  
        a = a.t().unsqueeze(1)  
        # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]  
        b = b.reshape(k, -1).t().reshape(s, m, k)  
        # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]  
        return torch.bmm(a, b.transpose(1, 2)).squeeze(2)  
    else:  
        return torch.einsum(rule, a, b)  

top1gating

  
@torch.jit.script  
def \_capacity(gates: Tensor, capacity\_factor: Tensor, min\_capacity: Tensor) -> Tensor:  
    # gates has shape of SE  
    num_tokens = gates.shape[0]  
    num_experts = gates.shape[1]  
    # to(torch.int64) works around a bug in torch.onnx.export:  
    # it should cast k to int64 when converting torch.topk but it doesn't.  
    capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)  
    if capacity < min_capacity:  
        capacity = min_capacity.to(torch.int64)  
    return capacity  
  
  
@torch.jit.script  
def \_top\_idx(source, k):  
    return torch.topk(source, k=k, dim=0)[1]  

  
def top1gating(logits: Tensor,  
               capacity\_factor: float,  
               min\_capacity: int,  
               used\_token: Tensor = None,  
               noisy\_gate\_policy: Optional[str] = None,  
               drop\_tokens: bool = True,  
               use\_rts: bool = True,  
               use\_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:  
    """Implements Top1Gating on logits."""  
    if noisy_gate_policy == 'RSample':  
        logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)  
    # everything is in fp32 in this function  
    gates = F.softmax(logits, dim=1)  
  
    capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))  
  
    # Create a mask for 1st's expert per token  
    # noisy gating  
    indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1)  
    num_experts = int(gates.shape[1])  
    mask1 = F.one_hot(indices1_s, num_classes=num_experts)  
  
    # mask only used tokens  
    if used_token is not None:  
        mask1 = einsum("s,se->se", used_token, mask1)  
  
    # gating decisions  
    exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')  
  
    # if we don't want to drop any tokens  
    if not drop_tokens:  
        new_capacity = torch.max(exp_counts).to(logits.device)  
        dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())  
        capacity = new_capacity  
  
    # Compute l\_aux  
    me = torch.mean(gates, dim=0)  
    ce = torch.mean(mask1.float(), dim=0)  
    l_aux = torch.sum(me * ce) * num_experts  
  
    # Random Token Selection  
    if use_rts:  
        uniform = exp_selection_uniform_map.get(logits.device)  
        if uniform is None:  
            uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device),  
                                                          high=torch.tensor(1.0, device=logits.device)).rsample  
            exp_selection_uniform_map[logits.device] = uniform  
  
        mask1_rand = mask1 * uniform(mask1.shape)  
    else:  
        mask1_rand = mask1  
  
    assert logits.shape[  
        0] >= min_capacity, "No. of tokens (batch-size) should be greater than min\_capacity. Either set min\_capacity to 0 or increase your batch size."  
  
    top_idx = _top_idx(mask1_rand, capacity)  
  
    new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)  
    mask1 = new_mask1  
  
    if use_tutel:  
        # Tutel doesn't support index values masked with zero  
        # so we need to replace masked indices with -1  
        indices_mask = mask1.sum(dim=1) * num_experts - 1  
        indices1_s = torch.min(indices1_s, indices_mask)  
  
    # Compute locations in capacity buffer  
    if use_tutel:  
        locations1 = tutel_moe.fast_cumsum_sub_one(mask1)  
    else:  
        locations1 = torch.cumsum(mask1, dim=0) - 1  
  
    if use_tutel:  
        gates1_s = (gates * mask1).sum(dim=1)  
        locations1_s = torch.sum(locations1 * mask1, dim=1)  
        return l_aux, capacity, num_experts, [  
            indices1_s,  
        ], [  
            locations1_s,  
        ], [  
            gates1_s,  
        ], exp_counts  
  
    # Store the capacity location for each token  
    locations1_s = torch.sum(locations1 * mask1, dim=1)  
  
    # Normalize gate probabilities  
    mask1_float = mask1.float()  
    gates = gates * mask1_float  
  
    locations1_sc = _one_hot_to_float(locations1_s, capacity)  
    combine_weights = einsum("se,sc->sec", gates, locations1_sc)  
  
    dispatch_mask = combine_weights.bool()  
  
    return l_aux, combine_weights, dispatch_mask, exp_counts  

top2gating

picture.image

  
def top2gating(logits: Tensor, capacity\_factor: float, min\_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:  
    """Implements Top2Gating on logits."""  
    # everything is in fp32 in this function  
    gates = F.softmax(logits, dim=1)  
  
    capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))  
  
    # Create a mask for 1st's expert per token  
    indices1_s = torch.argmax(gates, dim=1)  
    num_experts = int(gates.shape[1])  
    mask1 = F.one_hot(indices1_s, num_classes=num_experts)  
  
    # Create a mask for 2nd's expert per token using Gumbel-max trick  
    # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/  
    logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)  
    # Replace top-expert with min value  
    logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))  
    indices2_s = torch.argmax(logits_except1, dim=1)  
    mask2 = F.one_hot(indices2_s, num_classes=num_experts)  
  
    # Compute locations in capacity buffer  
    locations1 = torch.cumsum(mask1, dim=0) - 1  
    locations2 = torch.cumsum(mask2, dim=0) - 1  
    # Update 2nd's location by accounting for locations of 1st  
    locations2 += torch.sum(mask1, dim=0, keepdim=True)  
  
    # gating decisions  
    exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')  
  
    # Compute l\_aux  
    me = torch.mean(gates, dim=0)  
    ce = torch.mean(mask1.float(), dim=0)  
    l_aux = torch.mean(me * ce) * num_experts * num_experts  
  
    # Remove locations outside capacity from mask  
    mask1 *= torch.lt(locations1, capacity)  
    mask2 *= torch.lt(locations2, capacity)  
  
    # Store the capacity location for each token  
    locations1_s = torch.sum(locations1 * mask1, dim=1)  
    locations2_s = torch.sum(locations2 * mask2, dim=1)  
  
    # Normalize gate probabilities  
    mask1_float = mask1.float()  
    mask2_float = mask2.float()  
    gates1_s = einsum("se,se->s", gates, mask1_float)  
    gates2_s = einsum("se,se->s", gates, mask2_float)  
    denom_s = gates1_s + gates2_s  
    # Avoid divide-by-zero  
    denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)  
    gates1_s /= denom_s  
    gates2_s /= denom_s  
  
    # Calculate combine\_weights and dispatch\_mask  
    gates1 = einsum("s,se->se", gates1_s, mask1_float)  
    gates2 = einsum("s,se->se", gates2_s, mask2_float)  
    locations1_sc = _one_hot_to_float(locations1_s, capacity)  
    locations2_sc = _one_hot_to_float(locations2_s, capacity)  
    combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)  
    combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)  
    combine_weights = combine1_sec + combine2_sec  
    dispatch_mask = combine_weights.bool()  
  
    return l_aux, combine_weights, dispatch_mask, exp_counts  

TopKGate

  
class TopKGate(Module):  
    """Gate module which implements Top2Gating as described in Gshard\_.  
    ::  
  
        gate = TopKGate(model\_dim, num\_experts)  
        l\_aux, combine\_weights, dispatch\_mask = gate(input)  
  
    .. Gshard\_: https://arxiv.org/pdf/2006.16668.pdf  
  
    Args:  
        model\_dim (int):  
            size of model embedding dimension  
        num\_experts (ints):  
            number of experts in model  
    """  
  
    wg: torch.nn.Linear  
  
    def \_\_init\_\_(self,  
                 model\_dim: int,  
                 num\_experts: int,  
                 k: int = 1,  
                 capacity\_factor: float = 1.0,  
                 eval\_capacity\_factor: float = 1.0,  
                 min\_capacity: int = 8,  
                 noisy\_gate\_policy: Optional[str] = None,  
                 drop\_tokens: bool = True,  
                 use\_rts: bool = True) -> None:  
        super().__init__()  
  
        # Only top-1 and top-2 are supported at the moment.  
        if k != 1 and k != 2:  
            raise ValueError('Only top-1 and top-2 gatings are supported.')  
        self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()  
        self.k = k  
        self.capacity_factor = capacity_factor  
        self.eval_capacity_factor = eval_capacity_factor  
        self.min_capacity = min_capacity  
        self.noisy_gate_policy = noisy_gate_policy  
        self.timers = SynchronizedWallClockTimer()  
        self.wall_clock_breakdown = False  
        self.gate_time = 0.0  
        self.drop_tokens = drop_tokens  
        self.use_rts = use_rts  
  
    def forward(self,  
                input: torch.Tensor,  
                used\_token: torch.Tensor = None,  
                use\_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor]:  # type: ignore  
  
        if self.wall_clock_breakdown:  
            self.timers('TopKGate').start()  
  
        if self.wg.weight.dtype != torch.float32:  
            self.wg = self.wg.float()  
        input_fp32 = input.float()  
        # input jittering  
        if self.noisy_gate_policy == 'Jitter' and self.training:  
            input_fp32 = multiplicative_jitter(input_fp32, device=input.device)  
        logits = self.wg(input_fp32)  
  
        if self.k == 1:  
            gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,  
                                     self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,  
                                     self.drop_tokens, self.use_rts, use_tutel)  
  
        else:  
            gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,  
                                     self.min_capacity)  
  
        if self.wall_clock_breakdown:  
            self.timers('TopKGate').stop()  
            self.gate_time = self.timers('TopKGate').elapsed(reset=False)  
  
        return gate_output  

MOELayer

picture.image

picture.image

  
class MOELayer(Base):  
    """MOELayer module which implements MixtureOfExperts as described in Gshard\_.  
    ::  
  
        gate = TopKGate(model\_dim, num\_experts)  
        moe = MOELayer(gate, expert)  
        output = moe(input)  
        l\_aux = moe.l\_aux  
  
    .. Gshard\_: https://arxiv.org/pdf/2006.16668.pdf  
  
    Args:  
        gate (torch.nn.Module):  
            gate network  
        expert (torch.nn.Module):  
            expert network  
    """  
  
    def \_\_init\_\_(self,  
                 gate: Module,  
                 experts: Module,  
                 ep\_group\_name,  
                 ep\_size,  
                 num\_local\_experts: int,  
                 use\_tutel: bool = False) -> None:  
        super().__init__()  
        self.gate = gate  
        self.experts = experts  
        self.ep_group = None  
        self.ep_size = ep_size  
        self.ep_group_name = ep_group_name  
        self.num_local_experts = num_local_experts  
        self.time_falltoall = 0.0  
        self.time_salltoall = 0.0  
        self.time_moe = 0.0  
        self.timers = SynchronizedWallClockTimer()  
        self.wall_clock_breakdown = False  
  
        self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1  
  
        if self.use_tutel:  
            logger.info('Using Tutel optimizations.')  
        elif use_tutel and not TUTEL_INSTALLED:  
            logger.warning("Tutel optimization requested but not installed. "  
                           "Proceeding without Tutel.")  
        elif use_tutel and TUTEL_INSTALLED and gate.k != 1:  
            logger.warning("To enable Tutel optimization, use top-1 instead of top-2 gate. "  
                           "Proceeding without Tutel.")  
  
    def \_set\_ep\_group(self, ep\_group):  
        self.ep_group = ep_group  
  
    def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:  
  
        if self.wall_clock_breakdown:  
            self.timers(MOE_TIMER).start()  
  
        # Implement Algorithm 2 from GShard paper.  
        d_model = input[0].shape[-1]  
  
        # Initial implementation -> Reshape into S tokens by dropping sequence dimension.  
        # Reshape into G groups so that each group can distribute tokens equally  
        # group\_size = kwargs['group\_size'] if 'group\_size' in kwargs.keys() else 1  
        reshaped_input = input[0].reshape(-1, d_model)  
  
        if self.use_tutel:  
            self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True)  
            S, M = reshaped_input.size(0), reshaped_input.size(1)  
  
            if not hasattr(self, '\_tutel\_dispatcher'):  
                self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)  
            self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)  
            dispatched_input = self._tutel_dispatcher.encode(reshaped_input)  
        else:  
            self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])  
            # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity  
            dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input)  
  
        if self.wall_clock_breakdown:  
            self.timers(FIRST_ALLTOALL_TIMER).start()  
  
        if groups._get_expert_model_parallel_world_size() == 1:  
            # If the non-expert is tensor-parallel, it will create  
            # duplicate tokens on the tensor-parallel ranks.  
            # Since our experts are not tensor-parallel, these duplicates  
            # need to be dropped to ensure correctness.  
            # this also doubles up as a communication optimization as we are  
            # reducing the all-to-all communication volume.  
            dispatched_input = drop_tokens(dispatched_input, dim=1)  
  
        dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)  
  
        if self.wall_clock_breakdown:  
            self.timers(FIRST_ALLTOALL_TIMER).stop()  
            self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)  
  
        # Re-shape after all-to-all: ecm -> gecm  
        dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)  
  
        expert_output = self.experts(dispatched_input)  
  
        if self.wall_clock_breakdown:  
            self.timers(SECOND_ALLTOALL_TIMER).start()  
  
        expert_output = _AllToAll.apply(self.ep_group, expert_output)  
  
        if self.wall_clock_breakdown:  
            self.timers(SECOND_ALLTOALL_TIMER).stop()  
            self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)  
  
        # Re-shape back: gecm -> ecm  
        expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)  
  
        if groups._get_expert_model_parallel_world_size() == 1:  
            # the dropped duplicate tokens need to be gathered on each  
            # tensor parallel rank again for the tensor-parallel  
            # non-expert of the next layer.  
            expert_output = gather_tokens(expert_output, dim=1)  
  
        if self.use_tutel:  
            combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))  
        else:  
            combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)  
  
        a = combined_output.reshape(input[0].shape)  
  
        if self.wall_clock_breakdown:  
            self.timers(MOE_TIMER).stop()  
            self.time_moe = self.timers(MOE_TIMER).elapsed(reset=False)  
  
        return a  

MoE

  
class MoE(torch.nn.Module):  
    """Initialize an MoE layer.  
  
    Arguments:  
        hidden\_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.  
        expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).  
        num\_experts (int, optional): default=1, the total number of experts per layer.  
        ep\_size (int, optional): default=1, number of ranks in the expert parallel world or group.  
        k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.  
        capacity\_factor (float, optional): default=1.0, the capacity of the expert at training time.  
        eval\_capacity\_factor (float, optional): default=1.0, the capacity of the expert at eval time.  
        min\_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity\_factor.  
        use\_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer.  
        noisy\_gate\_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.  
        drop\_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).  
        use\_rts (bool, optional): default=True, whether to use Random Token Selection.  
        use\_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).  
        enable\_expert\_tensor\_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts  
    """  
  
    def \_\_init\_\_(self,  
                 hidden\_size,  
                 expert,  
                 num\_experts=1,  
                 ep\_size=1,  
                 k=1,  
                 capacity\_factor=1.,  
                 eval\_capacity\_factor=1.,  
                 min\_capacity=4,  
                 use\_residual=False,  
                 noisy\_gate\_policy: typing.Optional[str] = None,  
                 drop\_tokens: bool = True,  
                 use\_rts=True,  
                 use\_tutel: bool = False,  
                 enable\_expert\_tensor\_parallelism: bool = False):  
  
        super(MoE, self).__init__()  
  
        self.use_residual = use_residual  
        self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism  
        assert num_experts % ep_size == 0, f"Number of experts ({num\_experts}) should be divisible by expert parallel size ({ep\_size})"  
        self.ep_size = ep_size  
        self.expert_group_name = f"ep\_size\_{self.ep\_size}"  
        self.num_experts = num_experts  
        self.num_local_experts = num_experts // self.ep_size  
  
        log_dist(  
            f'Creating MoE layer with num\_experts: {num\_experts} | num\_local\_experts: {self.num\_local\_experts} | expert\_parallel\_size: {self.ep\_size}',  
            [0])  
  
        assert noisy_gate_policy is None or noisy_gate_policy in ['None', 'Jitter', 'RSample'], \  
            'Unsupported noisy\_gate\_policy: ' + noisy_gate_policy  
  
        experts = Experts(expert, self.num_local_experts, self.expert_group_name)  
        self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,  
                                               min_capacity, noisy_gate_policy, drop_tokens, use_rts),  
                                      experts,  
                                      self.expert_group_name,  
                                      self.ep_size,  
                                      self.num_local_experts,  
                                      use_tutel=use_tutel)  
        if self.use_residual:  
            self.mlp = expert  
            # coefficient is used for weighted sum of the output of expert and mlp  
            self.coefficient = torch.nn.Linear(hidden_size, 2)  
  
    def set\_deepspeed\_parallelism(self):  
        self._create_process_groups()  
  
    def \_create\_process\_groups(self):  
        # Create process group for a layer if needed  
        if self.expert_group_name not in groups._get_expert_parallel_group_dict():  
            print(f"No existing process group found, creating a new group named: {self.expert\_group\_name}")  
            if (groups.mpu is None) or (not self.enable_expert_tensor_parallelism):  
                # Condition 1 - no groups.mpu means no tensor parallelism  
                # Condition 2 - disabling expert tensor parallelism on purpose  
                groups._create_expert_and_data_parallel(self.ep_size)  
            else:  
                # expert tensor parallelism is enabled  
                groups._create_expert_data_and_model_parallel(self.ep_size, mpu=groups.mpu)  
        # Set the group handle for the MOELayer (deepspeed\_moe) object  
        self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name))  
  
    def forward(self, hidden\_states, used\_token=None):  
        """ MoE forward  
  
        Arguments:  
            hidden\_states (Tensor): input to the layer  
            used\_token (Tensor, optional): default: None, mask only used tokens  
  
        Returns:  
            A tuple including output, gate loss, and expert count.  
  
            * output (Tensor): output of the model  
  
            * l\_aux (Tensor): gate loss value  
  
            * exp\_counts (int): expert count  
        """  
        output = self.deepspeed_moe(hidden_states, used_token)  
        if self.use_residual:  
            # Residual MoE  
            output_mlp = self.mlp(hidden_states)  
            if type(output_mlp) is tuple:  
                output_mlp = output_mlp[0]  # Ignore the bias term for now  
            coef = self.coefficient(hidden_states)  
            coef = torch.nn.functional.softmax(coef, dim=-1)  
            output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]  
        return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts  

参考文献

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
亿万用户下高可用融合直播的应用实践
直播融合 CDN 调度系统承担了公司内所有直播流量的接入工作,对高并发高带宽场景支持友好,有完善的体系进行容灾降级、质量优化、成本优化。本次演讲将带大家了解直播融合 CDN 调度系统的整体架构及在抖音上的应用。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论