From 97d648c255316ec1fff5d82e46749076faabdd2d Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 15 一月 2024 15:41:25 +0800
Subject: [PATCH] code optimize, model update, scripts

---
 funasr/models/emotion2vec/model.py |   50 ++++++++++++++++++++++++++++----------------------
 1 files changed, 28 insertions(+), 22 deletions(-)

diff --git a/funasr/models/emotion2vec/model.py b/funasr/models/emotion2vec/model.py
index 315c1cc..de8113c 100644
--- a/funasr/models/emotion2vec/model.py
+++ b/funasr/models/emotion2vec/model.py
@@ -4,29 +4,35 @@
 #  MIT License  (https://opensource.org/licenses/MIT)
 # Modified from https://github.com/ddlBoJack/emotion2vec/tree/main
 
-import logging
 import os
-from functools import partial
-import numpy as np
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-from funasr.models.emotion2vec.modules import AltBlock
-from funasr.models.emotion2vec.audio import AudioEncoder
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-
-from omegaconf import OmegaConf
 import time
-
-logger = logging.getLogger(__name__)
+import torch
+import logging
+import numpy as np
+from functools import partial
+from omegaconf import OmegaConf
+import torch.nn.functional as F
+from contextlib import contextmanager
+from distutils.version import LooseVersion
 
 from funasr.register import tables
+from funasr.models.emotion2vec.modules import AltBlock
+from funasr.models.emotion2vec.audio import AudioEncoder
+from funasr.utils.load_utils import load_audio_text_image_video
+
+
+logger = logging.getLogger(__name__)
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+    # Nothing to do if torch<1.6.0
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
 
 @tables.register("model_classes", "Emotion2vec")
-class Emotion2vec(nn.Module):
+class Emotion2vec(torch.nn.Module):
     """
     Author: Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chen
     emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation
@@ -39,7 +45,7 @@
         self.cfg = cfg
 
         make_layer_norm = partial(
-            nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
+            torch.nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
         )
 
         def make_block(drop_path, dim=None, heads=None):
@@ -59,7 +65,7 @@
             )
 
         self.alibi_biases = {}
-        self.modality_encoders = nn.ModuleDict()
+        self.modality_encoders = torch.nn.ModuleDict()
 
         enc = AudioEncoder(
             cfg.modalities.audio,
@@ -77,11 +83,11 @@
         self.loss_beta = cfg.get("loss_beta")
         self.loss_scale = cfg.get("loss_scale")
 
-        self.dropout_input = nn.Dropout(cfg.get("dropout_input"))
+        self.dropout_input = torch.nn.Dropout(cfg.get("dropout_input"))
 
         dpr = np.linspace(cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth"))
 
-        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
+        self.blocks = torch.nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
 
         self.norm = None
         if cfg.get("layer_norm_first"):
@@ -183,7 +189,7 @@
         )
         return res
 
-    def generate(self,
+    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,

--
Gitblit v1.9.1