From cf7f9a06c8067033a8f113591f9f8d96a3fbc3dd Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 30 一月 2024 18:57:19 +0800
Subject: [PATCH] funasr1.0.4 emotion2vec finetuned
---
funasr/models/emotion2vec/model.py | 31 ++++++++++++++++++++++++++-----
1 files changed, 26 insertions(+), 5 deletions(-)
diff --git a/funasr/models/emotion2vec/model.py b/funasr/models/emotion2vec/model.py
index de8113c..58b9f39 100644
--- a/funasr/models/emotion2vec/model.py
+++ b/funasr/models/emotion2vec/model.py
@@ -93,7 +93,10 @@
if cfg.get("layer_norm_first"):
self.norm = make_layer_norm(cfg.get("embed_dim"))
-
+ vocab_size = kwargs.get("vocab_size", -1)
+ self.proj = None
+ if vocab_size > 0:
+ self.proj = torch.nn.Linear(cfg.get("embed_dim"), vocab_size)
def forward(
@@ -204,6 +207,9 @@
# assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
# assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
granularity = kwargs.get("granularity", "utterance")
+ extract_embedding = kwargs.get("extract_embedding", True)
+ if self.proj is None:
+ extract_embedding = True
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
@@ -211,6 +217,8 @@
data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ meta_data["batch_data_time"] = len(audio_sample_list[0])/kwargs.get("fs", 16000)
+
results = []
output_dir = kwargs.get("output_dir")
if output_dir:
@@ -222,15 +230,28 @@
source = source.view(1, -1)
feats = self.extract_features(source, padding_mask=None)
+ x = feats['x']
feats = feats['x'].squeeze(0).cpu().numpy()
if granularity == 'frame':
feats = feats
elif granularity == 'utterance':
feats = np.mean(feats, axis=0)
-
- result_i = {"key": key[i], "feats": feats}
- results.append(result_i)
- if output_dir:
+
+ if output_dir and extract_embedding:
np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats)
+
+ labels = tokenizer.token_list if tokenizer is not None else []
+ scores = []
+ if self.proj:
+ x = x.mean(dim=1)
+ x = self.proj(x)
+ x = torch.softmax(x, dim=-1)
+ scores = x[0].tolist()
+
+ result_i = {"key": key[i], "labels": labels, "scores": scores}
+ if extract_embedding:
+ result_i["feats"] = feats
+ results.append(result_i)
+
return results, meta_data
\ No newline at end of file
--
Gitblit v1.9.1