From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/emotion2vec/base.py |   70 +++++++++++++----------------------
 1 files changed, 26 insertions(+), 44 deletions(-)

diff --git a/funasr/models/emotion2vec/base.py b/funasr/models/emotion2vec/base.py
index cd87a99..1b4301e 100644
--- a/funasr/models/emotion2vec/base.py
+++ b/funasr/models/emotion2vec/base.py
@@ -22,7 +22,6 @@
 logger = logging.getLogger(__name__)
 
 
-
 MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
 MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
 
@@ -69,13 +68,17 @@
             self.alibi_scale = nn.Parameter(
                 torch.full(
                     (
-                        (modality_cfg.prenet_depth + modality_cfg.model_depth)
-                        if modality_cfg.learned_alibi_scale_per_layer
-                        else 1,
+                        (
+                            (modality_cfg.prenet_depth + modality_cfg.model_depth)
+                            if modality_cfg.learned_alibi_scale_per_layer
+                            else 1
+                        ),
                         1,
-                        self.modality_cfg.num_alibi_heads
-                        if modality_cfg.learned_alibi_scale_per_head
-                        else 1,
+                        (
+                            self.modality_cfg.num_alibi_heads
+                            if modality_cfg.learned_alibi_scale_per_head
+                            else 1
+                        ),
                         1,
                         1,
                     ),
@@ -96,9 +99,7 @@
                 device="cpu",
             )
             self.alibi_bias = nn.Parameter(alibi_bias)
-            self.get_alibi_bias = partial(
-                _learned_alibi_bias, alibi_bias=self.alibi_bias
-            )
+            self.get_alibi_bias = partial(_learned_alibi_bias, alibi_bias=self.alibi_bias)
 
     def upgrade_state_dict_named(self, state_dict, name):
         k = f"{name}.alibi_scale"
@@ -147,9 +148,7 @@
             if self.local_grad_mult == 1.0:
                 x = self.local_encoder(features)
             else:
-                x = GradMultiply.apply(
-                    self.local_encoder(features), self.local_grad_mult
-                )
+                x = GradMultiply.apply(self.local_encoder(features), self.local_grad_mult)
         else:
             with torch.no_grad():
                 x = self.local_encoder(features)
@@ -188,8 +187,7 @@
                 x = x.repeat_interleave(clone_batch, 0)
                 if mask_seeds is not None:
                     clone_hash = [
-                        int(hash((mask_seeds.seed, ind)) % 1e10)
-                        for ind in range(clone_batch - 1)
+                        int(hash((mask_seeds.seed, ind)) % 1e10) for ind in range(clone_batch - 1)
                     ]
                     clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
 
@@ -197,9 +195,7 @@
                     id = id.repeat_interleave(clone_batch, 0)
                     id = id.view(-1, clone_batch) + clone_hash.to(id)
                     id = id.view(-1)
-                    mask_seeds = MaskSeed(
-                        seed=mask_seeds.seed, update=mask_seeds.update, ids=id
-                    )
+                    mask_seeds = MaskSeed(seed=mask_seeds.seed, update=mask_seeds.update, ids=id)
                 if padding_mask is not None:
                     padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
 
@@ -268,9 +264,7 @@
             x,
             masked_padding_mask,
             alibi_bias,
-            alibi_scale[: self.modality_cfg.prenet_depth]
-            if alibi_scale is not None
-            else None,
+            alibi_scale[: self.modality_cfg.prenet_depth] if alibi_scale is not None else None,
         )
 
         return {
@@ -278,9 +272,11 @@
             "local_features": local_features,
             "padding_mask": masked_padding_mask,
             "alibi_bias": alibi_bias,
-            "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
-            if alibi_scale is not None and alibi_scale.size(0) > 1
-            else alibi_scale,
+            "alibi_scale": (
+                alibi_scale[self.modality_cfg.prenet_depth :]
+                if alibi_scale is not None and alibi_scale.size(0) > 1
+                else alibi_scale
+            ),
             "encoder_mask": mask_info,
         }
 
@@ -405,9 +401,7 @@
                 x = x * (1 - mask.type_as(x).unsqueeze(-1))
             else:
                 num_masks = mask.sum().item()
-                masks = x.new_empty(num_masks, x.size(-1)).normal_(
-                    0, cfg.mask_noise_std
-                )
+                masks = x.new_empty(num_masks, x.size(-1)).normal_(0, cfg.mask_noise_std)
                 x = index_put(x, mask, masks)
         if cfg.mask_channel_prob > 0:
             mask_channel = compute_mask_indices(
@@ -417,10 +411,7 @@
                 cfg.mask_channel_length,
             )
             mask_channel = (
-                torch.from_numpy(mask_channel)
-                .to(x.device)
-                .unsqueeze(1)
-                .expand(-1, T, -1)
+                torch.from_numpy(mask_channel).to(x.device).unsqueeze(1).expand(-1, T, -1)
             )
             x = index_put(x, mask_channel, 0)
         return x
@@ -445,9 +436,7 @@
 
     generator = None
     if mask_seed is not None:
-        seed = int(
-            hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
-        )
+        seed = int(hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6)
         generator = torch.Generator(device=x.device)
         generator.manual_seed(seed)
 
@@ -470,9 +459,7 @@
 
     ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
 
-    return MaskInfo(
-        x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
-    )
+    return MaskInfo(x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep)
 
 
 def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
@@ -525,10 +512,7 @@
         # autoregressive model so we want a symmetric mask with 0 on the
         # diagonal and other wise linear decreasing valuees
         pos_bias = (
-            torch.abs(
-                torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
-            )
-            * -1
+            torch.abs(torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)) * -1
         )
     elif dims == 2:
         if distance == "manhattan":
@@ -553,9 +537,7 @@
     else:
         raise Exception(f"unsupported number of alibi dims: {dims}")
 
-    alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
-        attn_heads, -1, -1
-    )
+    alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(attn_heads, -1, -1)
 
     return alibi_bias
 

--
Gitblit v1.9.1