kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/emotion2vec/fairseq_modules.py
@@ -4,142 +4,146 @@
from typing import Optional, Tuple, List
import numpy as np
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
   return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
    return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
class SamePad(nn.Module):
   def __init__(self, kernel_size, causal=False):
      super().__init__()
      if causal:
         self.remove = kernel_size - 1
      else:
         self.remove = 1 if kernel_size % 2 == 0 else 0
   def forward(self, x):
      if self.remove > 0:
         x = x[:, :, : -self.remove]
      return x
    def __init__(self, kernel_size, causal=False):
        super().__init__()
        if causal:
            self.remove = kernel_size - 1
        else:
            self.remove = 1 if kernel_size % 2 == 0 else 0
    def forward(self, x):
        if self.remove > 0:
            x = x[:, :, : -self.remove]
        return x
class TransposeLast(nn.Module):
   def __init__(self, deconstruct_idx=None):
      super().__init__()
      self.deconstruct_idx = deconstruct_idx
   def forward(self, x):
      if self.deconstruct_idx is not None:
         x = x[self.deconstruct_idx]
      return x.transpose(-2, -1)
    def __init__(self, deconstruct_idx=None):
        super().__init__()
        self.deconstruct_idx = deconstruct_idx
    def forward(self, x):
        if self.deconstruct_idx is not None:
            x = x[self.deconstruct_idx]
        return x.transpose(-2, -1)
class Fp32LayerNorm(nn.LayerNorm):
   def __init__(self, *args, **kwargs):
      super().__init__(*args, **kwargs)
   def forward(self, input):
      output = F.layer_norm(
         input.float(),
         self.normalized_shape,
         self.weight.float() if self.weight is not None else None,
         self.bias.float() if self.bias is not None else None,
         self.eps,
      )
      return output.type_as(input)
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self, input):
        output = F.layer_norm(
            input.float(),
            self.normalized_shape,
            self.weight.float() if self.weight is not None else None,
            self.bias.float() if self.bias is not None else None,
            self.eps,
        )
        return output.type_as(input)
class Fp32GroupNorm(nn.GroupNorm):
   def __init__(self, *args, **kwargs):
      super().__init__(*args, **kwargs)
   def forward(self, input):
      output = F.group_norm(
         input.float(),
         self.num_groups,
         self.weight.float() if self.weight is not None else None,
         self.bias.float() if self.bias is not None else None,
         self.eps,
      )
      return output.type_as(input)
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self, input):
        output = F.group_norm(
            input.float(),
            self.num_groups,
            self.weight.float() if self.weight is not None else None,
            self.bias.float() if self.bias is not None else None,
            self.eps,
        )
        return output.type_as(input)
class ConvFeatureExtractionModel(nn.Module):
   def __init__(
      self,
      conv_layers: List[Tuple[int, int, int]],
      dropout: float = 0.0,
      mode: str = "default",
      conv_bias: bool = False,
   ):
      super().__init__()
      assert mode in {"default", "layer_norm"}
      def block(
         n_in,
         n_out,
         k,
         stride,
         is_layer_norm=False,
         is_group_norm=False,
         conv_bias=False,
      ):
         def make_conv():
            conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
            nn.init.kaiming_normal_(conv.weight)
            return conv
         assert (
                   is_layer_norm and is_group_norm
                ) == False, "layer norm and group norm are exclusive"
         if is_layer_norm:
            return nn.Sequential(
               make_conv(),
               nn.Dropout(p=dropout),
               nn.Sequential(
                  TransposeLast(),
                  Fp32LayerNorm(dim, elementwise_affine=True),
                  TransposeLast(),
               ),
               nn.GELU(),
            )
         elif is_group_norm:
            return nn.Sequential(
               make_conv(),
               nn.Dropout(p=dropout),
               Fp32GroupNorm(dim, dim, affine=True),
               nn.GELU(),
            )
         else:
            return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
      in_d = 1
      self.conv_layers = nn.ModuleList()
      for i, cl in enumerate(conv_layers):
         assert len(cl) == 3, "invalid conv definition: " + str(cl)
         (dim, k, stride) = cl
         self.conv_layers.append(
            block(
               in_d,
               dim,
               k,
               stride,
               is_layer_norm=mode == "layer_norm",
               is_group_norm=mode == "default" and i == 0,
               conv_bias=conv_bias,
            )
         )
         in_d = dim
   def forward(self, x):
      # BxT -> BxCxT
      x = x.unsqueeze(1)
      for conv in self.conv_layers:
         x = conv(x)
      return x
    def __init__(
        self,
        conv_layers: List[Tuple[int, int, int]],
        dropout: float = 0.0,
        mode: str = "default",
        conv_bias: bool = False,
    ):
        super().__init__()
        assert mode in {"default", "layer_norm"}
        def block(
            n_in,
            n_out,
            k,
            stride,
            is_layer_norm=False,
            is_group_norm=False,
            conv_bias=False,
        ):
            def make_conv():
                conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
                nn.init.kaiming_normal_(conv.weight)
                return conv
            assert (
                is_layer_norm and is_group_norm
            ) == False, "layer norm and group norm are exclusive"
            if is_layer_norm:
                return nn.Sequential(
                    make_conv(),
                    nn.Dropout(p=dropout),
                    nn.Sequential(
                        TransposeLast(),
                        Fp32LayerNorm(dim, elementwise_affine=True),
                        TransposeLast(),
                    ),
                    nn.GELU(),
                )
            elif is_group_norm:
                return nn.Sequential(
                    make_conv(),
                    nn.Dropout(p=dropout),
                    Fp32GroupNorm(dim, dim, affine=True),
                    nn.GELU(),
                )
            else:
                return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
        in_d = 1
        self.conv_layers = nn.ModuleList()
        for i, cl in enumerate(conv_layers):
            assert len(cl) == 3, "invalid conv definition: " + str(cl)
            (dim, k, stride) = cl
            self.conv_layers.append(
                block(
                    in_d,
                    dim,
                    k,
                    stride,
                    is_layer_norm=mode == "layer_norm",
                    is_group_norm=mode == "default" and i == 0,
                    conv_bias=conv_bias,
                )
            )
            in_d = dim
    def forward(self, x):
        # BxT -> BxCxT
        x = x.unsqueeze(1)
        for conv in self.conv_layers:
            x = conv(x)
        return x
def compute_mask_indices(
    shape: Tuple[int, int],
@@ -254,11 +258,7 @@
            mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
            mask_idc = np.asarray(
                [
                    mask_idc[j] + offset
                    for j in range(len(mask_idc))
                    for offset in range(lengths[j])
                ]
                [mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]
            )
        mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
@@ -269,9 +269,7 @@
            mask_idc = np.random.choice(mask_idc, min_len, replace=False)
        if mask_dropout > 0:
            num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
            mask_idc = np.random.choice(
                mask_idc, len(mask_idc) - num_holes, replace=False
            )
            mask_idc = np.random.choice(mask_idc, len(mask_idc) - num_holes, replace=False)
        mask[i, mask_idc] = True
@@ -288,8 +286,8 @@
    @staticmethod
    def backward(ctx, grad):
        return grad * ctx.scale, None
def is_xla_tensor(tensor):
    return torch.is_tensor(tensor) and tensor.device.type == "xla"
@@ -303,4 +301,4 @@
        tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
    else:
        tensor[indices] = value
    return tensor
    return tensor