| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | # Modified from https://github.com/ddlBoJack/emotion2vec/tree/main |
| | | |
| | | import logging |
| | | from functools import partial |
| | | import numpy as np |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | |
| | | |
| | | from funasr.models.emotion2vec.modules import AltBlock |
| | | from funasr.models.emotion2vec.audio import AudioEncoder |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | from omegaconf import OmegaConf |
| | | import os |
| | | import time |
| | | |
| | | logger = logging.getLogger(__name__) |
| | | import torch |
| | | import logging |
| | | import numpy as np |
| | | from functools import partial |
| | | from omegaconf import OmegaConf |
| | | import torch.nn.functional as F |
| | | from contextlib import contextmanager |
| | | from distutils.version import LooseVersion |
| | | |
| | | from funasr.register import tables |
| | | from funasr.models.emotion2vec.modules import AltBlock |
| | | from funasr.models.emotion2vec.audio import AudioEncoder |
| | | from funasr.utils.load_utils import load_audio_text_image_video |
| | | |
| | | |
| | | logger = logging.getLogger(__name__) |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | | else: |
| | | # Nothing to do if torch<1.6.0 |
| | | @contextmanager |
| | | def autocast(enabled=True): |
| | | yield |
| | | |
| | | |
| | | @tables.register("model_classes", "Emotion2vec") |
| | | class Emotion2vec(nn.Module): |
| | | |
| | | class Emotion2vec(torch.nn.Module): |
| | | """ |
| | | Author: Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chen |
| | | emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation |
| | | https://arxiv.org/abs/2312.15185 |
| | | """ |
| | | def __init__(self, **kwargs): |
| | | super().__init__() |
| | | # import pdb; pdb.set_trace() |
| | |
| | | self.cfg = cfg |
| | | |
| | | make_layer_norm = partial( |
| | | nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine") |
| | | torch.nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine") |
| | | ) |
| | | |
| | | def make_block(drop_path, dim=None, heads=None): |
| | |
| | | ) |
| | | |
| | | self.alibi_biases = {} |
| | | self.modality_encoders = nn.ModuleDict() |
| | | self.modality_encoders = torch.nn.ModuleDict() |
| | | |
| | | enc = AudioEncoder( |
| | | cfg.modalities.audio, |
| | |
| | | self.loss_beta = cfg.get("loss_beta") |
| | | self.loss_scale = cfg.get("loss_scale") |
| | | |
| | | self.dropout_input = nn.Dropout(cfg.get("dropout_input")) |
| | | self.dropout_input = torch.nn.Dropout(cfg.get("dropout_input")) |
| | | |
| | | dpr = np.linspace(cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth")) |
| | | |
| | | self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))]) |
| | | self.blocks = torch.nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))]) |
| | | |
| | | self.norm = None |
| | | if cfg.get("layer_norm_first"): |
| | |
| | | ) |
| | | return res |
| | | |
| | | def generate(self, |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | results = [] |
| | | output_dir = kwargs.get("output_dir") |
| | | if output_dir: |
| | | os.makedirs(output_dir, exist_ok=True) |
| | | for i, wav in enumerate(audio_sample_list): |
| | | source = wav.to(device=kwargs["device"]) |
| | | if self.cfg.normalize: |
| | |
| | | |
| | | result_i = {"key": key[i], "feats": feats} |
| | | results.append(result_i) |
| | | if output_dir: |
| | | np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats) |
| | | |
| | | return results, meta_data |