From 94de39dde2e616a01683c518023d0fab72b4e103 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 22:21:50 +0800
Subject: [PATCH] aishell example
---
funasr/models/emotion2vec/model.py | 94 ++++++++++++++++++++++++++++++++++-------------
1 files changed, 68 insertions(+), 26 deletions(-)
diff --git a/funasr/models/emotion2vec/model.py b/funasr/models/emotion2vec/model.py
index e882b6e..58b9f39 100644
--- a/funasr/models/emotion2vec/model.py
+++ b/funasr/models/emotion2vec/model.py
@@ -1,27 +1,43 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+# Modified from https://github.com/ddlBoJack/emotion2vec/tree/main
-import logging
-from functools import partial
-import numpy as np
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-from funasr.models.emotion2vec.modules import AltBlock
-from funasr.models.emotion2vec.audio import AudioEncoder
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-
-from omegaconf import OmegaConf
+import os
import time
-
-logger = logging.getLogger(__name__)
+import torch
+import logging
+import numpy as np
+from functools import partial
+from omegaconf import OmegaConf
+import torch.nn.functional as F
+from contextlib import contextmanager
+from distutils.version import LooseVersion
from funasr.register import tables
+from funasr.models.emotion2vec.modules import AltBlock
+from funasr.models.emotion2vec.audio import AudioEncoder
+from funasr.utils.load_utils import load_audio_text_image_video
+
+
+logger = logging.getLogger(__name__)
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
@tables.register("model_classes", "Emotion2vec")
-class Emotion2vec(nn.Module):
-
+class Emotion2vec(torch.nn.Module):
+ """
+ Author: Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chen
+ emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation
+ https://arxiv.org/abs/2312.15185
+ """
def __init__(self, **kwargs):
super().__init__()
# import pdb; pdb.set_trace()
@@ -29,7 +45,7 @@
self.cfg = cfg
make_layer_norm = partial(
- nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
+ torch.nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
)
def make_block(drop_path, dim=None, heads=None):
@@ -49,7 +65,7 @@
)
self.alibi_biases = {}
- self.modality_encoders = nn.ModuleDict()
+ self.modality_encoders = torch.nn.ModuleDict()
enc = AudioEncoder(
cfg.modalities.audio,
@@ -67,17 +83,20 @@
self.loss_beta = cfg.get("loss_beta")
self.loss_scale = cfg.get("loss_scale")
- self.dropout_input = nn.Dropout(cfg.get("dropout_input"))
+ self.dropout_input = torch.nn.Dropout(cfg.get("dropout_input"))
dpr = np.linspace(cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth"))
- self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
+ self.blocks = torch.nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
self.norm = None
if cfg.get("layer_norm_first"):
self.norm = make_layer_norm(cfg.get("embed_dim"))
-
+ vocab_size = kwargs.get("vocab_size", -1)
+ self.proj = None
+ if vocab_size > 0:
+ self.proj = torch.nn.Linear(cfg.get("embed_dim"), vocab_size)
def forward(
@@ -173,7 +192,7 @@
)
return res
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list = None,
@@ -188,6 +207,9 @@
# assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
# assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
granularity = kwargs.get("granularity", "utterance")
+ extract_embedding = kwargs.get("extract_embedding", True)
+ if self.proj is None:
+ extract_embedding = True
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
@@ -195,7 +217,12 @@
data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ meta_data["batch_data_time"] = len(audio_sample_list[0])/kwargs.get("fs", 16000)
+
results = []
+ output_dir = kwargs.get("output_dir")
+ if output_dir:
+ os.makedirs(output_dir, exist_ok=True)
for i, wav in enumerate(audio_sample_list):
source = wav.to(device=kwargs["device"])
if self.cfg.normalize:
@@ -203,13 +230,28 @@
source = source.view(1, -1)
feats = self.extract_features(source, padding_mask=None)
+ x = feats['x']
feats = feats['x'].squeeze(0).cpu().numpy()
if granularity == 'frame':
feats = feats
elif granularity == 'utterance':
feats = np.mean(feats, axis=0)
-
- result_i = {"key": key[i], "feats": feats}
+
+ if output_dir and extract_embedding:
+ np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats)
+
+ labels = tokenizer.token_list if tokenizer is not None else []
+ scores = []
+ if self.proj:
+ x = x.mean(dim=1)
+ x = self.proj(x)
+ x = torch.softmax(x, dim=-1)
+ scores = x[0].tolist()
+
+ result_i = {"key": key[i], "labels": labels, "scores": scores}
+ if extract_embedding:
+ result_i["feats"] = feats
results.append(result_i)
+
return results, meta_data
\ No newline at end of file
--
Gitblit v1.9.1