From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/emotion2vec/model.py | 91 ++++++++++++++++++++++++++++++---------------
1 files changed, 61 insertions(+), 30 deletions(-)
diff --git a/funasr/models/emotion2vec/model.py b/funasr/models/emotion2vec/model.py
index de8113c..d18e184 100644
--- a/funasr/models/emotion2vec/model.py
+++ b/funasr/models/emotion2vec/model.py
@@ -38,6 +38,7 @@
emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation
https://arxiv.org/abs/2312.15185
"""
+
def __init__(self, **kwargs):
super().__init__()
# import pdb; pdb.set_trace()
@@ -75,7 +76,7 @@
cfg.get("layer_norm_first"),
self.alibi_biases,
)
- self.modality_encoders['AUDIO'] = enc
+ self.modality_encoders["AUDIO"] = enc
self.ema = None
@@ -85,7 +86,9 @@
self.dropout_input = torch.nn.Dropout(cfg.get("dropout_input"))
- dpr = np.linspace(cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth"))
+ dpr = np.linspace(
+ cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth")
+ )
self.blocks = torch.nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
@@ -93,8 +96,10 @@
if cfg.get("layer_norm_first"):
self.norm = make_layer_norm(cfg.get("embed_dim"))
-
-
+ vocab_size = kwargs.get("vocab_size", -1)
+ self.proj = None
+ if vocab_size > 0:
+ self.proj = torch.nn.Linear(cfg.get("embed_dim"), vocab_size)
def forward(
self,
@@ -111,7 +116,7 @@
**kwargs,
):
- feature_extractor = self.modality_encoders['AUDIO']
+ feature_extractor = self.modality_encoders["AUDIO"]
mask_seeds = None
@@ -143,11 +148,7 @@
):
ab = masked_alibi_bias
if ab is not None and alibi_scale is not None:
- scale = (
- alibi_scale[i]
- if alibi_scale.size(0) > 1
- else alibi_scale.squeeze(0)
- )
+ scale = alibi_scale[i] if alibi_scale.size(0) > 1 else alibi_scale.squeeze(0)
ab = ab * scale.type_as(ab)
x, lr = blk(
@@ -189,28 +190,39 @@
)
return res
- def inference(self,
- data_in,
- data_lengths=None,
- key: list = None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
-
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+
# if source_file.endswith('.wav'):
# wav, sr = sf.read(source_file)
# channel = sf.info(source_file).channels
# assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
# assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
granularity = kwargs.get("granularity", "utterance")
+ extract_embedding = kwargs.get("extract_embedding", True)
+ if self.proj is None:
+ extract_embedding = True
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
+ audio_sample_list = load_audio_text_image_video(
+ data_in,
+ fs=16000,
+ audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer,
+ )
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ meta_data["batch_data_time"] = len(audio_sample_list[0]) / kwargs.get("fs", 16000)
+
results = []
output_dir = kwargs.get("output_dir")
if output_dir:
@@ -222,15 +234,34 @@
source = source.view(1, -1)
feats = self.extract_features(source, padding_mask=None)
- feats = feats['x'].squeeze(0).cpu().numpy()
- if granularity == 'frame':
+ x = feats["x"]
+ feats = feats["x"].squeeze(0).cpu().numpy()
+ if granularity == "frame":
feats = feats
- elif granularity == 'utterance':
+ elif granularity == "utterance":
feats = np.mean(feats, axis=0)
-
- result_i = {"key": key[i], "feats": feats}
- results.append(result_i)
- if output_dir:
+
+ if output_dir and extract_embedding:
np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats)
-
- return results, meta_data
\ No newline at end of file
+
+ labels = tokenizer.token_list if tokenizer is not None else []
+ scores = []
+ if self.proj:
+ x = x.mean(dim=1)
+ x = self.proj(x)
+ for idx, lab in enumerate(labels):
+ x[:,idx] = -np.inf if lab.startswith("unuse") else x[:,idx]
+ x = torch.softmax(x, dim=-1)
+ scores = x[0].tolist()
+
+ select_label = [lb for lb in labels if not lb.startswith("unuse")]
+ select_score = [scores[idx] for idx, lb in enumerate(labels) if not lb.startswith("unuse")]
+
+ # result_i = {"key": key[i], "labels": labels, "scores": scores}
+ result_i = {"key": key[i], "labels": select_label, "scores": select_score}
+
+ if extract_embedding:
+ result_i["feats"] = feats
+ results.append(result_i)
+
+ return results, meta_data
--
Gitblit v1.9.1