shixian.shi
2024-01-15 55c09aeaa25b4bb88a50e09ba68fa6ff00a6d676
funasr/models/emotion2vec/model.py
@@ -1,27 +1,43 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# Modified from https://github.com/ddlBoJack/emotion2vec/tree/main
import logging
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.emotion2vec.modules import AltBlock
from funasr.models.emotion2vec.audio import AudioEncoder
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from omegaconf import OmegaConf
import os
import time
logger = logging.getLogger(__name__)
import torch
import logging
import numpy as np
from functools import partial
from omegaconf import OmegaConf
import torch.nn.functional as F
from contextlib import contextmanager
from distutils.version import LooseVersion
from funasr.register import tables
from funasr.models.emotion2vec.modules import AltBlock
from funasr.models.emotion2vec.audio import AudioEncoder
from funasr.utils.load_utils import load_audio_text_image_video
logger = logging.getLogger(__name__)
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
@tables.register("model_classes", "Emotion2vec")
class Emotion2vec(nn.Module):
class Emotion2vec(torch.nn.Module):
    """
    Author: Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chen
    emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation
    https://arxiv.org/abs/2312.15185
    """
    def __init__(self, **kwargs):
        super().__init__()
        # import pdb; pdb.set_trace()
@@ -29,7 +45,7 @@
        self.cfg = cfg
        make_layer_norm = partial(
            nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
            torch.nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
        )
        def make_block(drop_path, dim=None, heads=None):
@@ -49,7 +65,7 @@
            )
        self.alibi_biases = {}
        self.modality_encoders = nn.ModuleDict()
        self.modality_encoders = torch.nn.ModuleDict()
        enc = AudioEncoder(
            cfg.modalities.audio,
@@ -67,11 +83,11 @@
        self.loss_beta = cfg.get("loss_beta")
        self.loss_scale = cfg.get("loss_scale")
        self.dropout_input = nn.Dropout(cfg.get("dropout_input"))
        self.dropout_input = torch.nn.Dropout(cfg.get("dropout_input"))
        dpr = np.linspace(cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth"))
        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
        self.blocks = torch.nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
        self.norm = None
        if cfg.get("layer_norm_first"):
@@ -173,7 +189,7 @@
        )
        return res
    def generate(self,
    def inference(self,
                 data_in,
                 data_lengths=None,
                 key: list = None,
@@ -196,6 +212,9 @@
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        results = []
        output_dir = kwargs.get("output_dir")
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
        for i, wav in enumerate(audio_sample_list):
            source = wav.to(device=kwargs["device"])
            if self.cfg.normalize:
@@ -211,5 +230,7 @@
            
            result_i = {"key": key[i], "feats": feats}
            results.append(result_i)
            if output_dir:
                np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats)
            
        return results, meta_data