From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/emotion2vec/base.py | 70 +++++++++++++----------------------
1 files changed, 26 insertions(+), 44 deletions(-)
diff --git a/funasr/models/emotion2vec/base.py b/funasr/models/emotion2vec/base.py
index cd87a99..1b4301e 100644
--- a/funasr/models/emotion2vec/base.py
+++ b/funasr/models/emotion2vec/base.py
@@ -22,7 +22,6 @@
logger = logging.getLogger(__name__)
-
MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
@@ -69,13 +68,17 @@
self.alibi_scale = nn.Parameter(
torch.full(
(
- (modality_cfg.prenet_depth + modality_cfg.model_depth)
- if modality_cfg.learned_alibi_scale_per_layer
- else 1,
+ (
+ (modality_cfg.prenet_depth + modality_cfg.model_depth)
+ if modality_cfg.learned_alibi_scale_per_layer
+ else 1
+ ),
1,
- self.modality_cfg.num_alibi_heads
- if modality_cfg.learned_alibi_scale_per_head
- else 1,
+ (
+ self.modality_cfg.num_alibi_heads
+ if modality_cfg.learned_alibi_scale_per_head
+ else 1
+ ),
1,
1,
),
@@ -96,9 +99,7 @@
device="cpu",
)
self.alibi_bias = nn.Parameter(alibi_bias)
- self.get_alibi_bias = partial(
- _learned_alibi_bias, alibi_bias=self.alibi_bias
- )
+ self.get_alibi_bias = partial(_learned_alibi_bias, alibi_bias=self.alibi_bias)
def upgrade_state_dict_named(self, state_dict, name):
k = f"{name}.alibi_scale"
@@ -147,9 +148,7 @@
if self.local_grad_mult == 1.0:
x = self.local_encoder(features)
else:
- x = GradMultiply.apply(
- self.local_encoder(features), self.local_grad_mult
- )
+ x = GradMultiply.apply(self.local_encoder(features), self.local_grad_mult)
else:
with torch.no_grad():
x = self.local_encoder(features)
@@ -188,8 +187,7 @@
x = x.repeat_interleave(clone_batch, 0)
if mask_seeds is not None:
clone_hash = [
- int(hash((mask_seeds.seed, ind)) % 1e10)
- for ind in range(clone_batch - 1)
+ int(hash((mask_seeds.seed, ind)) % 1e10) for ind in range(clone_batch - 1)
]
clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
@@ -197,9 +195,7 @@
id = id.repeat_interleave(clone_batch, 0)
id = id.view(-1, clone_batch) + clone_hash.to(id)
id = id.view(-1)
- mask_seeds = MaskSeed(
- seed=mask_seeds.seed, update=mask_seeds.update, ids=id
- )
+ mask_seeds = MaskSeed(seed=mask_seeds.seed, update=mask_seeds.update, ids=id)
if padding_mask is not None:
padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
@@ -268,9 +264,7 @@
x,
masked_padding_mask,
alibi_bias,
- alibi_scale[: self.modality_cfg.prenet_depth]
- if alibi_scale is not None
- else None,
+ alibi_scale[: self.modality_cfg.prenet_depth] if alibi_scale is not None else None,
)
return {
@@ -278,9 +272,11 @@
"local_features": local_features,
"padding_mask": masked_padding_mask,
"alibi_bias": alibi_bias,
- "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
- if alibi_scale is not None and alibi_scale.size(0) > 1
- else alibi_scale,
+ "alibi_scale": (
+ alibi_scale[self.modality_cfg.prenet_depth :]
+ if alibi_scale is not None and alibi_scale.size(0) > 1
+ else alibi_scale
+ ),
"encoder_mask": mask_info,
}
@@ -405,9 +401,7 @@
x = x * (1 - mask.type_as(x).unsqueeze(-1))
else:
num_masks = mask.sum().item()
- masks = x.new_empty(num_masks, x.size(-1)).normal_(
- 0, cfg.mask_noise_std
- )
+ masks = x.new_empty(num_masks, x.size(-1)).normal_(0, cfg.mask_noise_std)
x = index_put(x, mask, masks)
if cfg.mask_channel_prob > 0:
mask_channel = compute_mask_indices(
@@ -417,10 +411,7 @@
cfg.mask_channel_length,
)
mask_channel = (
- torch.from_numpy(mask_channel)
- .to(x.device)
- .unsqueeze(1)
- .expand(-1, T, -1)
+ torch.from_numpy(mask_channel).to(x.device).unsqueeze(1).expand(-1, T, -1)
)
x = index_put(x, mask_channel, 0)
return x
@@ -445,9 +436,7 @@
generator = None
if mask_seed is not None:
- seed = int(
- hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
- )
+ seed = int(hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6)
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
@@ -470,9 +459,7 @@
ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
- return MaskInfo(
- x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
- )
+ return MaskInfo(x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep)
def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
@@ -525,10 +512,7 @@
# autoregressive model so we want a symmetric mask with 0 on the
# diagonal and other wise linear decreasing valuees
pos_bias = (
- torch.abs(
- torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
- )
- * -1
+ torch.abs(torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)) * -1
)
elif dims == 2:
if distance == "manhattan":
@@ -553,9 +537,7 @@
else:
raise Exception(f"unsupported number of alibi dims: {dims}")
- alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
- attn_heads, -1, -1
- )
+ alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(attn_heads, -1, -1)
return alibi_bias
--
Gitblit v1.9.1