游雁
2024-02-19 94de39dde2e616a01683c518023d0fab72b4e103
funasr/models/emotion2vec/model.py
@@ -1,27 +1,43 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# Modified from https://github.com/ddlBoJack/emotion2vec/tree/main
import logging
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.emotion2vec.modules import AltBlock
from funasr.models.emotion2vec.audio import AudioEncoder
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from omegaconf import OmegaConf
import os
import time
logger = logging.getLogger(__name__)
import torch
import logging
import numpy as np
from functools import partial
from omegaconf import OmegaConf
import torch.nn.functional as F
from contextlib import contextmanager
from distutils.version import LooseVersion
from funasr.register import tables
from funasr.models.emotion2vec.modules import AltBlock
from funasr.models.emotion2vec.audio import AudioEncoder
from funasr.utils.load_utils import load_audio_text_image_video
logger = logging.getLogger(__name__)
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
@tables.register("model_classes", "Emotion2vec")
class Emotion2vec(nn.Module):
class Emotion2vec(torch.nn.Module):
    """
    Author: Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chen
    emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation
    https://arxiv.org/abs/2312.15185
    """
    def __init__(self, **kwargs):
        super().__init__()
        # import pdb; pdb.set_trace()
@@ -29,7 +45,7 @@
        self.cfg = cfg
        make_layer_norm = partial(
            nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
            torch.nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
        )
        def make_block(drop_path, dim=None, heads=None):
@@ -49,7 +65,7 @@
            )
        self.alibi_biases = {}
        self.modality_encoders = nn.ModuleDict()
        self.modality_encoders = torch.nn.ModuleDict()
        enc = AudioEncoder(
            cfg.modalities.audio,
@@ -67,17 +83,20 @@
        self.loss_beta = cfg.get("loss_beta")
        self.loss_scale = cfg.get("loss_scale")
        self.dropout_input = nn.Dropout(cfg.get("dropout_input"))
        self.dropout_input = torch.nn.Dropout(cfg.get("dropout_input"))
        dpr = np.linspace(cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth"))
        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
        self.blocks = torch.nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
        self.norm = None
        if cfg.get("layer_norm_first"):
            self.norm = make_layer_norm(cfg.get("embed_dim"))
        vocab_size = kwargs.get("vocab_size", -1)
        self.proj = None
        if vocab_size > 0:
            self.proj = torch.nn.Linear(cfg.get("embed_dim"), vocab_size)
    def forward(
@@ -173,7 +192,7 @@
        )
        return res
    def generate(self,
    def inference(self,
                 data_in,
                 data_lengths=None,
                 key: list = None,
@@ -188,6 +207,9 @@
        #     assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
        #     assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
        granularity = kwargs.get("granularity", "utterance")
        extract_embedding = kwargs.get("extract_embedding", True)
        if self.proj is None:
            extract_embedding = True
        meta_data = {}
        # extract fbank feats
        time1 = time.perf_counter()
@@ -195,7 +217,12 @@
                                                        data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        meta_data["batch_data_time"] = len(audio_sample_list[0])/kwargs.get("fs", 16000)
        results = []
        output_dir = kwargs.get("output_dir")
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
        for i, wav in enumerate(audio_sample_list):
            source = wav.to(device=kwargs["device"])
            if self.cfg.normalize:
@@ -203,13 +230,28 @@
            source = source.view(1, -1)
            feats = self.extract_features(source, padding_mask=None)
            x = feats['x']
            feats = feats['x'].squeeze(0).cpu().numpy()
            if granularity == 'frame':
                feats = feats
            elif granularity == 'utterance':
                feats = np.mean(feats, axis=0)
            result_i = {"key": key[i], "feats": feats}
            if output_dir and extract_embedding:
                np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats)
            labels = tokenizer.token_list if tokenizer is not None else []
            scores = []
            if self.proj:
                x = x.mean(dim=1)
                x = self.proj(x)
                x = torch.softmax(x, dim=-1)
                scores = x[0].tolist()
            result_i = {"key": key[i], "labels": labels, "scores": scores}
            if extract_embedding:
                result_i["feats"] = feats
            results.append(result_i)
            
        return results, meta_data