From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/emotion2vec/model.py |  150 +++++++++++++++++++++++++++++++++----------------
 1 files changed, 101 insertions(+), 49 deletions(-)

diff --git a/funasr/models/emotion2vec/model.py b/funasr/models/emotion2vec/model.py
index e882b6e..d18e184 100644
--- a/funasr/models/emotion2vec/model.py
+++ b/funasr/models/emotion2vec/model.py
@@ -1,26 +1,43 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+# Modified from https://github.com/ddlBoJack/emotion2vec/tree/main
 
-import logging
-from functools import partial
-import numpy as np
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-from funasr.models.emotion2vec.modules import AltBlock
-from funasr.models.emotion2vec.audio import AudioEncoder
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-
-from omegaconf import OmegaConf
+import os
 import time
-
-logger = logging.getLogger(__name__)
+import torch
+import logging
+import numpy as np
+from functools import partial
+from omegaconf import OmegaConf
+import torch.nn.functional as F
+from contextlib import contextmanager
+from distutils.version import LooseVersion
 
 from funasr.register import tables
+from funasr.models.emotion2vec.modules import AltBlock
+from funasr.models.emotion2vec.audio import AudioEncoder
+from funasr.utils.load_utils import load_audio_text_image_video
+
+
+logger = logging.getLogger(__name__)
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+    # Nothing to do if torch<1.6.0
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
 
 @tables.register("model_classes", "Emotion2vec")
-class Emotion2vec(nn.Module):
+class Emotion2vec(torch.nn.Module):
+    """
+    Author: Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chen
+    emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation
+    https://arxiv.org/abs/2312.15185
+    """
 
     def __init__(self, **kwargs):
         super().__init__()
@@ -29,7 +46,7 @@
         self.cfg = cfg
 
         make_layer_norm = partial(
-            nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
+            torch.nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
         )
 
         def make_block(drop_path, dim=None, heads=None):
@@ -49,7 +66,7 @@
             )
 
         self.alibi_biases = {}
-        self.modality_encoders = nn.ModuleDict()
+        self.modality_encoders = torch.nn.ModuleDict()
 
         enc = AudioEncoder(
             cfg.modalities.audio,
@@ -59,7 +76,7 @@
             cfg.get("layer_norm_first"),
             self.alibi_biases,
         )
-        self.modality_encoders['AUDIO'] = enc
+        self.modality_encoders["AUDIO"] = enc
 
         self.ema = None
 
@@ -67,18 +84,22 @@
         self.loss_beta = cfg.get("loss_beta")
         self.loss_scale = cfg.get("loss_scale")
 
-        self.dropout_input = nn.Dropout(cfg.get("dropout_input"))
+        self.dropout_input = torch.nn.Dropout(cfg.get("dropout_input"))
 
-        dpr = np.linspace(cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth"))
+        dpr = np.linspace(
+            cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth")
+        )
 
-        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
+        self.blocks = torch.nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
 
         self.norm = None
         if cfg.get("layer_norm_first"):
             self.norm = make_layer_norm(cfg.get("embed_dim"))
 
-
-
+        vocab_size = kwargs.get("vocab_size", -1)
+        self.proj = None
+        if vocab_size > 0:
+            self.proj = torch.nn.Linear(cfg.get("embed_dim"), vocab_size)
 
     def forward(
         self,
@@ -95,7 +116,7 @@
         **kwargs,
     ):
 
-        feature_extractor = self.modality_encoders['AUDIO']
+        feature_extractor = self.modality_encoders["AUDIO"]
 
         mask_seeds = None
 
@@ -127,11 +148,7 @@
             ):
                 ab = masked_alibi_bias
                 if ab is not None and alibi_scale is not None:
-                    scale = (
-                        alibi_scale[i]
-                        if alibi_scale.size(0) > 1
-                        else alibi_scale.squeeze(0)
-                    )
+                    scale = alibi_scale[i] if alibi_scale.size(0) > 1 else alibi_scale.squeeze(0)
                     ab = ab * scale.type_as(ab)
 
                 x, lr = blk(
@@ -173,29 +190,43 @@
         )
         return res
 
-    def generate(self,
-                 data_in,
-                 data_lengths=None,
-                 key: list = None,
-                 tokenizer=None,
-                 frontend=None,
-                 **kwargs,
-                 ):
-    
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+
         # if source_file.endswith('.wav'):
         #     wav, sr = sf.read(source_file)
         #     channel = sf.info(source_file).channels
         #     assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
         #     assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
         granularity = kwargs.get("granularity", "utterance")
+        extract_embedding = kwargs.get("extract_embedding", True)
+        if self.proj is None:
+            extract_embedding = True
         meta_data = {}
         # extract fbank feats
         time1 = time.perf_counter()
-        audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000),
-                                                        data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
+        audio_sample_list = load_audio_text_image_video(
+            data_in,
+            fs=16000,
+            audio_fs=kwargs.get("fs", 16000),
+            data_type=kwargs.get("data_type", "sound"),
+            tokenizer=tokenizer,
+        )
         time2 = time.perf_counter()
         meta_data["load_data"] = f"{time2 - time1:0.3f}"
+        meta_data["batch_data_time"] = len(audio_sample_list[0]) / kwargs.get("fs", 16000)
+
         results = []
+        output_dir = kwargs.get("output_dir")
+        if output_dir:
+            os.makedirs(output_dir, exist_ok=True)
         for i, wav in enumerate(audio_sample_list):
             source = wav.to(device=kwargs["device"])
             if self.cfg.normalize:
@@ -203,13 +234,34 @@
             source = source.view(1, -1)
 
             feats = self.extract_features(source, padding_mask=None)
-            feats = feats['x'].squeeze(0).cpu().numpy()
-            if granularity == 'frame':
+            x = feats["x"]
+            feats = feats["x"].squeeze(0).cpu().numpy()
+            if granularity == "frame":
                 feats = feats
-            elif granularity == 'utterance':
+            elif granularity == "utterance":
                 feats = np.mean(feats, axis=0)
-            
-            result_i = {"key": key[i], "feats": feats}
+
+            if output_dir and extract_embedding:
+                np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats)
+
+            labels = tokenizer.token_list if tokenizer is not None else []
+            scores = []
+            if self.proj:
+                x = x.mean(dim=1)
+                x = self.proj(x)
+                for idx, lab in enumerate(labels):
+                    x[:,idx] = -np.inf if lab.startswith("unuse") else x[:,idx]
+                x = torch.softmax(x, dim=-1)
+                scores = x[0].tolist()
+
+            select_label = [lb for lb in labels if not lb.startswith("unuse")]
+            select_score = [scores[idx] for idx, lb in enumerate(labels) if not lb.startswith("unuse")]
+
+            # result_i = {"key": key[i], "labels": labels, "scores": scores}
+            result_i = {"key": key[i], "labels": select_label, "scores": select_score}
+
+            if extract_embedding:
+                result_i["feats"] = feats
             results.append(result_i)
-            
-        return results, meta_data
\ No newline at end of file
+
+        return results, meta_data

--
Gitblit v1.9.1