liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/emotion2vec/base.py
@@ -22,7 +22,6 @@
logger = logging.getLogger(__name__)
MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
@@ -69,13 +68,17 @@
            self.alibi_scale = nn.Parameter(
                torch.full(
                    (
                        (modality_cfg.prenet_depth + modality_cfg.model_depth)
                        if modality_cfg.learned_alibi_scale_per_layer
                        else 1,
                        (
                            (modality_cfg.prenet_depth + modality_cfg.model_depth)
                            if modality_cfg.learned_alibi_scale_per_layer
                            else 1
                        ),
                        1,
                        self.modality_cfg.num_alibi_heads
                        if modality_cfg.learned_alibi_scale_per_head
                        else 1,
                        (
                            self.modality_cfg.num_alibi_heads
                            if modality_cfg.learned_alibi_scale_per_head
                            else 1
                        ),
                        1,
                        1,
                    ),
@@ -96,9 +99,7 @@
                device="cpu",
            )
            self.alibi_bias = nn.Parameter(alibi_bias)
            self.get_alibi_bias = partial(
                _learned_alibi_bias, alibi_bias=self.alibi_bias
            )
            self.get_alibi_bias = partial(_learned_alibi_bias, alibi_bias=self.alibi_bias)
    def upgrade_state_dict_named(self, state_dict, name):
        k = f"{name}.alibi_scale"
@@ -147,9 +148,7 @@
            if self.local_grad_mult == 1.0:
                x = self.local_encoder(features)
            else:
                x = GradMultiply.apply(
                    self.local_encoder(features), self.local_grad_mult
                )
                x = GradMultiply.apply(self.local_encoder(features), self.local_grad_mult)
        else:
            with torch.no_grad():
                x = self.local_encoder(features)
@@ -188,8 +187,7 @@
                x = x.repeat_interleave(clone_batch, 0)
                if mask_seeds is not None:
                    clone_hash = [
                        int(hash((mask_seeds.seed, ind)) % 1e10)
                        for ind in range(clone_batch - 1)
                        int(hash((mask_seeds.seed, ind)) % 1e10) for ind in range(clone_batch - 1)
                    ]
                    clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
@@ -197,9 +195,7 @@
                    id = id.repeat_interleave(clone_batch, 0)
                    id = id.view(-1, clone_batch) + clone_hash.to(id)
                    id = id.view(-1)
                    mask_seeds = MaskSeed(
                        seed=mask_seeds.seed, update=mask_seeds.update, ids=id
                    )
                    mask_seeds = MaskSeed(seed=mask_seeds.seed, update=mask_seeds.update, ids=id)
                if padding_mask is not None:
                    padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
@@ -268,9 +264,7 @@
            x,
            masked_padding_mask,
            alibi_bias,
            alibi_scale[: self.modality_cfg.prenet_depth]
            if alibi_scale is not None
            else None,
            alibi_scale[: self.modality_cfg.prenet_depth] if alibi_scale is not None else None,
        )
        return {
@@ -278,9 +272,11 @@
            "local_features": local_features,
            "padding_mask": masked_padding_mask,
            "alibi_bias": alibi_bias,
            "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
            if alibi_scale is not None and alibi_scale.size(0) > 1
            else alibi_scale,
            "alibi_scale": (
                alibi_scale[self.modality_cfg.prenet_depth :]
                if alibi_scale is not None and alibi_scale.size(0) > 1
                else alibi_scale
            ),
            "encoder_mask": mask_info,
        }
@@ -405,9 +401,7 @@
                x = x * (1 - mask.type_as(x).unsqueeze(-1))
            else:
                num_masks = mask.sum().item()
                masks = x.new_empty(num_masks, x.size(-1)).normal_(
                    0, cfg.mask_noise_std
                )
                masks = x.new_empty(num_masks, x.size(-1)).normal_(0, cfg.mask_noise_std)
                x = index_put(x, mask, masks)
        if cfg.mask_channel_prob > 0:
            mask_channel = compute_mask_indices(
@@ -417,10 +411,7 @@
                cfg.mask_channel_length,
            )
            mask_channel = (
                torch.from_numpy(mask_channel)
                .to(x.device)
                .unsqueeze(1)
                .expand(-1, T, -1)
                torch.from_numpy(mask_channel).to(x.device).unsqueeze(1).expand(-1, T, -1)
            )
            x = index_put(x, mask_channel, 0)
        return x
@@ -445,9 +436,7 @@
    generator = None
    if mask_seed is not None:
        seed = int(
            hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
        )
        seed = int(hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6)
        generator = torch.Generator(device=x.device)
        generator.manual_seed(seed)
@@ -470,9 +459,7 @@
    ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
    return MaskInfo(
        x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
    )
    return MaskInfo(x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep)
def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
@@ -525,10 +512,7 @@
        # autoregressive model so we want a symmetric mask with 0 on the
        # diagonal and other wise linear decreasing valuees
        pos_bias = (
            torch.abs(
                torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
            )
            * -1
            torch.abs(torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)) * -1
        )
    elif dims == 2:
        if distance == "manhattan":
@@ -553,9 +537,7 @@
    else:
        raise Exception(f"unsupported number of alibi dims: {dims}")
    alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
        attn_heads, -1, -1
    )
    alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(attn_heads, -1, -1)
    return alibi_bias