From cf7f9a06c8067033a8f113591f9f8d96a3fbc3dd Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 30 一月 2024 18:57:19 +0800
Subject: [PATCH] funasr1.0.4 emotion2vec finetuned

---
 examples/industrial_data_pretraining/emotion2vec/demo.py  |    5 +++--
 examples/industrial_data_pretraining/emotion2vec/infer.sh |    4 +++-
 funasr/models/emotion2vec/model.py                        |   31 ++++++++++++++++++++++++++-----
 3 files changed, 32 insertions(+), 8 deletions(-)

diff --git a/examples/industrial_data_pretraining/emotion2vec/demo.py b/examples/industrial_data_pretraining/emotion2vec/demo.py
index 3841089..66c25cf 100644
--- a/examples/industrial_data_pretraining/emotion2vec/demo.py
+++ b/examples/industrial_data_pretraining/emotion2vec/demo.py
@@ -5,8 +5,9 @@
 
 from funasr import AutoModel
 
-model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.4")
+# model="damo/emotion2vec_base"
+model = AutoModel(model="iic/emotion2vec_base_finetuned", model_revision="v2.0.4")
 
 wav_file = f"{model.model_path}/example/test.wav"
-res = model.generate(wav_file, output_dir="./outputs", granularity="utterance")
+res = model.generate(wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False)
 print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/emotion2vec/infer.sh b/examples/industrial_data_pretraining/emotion2vec/infer.sh
index c46b819..df990b9 100644
--- a/examples/industrial_data_pretraining/emotion2vec/infer.sh
+++ b/examples/industrial_data_pretraining/emotion2vec/infer.sh
@@ -1,5 +1,6 @@
 
-model="damo/emotion2vec_base"
+#model="damo/emotion2vec_base"
+model="iic/emotion2vec_base_finetuned"
 model_revision="v2.0.4"
 
 python funasr/bin/inference.py \
@@ -7,4 +8,5 @@
 +model_revision=${model_revision} \
 +input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
 +output_dir="./outputs/debug" \
++extract_embedding=False \
 +device="cpu" \
diff --git a/funasr/models/emotion2vec/model.py b/funasr/models/emotion2vec/model.py
index de8113c..58b9f39 100644
--- a/funasr/models/emotion2vec/model.py
+++ b/funasr/models/emotion2vec/model.py
@@ -93,7 +93,10 @@
         if cfg.get("layer_norm_first"):
             self.norm = make_layer_norm(cfg.get("embed_dim"))
 
-
+        vocab_size = kwargs.get("vocab_size", -1)
+        self.proj = None
+        if vocab_size > 0:
+            self.proj = torch.nn.Linear(cfg.get("embed_dim"), vocab_size)
 
 
     def forward(
@@ -204,6 +207,9 @@
         #     assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
         #     assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
         granularity = kwargs.get("granularity", "utterance")
+        extract_embedding = kwargs.get("extract_embedding", True)
+        if self.proj is None:
+            extract_embedding = True
         meta_data = {}
         # extract fbank feats
         time1 = time.perf_counter()
@@ -211,6 +217,8 @@
                                                         data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
         time2 = time.perf_counter()
         meta_data["load_data"] = f"{time2 - time1:0.3f}"
+        meta_data["batch_data_time"] = len(audio_sample_list[0])/kwargs.get("fs", 16000)
+        
         results = []
         output_dir = kwargs.get("output_dir")
         if output_dir:
@@ -222,15 +230,28 @@
             source = source.view(1, -1)
 
             feats = self.extract_features(source, padding_mask=None)
+            x = feats['x']
             feats = feats['x'].squeeze(0).cpu().numpy()
             if granularity == 'frame':
                 feats = feats
             elif granularity == 'utterance':
                 feats = np.mean(feats, axis=0)
-            
-            result_i = {"key": key[i], "feats": feats}
-            results.append(result_i)
-            if output_dir:
+                
+            if output_dir and extract_embedding:
                 np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats)
+
+            labels = tokenizer.token_list if tokenizer is not None else []
+            scores = []
+            if self.proj:
+                x = x.mean(dim=1)
+                x = self.proj(x)
+                x = torch.softmax(x, dim=-1)
+                scores = x[0].tolist()
+
+            result_i = {"key": key[i], "labels": labels, "scores": scores}
+            if extract_embedding:
+                result_i["feats"] = feats
+            results.append(result_i)
+
             
         return results, meta_data
\ No newline at end of file

--
Gitblit v1.9.1