From f57b68121a526baea43b2e93f4540d8a2995f633 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 15:15:24 +0800
Subject: [PATCH] batch

---
 funasr/models/emotion2vec/model.py |   61 ++++++++++++++++--------------
 1 files changed, 32 insertions(+), 29 deletions(-)

diff --git a/funasr/models/emotion2vec/model.py b/funasr/models/emotion2vec/model.py
index 58b9f39..48b8716 100644
--- a/funasr/models/emotion2vec/model.py
+++ b/funasr/models/emotion2vec/model.py
@@ -38,6 +38,7 @@
     emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation
     https://arxiv.org/abs/2312.15185
     """
+
     def __init__(self, **kwargs):
         super().__init__()
         # import pdb; pdb.set_trace()
@@ -75,7 +76,7 @@
             cfg.get("layer_norm_first"),
             self.alibi_biases,
         )
-        self.modality_encoders['AUDIO'] = enc
+        self.modality_encoders["AUDIO"] = enc
 
         self.ema = None
 
@@ -85,7 +86,9 @@
 
         self.dropout_input = torch.nn.Dropout(cfg.get("dropout_input"))
 
-        dpr = np.linspace(cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth"))
+        dpr = np.linspace(
+            cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth")
+        )
 
         self.blocks = torch.nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
 
@@ -97,7 +100,6 @@
         self.proj = None
         if vocab_size > 0:
             self.proj = torch.nn.Linear(cfg.get("embed_dim"), vocab_size)
-
 
     def forward(
         self,
@@ -114,7 +116,7 @@
         **kwargs,
     ):
 
-        feature_extractor = self.modality_encoders['AUDIO']
+        feature_extractor = self.modality_encoders["AUDIO"]
 
         mask_seeds = None
 
@@ -146,11 +148,7 @@
             ):
                 ab = masked_alibi_bias
                 if ab is not None and alibi_scale is not None:
-                    scale = (
-                        alibi_scale[i]
-                        if alibi_scale.size(0) > 1
-                        else alibi_scale.squeeze(0)
-                    )
+                    scale = alibi_scale[i] if alibi_scale.size(0) > 1 else alibi_scale.squeeze(0)
                     ab = ab * scale.type_as(ab)
 
                 x, lr = blk(
@@ -192,15 +190,16 @@
         )
         return res
 
-    def inference(self,
-                 data_in,
-                 data_lengths=None,
-                 key: list = None,
-                 tokenizer=None,
-                 frontend=None,
-                 **kwargs,
-                 ):
-    
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+
         # if source_file.endswith('.wav'):
         #     wav, sr = sf.read(source_file)
         #     channel = sf.info(source_file).channels
@@ -213,12 +212,17 @@
         meta_data = {}
         # extract fbank feats
         time1 = time.perf_counter()
-        audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000),
-                                                        data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
+        audio_sample_list = load_audio_text_image_video(
+            data_in,
+            fs=16000,
+            audio_fs=kwargs.get("fs", 16000),
+            data_type=kwargs.get("data_type", "sound"),
+            tokenizer=tokenizer,
+        )
         time2 = time.perf_counter()
         meta_data["load_data"] = f"{time2 - time1:0.3f}"
-        meta_data["batch_data_time"] = len(audio_sample_list[0])/kwargs.get("fs", 16000)
-        
+        meta_data["batch_data_time"] = len(audio_sample_list[0]) / kwargs.get("fs", 16000)
+
         results = []
         output_dir = kwargs.get("output_dir")
         if output_dir:
@@ -230,13 +234,13 @@
             source = source.view(1, -1)
 
             feats = self.extract_features(source, padding_mask=None)
-            x = feats['x']
-            feats = feats['x'].squeeze(0).cpu().numpy()
-            if granularity == 'frame':
+            x = feats["x"]
+            feats = feats["x"].squeeze(0).cpu().numpy()
+            if granularity == "frame":
                 feats = feats
-            elif granularity == 'utterance':
+            elif granularity == "utterance":
                 feats = np.mean(feats, axis=0)
-                
+
             if output_dir and extract_embedding:
                 np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats)
 
@@ -253,5 +257,4 @@
                 result_i["feats"] = feats
             results.append(result_i)
 
-            
-        return results, meta_data
\ No newline at end of file
+        return results, meta_data

--
Gitblit v1.9.1