From fb176404cfeb40c053f4f42d01eb45c185d21ce2 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 08 一月 2024 16:20:45 +0800
Subject: [PATCH] funasr1.0 emotion2vec

---
 funasr/register.py                                        |   11 
 funasr/download/download_from_hub.py                      |   61 +
 funasr/models/emotion2vec/modules.py                      |  323 +++++++++
 funasr/models/emotion2vec/template.yaml                   |  113 +++
 funasr/models/emotion2vec/__init__.py                     |    0 
 funasr/train_utils/load_pretrained_model.py               |    1 
 funasr/bin/inference.py                                   |    3 
 funasr/models/emotion2vec/fairseq_modules.py              |  306 +++++++++
 examples/industrial_data_pretraining/emotion2vec/demo.py  |   11 
 funasr/models/emotion2vec/audio.py                        |  167 +++++
 funasr/models/emotion2vec/base.py                         |  639 +++++++++++++++++++
 examples/industrial_data_pretraining/emotion2vec/infer.sh |   14 
 funasr/models/emotion2vec/model.py                        |  215 ++++++
 funasr/models/emotion2vec/timm_modules.py                 |  100 +++
 14 files changed, 1,936 insertions(+), 28 deletions(-)

diff --git a/examples/industrial_data_pretraining/emotion2vec/demo.py b/examples/industrial_data_pretraining/emotion2vec/demo.py
new file mode 100644
index 0000000..b267e2b
--- /dev/null
+++ b/examples/industrial_data_pretraining/emotion2vec/demo.py
@@ -0,0 +1,11 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/emotion2vec_base")
+
+res = model(input="/Users/zhifu/Downloads/modelscope_models/emotion2vec_base/example/test.wav")
+print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/emotion2vec/infer.sh b/examples/industrial_data_pretraining/emotion2vec/infer.sh
new file mode 100644
index 0000000..99600ca
--- /dev/null
+++ b/examples/industrial_data_pretraining/emotion2vec/infer.sh
@@ -0,0 +1,14 @@
+
+# download model
+local_path_root=../modelscope_models
+mkdir -p ${local_path_root}
+local_path=${local_path_root}/emotion2vec_base
+git clone https://www.modelscope.cn/damo/emotion2vec_base.git ${local_path}
+#local_path=/Users/zhifu/Downloads/modelscope_models/emotion2vec_base
+
+python funasr/bin/inference.py \
++model="${local_path}" \
++input="${local_path}/example/test.wav" \
++output_dir="./outputs/debug" \
++device="cpu" \
++debug=true
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index 1676c30..5b58907 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -222,7 +222,8 @@
 				batch["data_lengths"] = input_len
 		
 			time1 = time.perf_counter()
-			results, meta_data = model.generate(**batch, **kwargs)
+			with torch.no_grad():
+				results, meta_data = model.generate(**batch, **kwargs)
 			time2 = time.perf_counter()
 			
 			asr_result_list.extend(results)
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 4f05b42..abf3ba0 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -1,3 +1,4 @@
+import json
 import os
 from omegaconf import OmegaConf
 import torch
@@ -19,23 +20,34 @@
 		model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"))
 	
 	config = os.path.join(model_or_path, "config.yaml")
-	assert os.path.exists(config), "{} is not exist!".format(config)
-	cfg = OmegaConf.load(config)
-	kwargs = OmegaConf.merge(cfg, kwargs)
-	init_param = os.path.join(model_or_path, "model.pb")
-	kwargs["init_param"] = init_param
-	if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
-		kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
-	if os.path.exists(os.path.join(model_or_path, "tokens.json")):
-		kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
-	if os.path.exists(os.path.join(model_or_path, "seg_dict")):
-		kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
-	if os.path.exists(os.path.join(model_or_path, "bpe.model")):
-		kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
-	kwargs["model"] = cfg["model"]
-	if os.path.exists(os.path.join(model_or_path, "am.mvn")):
-		kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
-	
+	if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
+		# config = os.path.join(model_or_path, "config.yaml")
+		# assert os.path.exists(config), "{} is not exist!".format(config)
+		cfg = OmegaConf.load(config)
+		kwargs = OmegaConf.merge(cfg, kwargs)
+		init_param = os.path.join(model_or_path, "model.pb")
+		kwargs["init_param"] = init_param
+		if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+			kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
+		if os.path.exists(os.path.join(model_or_path, "tokens.json")):
+			kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
+		if os.path.exists(os.path.join(model_or_path, "seg_dict")):
+			kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
+		if os.path.exists(os.path.join(model_or_path, "bpe.model")):
+			kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
+		kwargs["model"] = cfg["model"]
+		if os.path.exists(os.path.join(model_or_path, "am.mvn")):
+			kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
+	else:# configuration.json
+		assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
+		with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
+			conf_json = json.load(f)
+			config = os.path.join(model_or_path, conf_json["model"]["model_config"])
+			cfg = OmegaConf.load(config)
+			kwargs = OmegaConf.merge(cfg, kwargs)
+			init_param = os.path.join(model_or_path, conf_json["model"]["model_name"])
+			kwargs["init_param"] = init_param
+		kwargs["model"] = cfg["model"]
 	return OmegaConf.to_container(kwargs, resolve=True)
 
 def get_or_download_model_dir(
@@ -60,12 +72,15 @@
 	if os.path.exists(model):
 		model_cache_dir = model if os.path.isdir(
 			model) else os.path.dirname(model)
-		check_local_model_is_latest(
-			model_cache_dir,
-			user_agent={
-				Invoke.KEY: key,
-				ThirdParty.KEY: "funasr"
-			})
+		try:
+			check_local_model_is_latest(
+				model_cache_dir,
+				user_agent={
+					Invoke.KEY: key,
+					ThirdParty.KEY: "funasr"
+				})
+		except:
+			print("could not check the latest version")
 	else:
 		model_cache_dir = snapshot_download(
 			model,
diff --git a/funasr/models/emotion2vec/__init__.py b/funasr/models/emotion2vec/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/emotion2vec/__init__.py
diff --git a/funasr/models/emotion2vec/audio.py b/funasr/models/emotion2vec/audio.py
new file mode 100644
index 0000000..316d372
--- /dev/null
+++ b/funasr/models/emotion2vec/audio.py
@@ -0,0 +1,167 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Tuple
+from functools import partial
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from typing import Callable, Dict, Optional
+from funasr.models.emotion2vec.fairseq_modules import (
+    LayerNorm,
+    SamePad,
+    TransposeLast,
+    ConvFeatureExtractionModel,
+)
+
+from funasr.models.emotion2vec.base import ModalitySpecificEncoder, get_alibi_bias
+from funasr.models.emotion2vec.modules import Modality, BlockEncoder, Decoder1d
+
+
+
+
+class AudioEncoder(ModalitySpecificEncoder):
+
+
+    def __init__(
+        self,
+        modality_cfg,
+        embed_dim: int,
+        make_block: Callable[[float], nn.ModuleList],
+        norm_layer: Callable[[int], nn.LayerNorm],
+        layer_norm_first: bool,
+        alibi_biases: Dict,
+    ):
+
+        self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec)
+        feature_embed_dim = self.feature_enc_layers[-1][0]
+
+        local_encoder = ConvFeatureExtractionModel(
+            conv_layers=self.feature_enc_layers,
+            dropout=0.0,
+            mode=modality_cfg.extractor_mode,
+            conv_bias=False,
+        )
+
+        project_features = nn.Sequential(
+            TransposeLast(),
+            nn.LayerNorm(feature_embed_dim),
+            nn.Linear(feature_embed_dim, embed_dim),
+        )
+
+        num_pos_layers = modality_cfg.conv_pos_depth
+        k = max(3, modality_cfg.conv_pos_width // num_pos_layers)
+
+        positional_encoder = nn.Sequential(
+            TransposeLast(),
+            *[
+                nn.Sequential(
+                    nn.Conv1d(
+                        embed_dim,
+                        embed_dim,
+                        kernel_size=k,
+                        padding=k // 2,
+                        groups=modality_cfg.conv_pos_groups,
+                    ),
+                    SamePad(k),
+                    TransposeLast(),
+                    LayerNorm(embed_dim, elementwise_affine=False),
+                    TransposeLast(),
+                    nn.GELU(),
+                )
+                for _ in range(num_pos_layers)
+            ],
+            TransposeLast(),
+        )
+
+        if modality_cfg.conv_pos_pre_ln:
+            positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder)
+
+        dpr = np.linspace(
+            modality_cfg.start_drop_path_rate,
+            modality_cfg.end_drop_path_rate,
+            modality_cfg.prenet_depth,
+        )
+        context_encoder = BlockEncoder(
+            nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
+            norm_layer(embed_dim) if not layer_norm_first else None,
+            layer_norm_first,
+            modality_cfg.prenet_layerdrop,
+            modality_cfg.prenet_dropout,
+        )
+
+        decoder = (
+            Decoder1d(modality_cfg.decoder, embed_dim)
+            if modality_cfg.decoder is not None
+            else None
+        )
+
+        alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
+
+        super().__init__(
+            modality_cfg=modality_cfg,
+            embed_dim=embed_dim,
+            local_encoder=local_encoder,
+            project_features=project_features,
+            fixed_positional_encoder=None,
+            relative_positional_encoder=positional_encoder,
+            context_encoder=context_encoder,
+            decoder=decoder,
+            get_alibi_bias=alibi_bias_fn,
+        )
+
+    def convert_padding_mask(self, x, padding_mask):
+        def get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
+            """
+            Computes the output length of the convolutional layers
+            """
+
+            def _conv_out_length(input_length, kernel_size, stride):
+                return torch.floor((input_length - kernel_size) / stride + 1)
+
+            for i in range(len(self.feature_enc_layers)):
+                input_lengths = _conv_out_length(
+                    input_lengths,
+                    self.feature_enc_layers[i][1],
+                    self.feature_enc_layers[i][2],
+                )
+
+            return input_lengths.to(torch.long)
+
+        if padding_mask is not None:
+            input_lengths = (1 - padding_mask.long()).sum(-1)
+            # apply conv formula to get real output_lengths
+            output_lengths = get_feat_extract_output_lengths(input_lengths)
+
+            if padding_mask.any():
+                padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device)
+
+                # these two operations makes sure that all values
+                # before the output lengths indices are attended to
+                padding_mask[
+                    (
+                        torch.arange(padding_mask.shape[0], device=padding_mask.device),
+                        output_lengths - 1,
+                    )
+                ] = 1
+                padding_mask = (
+                    1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
+                ).bool()
+            else:
+                padding_mask = torch.zeros(
+                    x.shape[:2], dtype=torch.bool, device=x.device
+                )
+
+        return padding_mask
+
+    def reset_parameters(self):
+        super().reset_parameters()
+        for mod in self.project_features.children():
+            if isinstance(mod, nn.Linear):
+                mod.reset_parameters()
+        if self.decoder is not None:
+            self.decoder.reset_parameters()
diff --git a/funasr/models/emotion2vec/base.py b/funasr/models/emotion2vec/base.py
new file mode 100644
index 0000000..cd87a99
--- /dev/null
+++ b/funasr/models/emotion2vec/base.py
@@ -0,0 +1,639 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from collections import namedtuple
+from dataclasses import dataclass
+from functools import partial
+from omegaconf import MISSING, II
+from typing import Optional, Callable
+from funasr.models.emotion2vec.fairseq_modules import compute_mask_indices
+from funasr.models.emotion2vec.fairseq_modules import GradMultiply
+from funasr.models.emotion2vec.fairseq_modules import index_put
+
+
+logger = logging.getLogger(__name__)
+
+
+
+MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
+MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
+
+
+class ModalitySpecificEncoder(nn.Module):
+    def __init__(
+        self,
+        modality_cfg,
+        embed_dim: int,
+        local_encoder: nn.Module,
+        project_features: nn.Module,
+        fixed_positional_encoder: Optional[nn.Module],
+        relative_positional_encoder: Optional[nn.Module],
+        context_encoder: nn.Module,
+        decoder: nn.Module,
+        get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]],
+    ):
+        super().__init__()
+
+        self.modality_cfg = modality_cfg
+        self.local_encoder = local_encoder
+        self.project_features = project_features
+        self.fixed_positional_encoder = fixed_positional_encoder
+        self.relative_positional_encoder = relative_positional_encoder
+        self.context_encoder = context_encoder
+
+        self.decoder = decoder
+        self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None
+
+        self.local_grad_mult = self.modality_cfg.local_grad_mult
+
+        self.extra_tokens = None
+        if modality_cfg.num_extra_tokens > 0:
+            self.extra_tokens = nn.Parameter(
+                torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim)
+            )
+            if not modality_cfg.init_extra_token_zero:
+                nn.init.normal_(self.extra_tokens)
+            elif self.extra_tokens.size(1) > 1:
+                nn.init.normal_(self.extra_tokens[:, 1:])
+
+        self.alibi_scale = None
+        if self.get_alibi_bias is not None:
+            self.alibi_scale = nn.Parameter(
+                torch.full(
+                    (
+                        (modality_cfg.prenet_depth + modality_cfg.model_depth)
+                        if modality_cfg.learned_alibi_scale_per_layer
+                        else 1,
+                        1,
+                        self.modality_cfg.num_alibi_heads
+                        if modality_cfg.learned_alibi_scale_per_head
+                        else 1,
+                        1,
+                        1,
+                    ),
+                    modality_cfg.alibi_scale,
+                    dtype=torch.float,
+                ),
+                requires_grad=modality_cfg.learned_alibi_scale,
+            )
+
+        if modality_cfg.learned_alibi and self.get_alibi_bias is not None:
+            assert modality_cfg.alibi_max_pos is not None
+            alibi_bias = self.get_alibi_bias(
+                batch_size=1,
+                time_steps=modality_cfg.alibi_max_pos,
+                heads=modality_cfg.num_alibi_heads,
+                scale=1.0,
+                dtype=torch.float,
+                device="cpu",
+            )
+            self.alibi_bias = nn.Parameter(alibi_bias)
+            self.get_alibi_bias = partial(
+                _learned_alibi_bias, alibi_bias=self.alibi_bias
+            )
+
+    def upgrade_state_dict_named(self, state_dict, name):
+        k = f"{name}.alibi_scale"
+        if k in state_dict and state_dict[k].dim() == 4:
+            state_dict[k] = state_dict[k].unsqueeze(0)
+
+        return state_dict
+
+    def convert_padding_mask(self, x, padding_mask):
+        return padding_mask
+
+    def decoder_input(self, x, mask_info: MaskInfo):
+        inp_drop = self.modality_cfg.decoder.input_dropout
+        if inp_drop > 0:
+            x = F.dropout(x, inp_drop, training=self.training, inplace=True)
+
+        num_extra = self.modality_cfg.num_extra_tokens
+
+        if mask_info is not None:
+            num_masked = mask_info.ids_restore.shape[1] - x.shape[1] + num_extra
+
+            mask_tokens = x.new_empty(
+                x.size(0),
+                num_masked,
+                x.size(-1),
+            ).normal_(0, self.modality_cfg.mask_noise_std)
+
+            x_ = torch.cat([x[:, num_extra:], mask_tokens], dim=1)
+            x = torch.gather(x_, dim=1, index=mask_info.ids_restore)
+
+            if self.modality_cfg.decoder.add_positions_masked:
+                assert self.fixed_positional_encoder is not None
+                pos = self.fixed_positional_encoder(x, None)
+                x = x + (pos * mask_info.mask.unsqueeze(-1))
+        else:
+            x = x[:, num_extra:]
+
+        if self.modality_cfg.decoder.add_positions_all:
+            assert self.fixed_positional_encoder is not None
+            x = x + self.fixed_positional_encoder(x, None)
+
+        return x, mask_info
+
+    def local_features(self, features):
+        if self.local_grad_mult > 0:
+            if self.local_grad_mult == 1.0:
+                x = self.local_encoder(features)
+            else:
+                x = GradMultiply.apply(
+                    self.local_encoder(features), self.local_grad_mult
+                )
+        else:
+            with torch.no_grad():
+                x = self.local_encoder(features)
+
+        x = self.project_features(x)
+        return x
+
+    def contextualized_features(
+        self,
+        x,
+        padding_mask,
+        mask,
+        remove_masked,
+        clone_batch: int = 1,
+        mask_seeds: Optional[torch.Tensor] = None,
+        precomputed_mask=None,
+    ):
+
+        if padding_mask is not None:
+            padding_mask = self.convert_padding_mask(x, padding_mask)
+
+        local_features = x
+        if mask and clone_batch == 1:
+            local_features = local_features.clone()
+
+        orig_B, orig_T, _ = x.shape
+        pre_mask_B = orig_B
+        mask_info = None
+
+        x_pos = None
+        if self.fixed_positional_encoder is not None:
+            x = x + self.fixed_positional_encoder(x, padding_mask)
+
+        if mask:
+            if clone_batch > 1:
+                x = x.repeat_interleave(clone_batch, 0)
+                if mask_seeds is not None:
+                    clone_hash = [
+                        int(hash((mask_seeds.seed, ind)) % 1e10)
+                        for ind in range(clone_batch - 1)
+                    ]
+                    clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
+
+                    id = mask_seeds.ids
+                    id = id.repeat_interleave(clone_batch, 0)
+                    id = id.view(-1, clone_batch) + clone_hash.to(id)
+                    id = id.view(-1)
+                    mask_seeds = MaskSeed(
+                        seed=mask_seeds.seed, update=mask_seeds.update, ids=id
+                    )
+                if padding_mask is not None:
+                    padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
+
+            x, mask_info = self.compute_mask(
+                x,
+                padding_mask,
+                mask_seed=mask_seeds,
+                apply=self.relative_positional_encoder is not None or not remove_masked,
+                precomputed_mask=precomputed_mask,
+            )
+
+        if self.relative_positional_encoder is not None:
+            x_pos = self.relative_positional_encoder(x)
+
+        masked_padding_mask = padding_mask
+        if mask and remove_masked:
+            x = mask_info.x_unmasked
+            if x_pos is not None:
+                x = x + gather_unmasked(x_pos, mask_info)
+
+            if padding_mask is not None and padding_mask.any():
+                masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info)
+                if not masked_padding_mask.any():
+                    masked_padding_mask = None
+            else:
+                masked_padding_mask = None
+
+        elif x_pos is not None:
+            x = x + x_pos
+
+        alibi_bias = None
+        alibi_scale = self.alibi_scale
+
+        if self.get_alibi_bias is not None:
+            alibi_bias = self.get_alibi_bias(
+                batch_size=pre_mask_B,
+                time_steps=orig_T,
+                heads=self.modality_cfg.num_alibi_heads,
+                dtype=torch.float32,
+                device=x.device,
+            )
+
+            if alibi_scale is not None:
+                alibi_scale = alibi_scale.clamp_min(0)
+                if alibi_scale.size(0) == 1:
+                    alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias)
+                    alibi_scale = None
+
+            if clone_batch > 1:
+                alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0)
+
+            if mask_info is not None and remove_masked:
+                alibi_bias = masked_alibi(alibi_bias, mask_info)
+
+        if self.extra_tokens is not None:
+            num = self.extra_tokens.size(1)
+            x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1)
+            if masked_padding_mask is not None:
+                # B x T
+                masked_padding_mask = F.pad(masked_padding_mask, (num, 0))
+            if alibi_bias is not None:
+                # B x H x T x T
+                alibi_bias = F.pad(alibi_bias, (num, 0, num, 0))
+
+        x = self.context_encoder(
+            x,
+            masked_padding_mask,
+            alibi_bias,
+            alibi_scale[: self.modality_cfg.prenet_depth]
+            if alibi_scale is not None
+            else None,
+        )
+
+        return {
+            "x": x,
+            "local_features": local_features,
+            "padding_mask": masked_padding_mask,
+            "alibi_bias": alibi_bias,
+            "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
+            if alibi_scale is not None and alibi_scale.size(0) > 1
+            else alibi_scale,
+            "encoder_mask": mask_info,
+        }
+
+    def forward(
+        self,
+        features,
+        padding_mask,
+        mask: bool,
+        remove_masked: bool,
+        clone_batch: int = 1,
+        mask_seeds: Optional[torch.Tensor] = None,
+        precomputed_mask=None,
+    ):
+        x = self.local_features(features)
+        return self.contextualized_features(
+            x,
+            padding_mask,
+            mask,
+            remove_masked,
+            clone_batch,
+            mask_seeds,
+            precomputed_mask,
+        )
+
+    def reset_parameters(self):
+        pass
+
+    def compute_mask(
+        self,
+        x,
+        padding_mask,
+        mask_seed: Optional[MaskSeed],
+        apply,
+        precomputed_mask,
+    ):
+        if precomputed_mask is not None:
+            mask = precomputed_mask
+            mask_info = self.make_maskinfo(x, mask)
+        else:
+            B, T, C = x.shape
+            cfg = self.modality_cfg
+
+            mask_prob = cfg.mask_prob
+
+            if (
+                cfg.mask_prob_min is not None
+                and cfg.mask_prob_min >= 0
+                and cfg.mask_prob_min < mask_prob
+            ):
+                mask_prob = np.random.uniform(cfg.mask_prob_min, mask_prob)
+
+            if mask_prob > 0:
+                if cfg.mask_length == 1:
+                    mask_info = random_masking(x, mask_prob, mask_seed)
+                else:
+                    if self.modality_cfg.inverse_mask:
+                        mask_prob = 1 - mask_prob
+
+                    mask = compute_mask_indices(
+                        (B, T),
+                        padding_mask,
+                        mask_prob,
+                        cfg.mask_length,
+                        min_masks=1,
+                        require_same_masks=True,
+                        mask_dropout=cfg.mask_dropout,
+                        add_masks=cfg.add_masks,
+                        seed=mask_seed.seed if mask_seed is not None else None,
+                        epoch=mask_seed.update if mask_seed is not None else None,
+                        indices=mask_seed.ids if mask_seed is not None else None,
+                    )
+
+                    mask = torch.from_numpy(mask).to(device=x.device)
+                    if self.modality_cfg.inverse_mask:
+                        mask = 1 - mask
+                    mask_info = self.make_maskinfo(x, mask)
+            else:
+                mask_info = None
+
+        if apply:
+            x = self.apply_mask(x, mask_info)
+
+        return x, mask_info
+
+    def make_maskinfo(self, x, mask, shape=None):
+        if shape is None:
+            B, T, D = x.shape
+        else:
+            B, T, D = shape
+
+        mask = mask.to(torch.uint8)
+        ids_shuffle = mask.argsort(dim=1)
+        ids_restore = ids_shuffle.argsort(dim=1).unsqueeze(-1).expand(-1, -1, D)
+
+        len_keep = T - mask[0].sum()
+        if self.modality_cfg.keep_masked_pct > 0:
+            len_keep += round((T - int(len_keep)) * self.modality_cfg.keep_masked_pct)
+
+        ids_keep = ids_shuffle[:, :len_keep]
+
+        if shape is not None:
+            x_unmasked = None
+        else:
+            ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
+            x_unmasked = torch.gather(x, dim=1, index=ids_keep)
+
+        mask_info = MaskInfo(
+            x_unmasked=x_unmasked,
+            mask=mask,
+            ids_restore=ids_restore,
+            ids_keep=ids_keep,
+        )
+        return mask_info
+
+    def apply_mask(self, x, mask_info):
+        cfg = self.modality_cfg
+        B, T, C = x.shape
+
+        if mask_info is not None:
+            mask = mask_info.mask
+            if cfg.encoder_zero_mask:
+                x = x * (1 - mask.type_as(x).unsqueeze(-1))
+            else:
+                num_masks = mask.sum().item()
+                masks = x.new_empty(num_masks, x.size(-1)).normal_(
+                    0, cfg.mask_noise_std
+                )
+                x = index_put(x, mask, masks)
+        if cfg.mask_channel_prob > 0:
+            mask_channel = compute_mask_indices(
+                (B, C),
+                None,
+                cfg.mask_channel_prob,
+                cfg.mask_channel_length,
+            )
+            mask_channel = (
+                torch.from_numpy(mask_channel)
+                .to(x.device)
+                .unsqueeze(1)
+                .expand(-1, T, -1)
+            )
+            x = index_put(x, mask_channel, 0)
+        return x
+
+    def remove_pretraining_modules(self, keep_decoder=False):
+        if not keep_decoder:
+            self.decoder = None
+
+
+def get_annealed_rate(start, end, curr_step, total_steps):
+    if curr_step >= total_steps:
+        return end
+    r = end - start
+    pct_remaining = 1 - curr_step / total_steps
+    return end - r * pct_remaining
+
+
+# adapted from MAE
+def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]):
+    N, L, D = x.shape  # batch, length, dim
+    len_keep = int(L * (1 - mask_ratio))
+
+    generator = None
+    if mask_seed is not None:
+        seed = int(
+            hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
+        )
+        generator = torch.Generator(device=x.device)
+        generator.manual_seed(seed)
+
+    noise = torch.rand(N, L, generator=generator, device=x.device)  # noise in [0, 1]
+
+    # sort noise for each sample
+    ids_shuffle = noise.argsort(dim=1)  # ascend: small is keep, large is remove
+    ids_restore = ids_shuffle.argsort(dim=1)
+
+    # keep the first subset
+    ids_keep = ids_shuffle[:, :len_keep]
+    ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
+    x_unmasked = torch.gather(x, dim=1, index=ids_keep)
+
+    # generate the binary mask: 0 is keep, 1 is remove
+    mask = torch.ones([N, L], dtype=x.dtype, device=x.device)
+    mask[:, :len_keep] = 0
+    # unshuffle to get the binary mask
+    mask = torch.gather(mask, dim=1, index=ids_restore)
+
+    ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
+
+    return MaskInfo(
+        x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
+    )
+
+
+def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
+    return torch.gather(
+        x,
+        dim=1,
+        index=mask_info.ids_keep,
+    )
+
+
+def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
+    return torch.gather(
+        x,
+        dim=1,
+        index=mask_info.ids_keep[..., 0],  # ignore the feature dimension
+    )
+
+
+def get_alibi(
+    max_positions: int,
+    attention_heads: int,
+    dims: int = 1,
+    distance: str = "manhattan",
+):
+    def get_slopes(n):
+        def get_slopes_power_of_2(n):
+            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
+            ratio = start
+            return [start * ratio**i for i in range(n)]
+
+        # In the paper, we only train models that have 2^a heads for some
+        # a. This function has some good properties that only occur when
+        # the input is a power of 2. To maintain that even when the number
+        # of heads is not a power of 2, we use this workaround.
+        if math.log2(n).is_integer():
+            return get_slopes_power_of_2(n)
+        else:
+            closest_power_of_2 = 2 ** math.floor(math.log2(n))
+            return (
+                get_slopes_power_of_2(closest_power_of_2)
+                + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
+            )
+
+    maxpos = max_positions
+    attn_heads = attention_heads
+    slopes = torch.Tensor(get_slopes(attn_heads))
+
+    if dims == 1:
+        # prepare alibi position linear bias. Note that wav2vec2 is non
+        # autoregressive model so we want a symmetric mask with 0 on the
+        # diagonal and other wise linear decreasing valuees
+        pos_bias = (
+            torch.abs(
+                torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
+            )
+            * -1
+        )
+    elif dims == 2:
+        if distance == "manhattan":
+            df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
+        elif distance == "euclidean":
+            df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
+
+        n = math.sqrt(max_positions)
+        assert n.is_integer(), n
+        n = int(n)
+
+        pos_bias = torch.zeros((max_positions, max_positions))
+
+        for i in range(n):
+            for j in range(n):
+                for k in range(n):
+                    for l in range(n):
+                        new_x = i * n + j
+                        new_y = k * n + l
+                        pos_bias[new_x, new_y] = -df(i, j, k, l)
+
+    else:
+        raise Exception(f"unsupported number of alibi dims: {dims}")
+
+    alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
+        attn_heads, -1, -1
+    )
+
+    return alibi_bias
+
+
+def get_alibi_bias(
+    alibi_biases,
+    batch_size,
+    time_steps,
+    heads,
+    dtype,
+    device,
+    dims=1,
+    distance="manhattan",
+):
+    cache_key = f"{dims}_{heads}_{distance}"
+
+    buffered = alibi_biases.get(cache_key, None)
+
+    target_size = heads * batch_size
+    if (
+        buffered is None
+        or buffered.size(0) < target_size
+        or buffered.size(1) < time_steps
+        or buffered.dtype != dtype
+        or buffered.device != device
+    ):
+        bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
+        bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
+
+        buffered = (
+            get_alibi(bt, heads, dims=dims, distance=distance)
+            .to(dtype=dtype, device=device)
+            .repeat(bn, 1, 1)
+        )
+
+        alibi_biases[cache_key] = buffered
+
+    b = buffered[:target_size, :time_steps, :time_steps]
+    b = b.view(batch_size, heads, time_steps, time_steps)
+    return b
+
+
+def _learned_alibi_bias(
+    alibi_bias,
+    batch_size,
+    time_steps,
+    heads,
+    scale,
+    dtype,
+    device,
+):
+    assert alibi_bias.size(1) == heads, alibi_bias.shape
+    assert alibi_bias.dtype == dtype, alibi_bias.dtype
+    assert alibi_bias.device == device, alibi_bias.device
+
+    if alibi_bias.size(-1) < time_steps:
+        psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
+        alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
+
+    alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
+    return alibi_bias[..., :time_steps, :time_steps]
+
+
+def masked_alibi(alibi_bias, mask_info):
+    H = alibi_bias.size(1)
+
+    orig_bias = alibi_bias
+
+    index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
+    alibi_bias = torch.gather(
+        orig_bias,
+        dim=-2,
+        index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
+    )
+    alibi_bias = torch.gather(
+        alibi_bias,
+        dim=-1,
+        index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
+    )
+
+    return alibi_bias
diff --git a/funasr/models/emotion2vec/fairseq_modules.py b/funasr/models/emotion2vec/fairseq_modules.py
new file mode 100644
index 0000000..fa98885
--- /dev/null
+++ b/funasr/models/emotion2vec/fairseq_modules.py
@@ -0,0 +1,306 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional, Tuple, List
+import numpy as np
+
+def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
+	return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
+
+class SamePad(nn.Module):
+	def __init__(self, kernel_size, causal=False):
+		super().__init__()
+		if causal:
+			self.remove = kernel_size - 1
+		else:
+			self.remove = 1 if kernel_size % 2 == 0 else 0
+	
+	def forward(self, x):
+		if self.remove > 0:
+			x = x[:, :, : -self.remove]
+		return x
+
+class TransposeLast(nn.Module):
+	def __init__(self, deconstruct_idx=None):
+		super().__init__()
+		self.deconstruct_idx = deconstruct_idx
+	
+	def forward(self, x):
+		if self.deconstruct_idx is not None:
+			x = x[self.deconstruct_idx]
+		return x.transpose(-2, -1)
+
+
+class Fp32LayerNorm(nn.LayerNorm):
+	def __init__(self, *args, **kwargs):
+		super().__init__(*args, **kwargs)
+	
+	def forward(self, input):
+		output = F.layer_norm(
+			input.float(),
+			self.normalized_shape,
+			self.weight.float() if self.weight is not None else None,
+			self.bias.float() if self.bias is not None else None,
+			self.eps,
+		)
+		return output.type_as(input)
+
+
+class Fp32GroupNorm(nn.GroupNorm):
+	def __init__(self, *args, **kwargs):
+		super().__init__(*args, **kwargs)
+	
+	def forward(self, input):
+		output = F.group_norm(
+			input.float(),
+			self.num_groups,
+			self.weight.float() if self.weight is not None else None,
+			self.bias.float() if self.bias is not None else None,
+			self.eps,
+		)
+		return output.type_as(input)
+
+
+class ConvFeatureExtractionModel(nn.Module):
+	def __init__(
+		self,
+		conv_layers: List[Tuple[int, int, int]],
+		dropout: float = 0.0,
+		mode: str = "default",
+		conv_bias: bool = False,
+	):
+		super().__init__()
+		
+		assert mode in {"default", "layer_norm"}
+		
+		def block(
+			n_in,
+			n_out,
+			k,
+			stride,
+			is_layer_norm=False,
+			is_group_norm=False,
+			conv_bias=False,
+		):
+			def make_conv():
+				conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
+				nn.init.kaiming_normal_(conv.weight)
+				return conv
+			
+			assert (
+				       is_layer_norm and is_group_norm
+			       ) == False, "layer norm and group norm are exclusive"
+			
+			if is_layer_norm:
+				return nn.Sequential(
+					make_conv(),
+					nn.Dropout(p=dropout),
+					nn.Sequential(
+						TransposeLast(),
+						Fp32LayerNorm(dim, elementwise_affine=True),
+						TransposeLast(),
+					),
+					nn.GELU(),
+				)
+			elif is_group_norm:
+				return nn.Sequential(
+					make_conv(),
+					nn.Dropout(p=dropout),
+					Fp32GroupNorm(dim, dim, affine=True),
+					nn.GELU(),
+				)
+			else:
+				return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
+		
+		in_d = 1
+		self.conv_layers = nn.ModuleList()
+		for i, cl in enumerate(conv_layers):
+			assert len(cl) == 3, "invalid conv definition: " + str(cl)
+			(dim, k, stride) = cl
+			
+			self.conv_layers.append(
+				block(
+					in_d,
+					dim,
+					k,
+					stride,
+					is_layer_norm=mode == "layer_norm",
+					is_group_norm=mode == "default" and i == 0,
+					conv_bias=conv_bias,
+				)
+			)
+			in_d = dim
+	
+	def forward(self, x):
+		
+		# BxT -> BxCxT
+		x = x.unsqueeze(1)
+		
+		for conv in self.conv_layers:
+			x = conv(x)
+		
+		return x
+
+def compute_mask_indices(
+    shape: Tuple[int, int],
+    padding_mask: Optional[torch.Tensor],
+    mask_prob: float,
+    mask_length: int,
+    mask_type: str = "static",
+    mask_other: float = 0.0,
+    min_masks: int = 0,
+    no_overlap: bool = False,
+    min_space: int = 0,
+    require_same_masks: bool = True,
+    mask_dropout: float = 0.0,
+) -> np.ndarray:
+    """
+    Computes random mask spans for a given shape
+
+    Args:
+        shape: the the shape for which to compute masks.
+            should be of size 2 where first element is batch size and 2nd is timesteps
+        padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+        mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+            number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+            however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+        mask_type: how to compute mask lengths
+            static = fixed size
+            uniform = sample from uniform distribution [mask_other, mask_length*2]
+            normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+            poisson = sample from possion distribution with lambda = mask length
+        min_masks: minimum number of masked spans
+        no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+        min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+        require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
+        mask_dropout: randomly dropout this percentage of masks in each example
+    """
+
+    bsz, all_sz = shape
+    mask = np.full((bsz, all_sz), False)
+
+    all_num_mask = int(
+        # add a random number for probabilistic rounding
+        mask_prob * all_sz / float(mask_length)
+        + np.random.rand()
+    )
+
+    all_num_mask = max(min_masks, all_num_mask)
+
+    mask_idcs = []
+    for i in range(bsz):
+        if padding_mask is not None:
+            sz = all_sz - padding_mask[i].long().sum().item()
+            num_mask = int(
+                # add a random number for probabilistic rounding
+                mask_prob * sz / float(mask_length)
+                + np.random.rand()
+            )
+            num_mask = max(min_masks, num_mask)
+        else:
+            sz = all_sz
+            num_mask = all_num_mask
+
+        if mask_type == "static":
+            lengths = np.full(num_mask, mask_length)
+        elif mask_type == "uniform":
+            lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+        elif mask_type == "normal":
+            lengths = np.random.normal(mask_length, mask_other, size=num_mask)
+            lengths = [max(1, int(round(x))) for x in lengths]
+        elif mask_type == "poisson":
+            lengths = np.random.poisson(mask_length, size=num_mask)
+            lengths = [int(round(x)) for x in lengths]
+        else:
+            raise Exception("unknown mask selection " + mask_type)
+
+        if sum(lengths) == 0:
+            lengths[0] = min(mask_length, sz - 1)
+
+        if no_overlap:
+            mask_idc = []
+
+            def arrange(s, e, length, keep_length):
+                span_start = np.random.randint(s, e - length)
+                mask_idc.extend(span_start + i for i in range(length))
+
+                new_parts = []
+                if span_start - s - min_space >= keep_length:
+                    new_parts.append((s, span_start - min_space + 1))
+                if e - span_start - length - min_space > keep_length:
+                    new_parts.append((span_start + length + min_space, e))
+                return new_parts
+
+            parts = [(0, sz)]
+            min_length = min(lengths)
+            for length in sorted(lengths, reverse=True):
+                lens = np.fromiter(
+                    (e - s if e - s >= length + min_space else 0 for s, e in parts),
+                    np.int,
+                )
+                l_sum = np.sum(lens)
+                if l_sum == 0:
+                    break
+                probs = lens / np.sum(lens)
+                c = np.random.choice(len(parts), p=probs)
+                s, e = parts.pop(c)
+                parts.extend(arrange(s, e, length, min_length))
+            mask_idc = np.asarray(mask_idc)
+        else:
+            min_len = min(lengths)
+            if sz - min_len <= num_mask:
+                min_len = sz - num_mask - 1
+
+            mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+
+            mask_idc = np.asarray(
+                [
+                    mask_idc[j] + offset
+                    for j in range(len(mask_idc))
+                    for offset in range(lengths[j])
+                ]
+            )
+
+        mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+
+    min_len = min([len(m) for m in mask_idcs])
+    for i, mask_idc in enumerate(mask_idcs):
+        if len(mask_idc) > min_len and require_same_masks:
+            mask_idc = np.random.choice(mask_idc, min_len, replace=False)
+        if mask_dropout > 0:
+            num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
+            mask_idc = np.random.choice(
+                mask_idc, len(mask_idc) - num_holes, replace=False
+            )
+
+        mask[i, mask_idc] = True
+
+    return mask
+
+
+class GradMultiply(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x, scale):
+        ctx.scale = scale
+        res = x.new(x)
+        return res
+
+    @staticmethod
+    def backward(ctx, grad):
+        return grad * ctx.scale, None
+    
+    
+def is_xla_tensor(tensor):
+    return torch.is_tensor(tensor) and tensor.device.type == "xla"
+
+
+def index_put(tensor, indices, value):
+    if is_xla_tensor(tensor):
+        for _ in range(indices.dim(), tensor.dim()):
+            indices = indices.unsqueeze(-1)
+        if indices.size(-1) < tensor.size(-1):
+            indices = indices.expand_as(tensor)
+        tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
+    else:
+        tensor[indices] = value
+    return tensor
\ No newline at end of file
diff --git a/funasr/models/emotion2vec/model.py b/funasr/models/emotion2vec/model.py
new file mode 100644
index 0000000..e882b6e
--- /dev/null
+++ b/funasr/models/emotion2vec/model.py
@@ -0,0 +1,215 @@
+
+import logging
+from functools import partial
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+from funasr.models.emotion2vec.modules import AltBlock
+from funasr.models.emotion2vec.audio import AudioEncoder
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+
+from omegaconf import OmegaConf
+import time
+
+logger = logging.getLogger(__name__)
+
+from funasr.register import tables
+
+@tables.register("model_classes", "Emotion2vec")
+class Emotion2vec(nn.Module):
+
+    def __init__(self, **kwargs):
+        super().__init__()
+        # import pdb; pdb.set_trace()
+        cfg = OmegaConf.create(kwargs["model_conf"])
+        self.cfg = cfg
+
+        make_layer_norm = partial(
+            nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
+        )
+
+        def make_block(drop_path, dim=None, heads=None):
+            return AltBlock(
+                cfg.get("embed_dim") if dim is None else dim,
+                cfg.get("num_heads") if heads is None else heads,
+                cfg.get("mlp_ratio"),
+                qkv_bias=True,
+                drop=cfg.get("encoder_dropout"),
+                attn_drop=cfg.get("attention_dropout"),
+                mlp_drop=cfg.get("activation_dropout"),
+                post_mlp_drop=cfg.get("post_mlp_drop"),
+                drop_path=drop_path,
+                norm_layer=make_layer_norm,
+                layer_norm_first=cfg.get("layer_norm_first"),
+                ffn_targets=not cfg.get("end_of_block_targets"),
+            )
+
+        self.alibi_biases = {}
+        self.modality_encoders = nn.ModuleDict()
+
+        enc = AudioEncoder(
+            cfg.modalities.audio,
+            cfg.get("embed_dim"),
+            make_block,
+            make_layer_norm,
+            cfg.get("layer_norm_first"),
+            self.alibi_biases,
+        )
+        self.modality_encoders['AUDIO'] = enc
+
+        self.ema = None
+
+        self.average_top_k_layers = cfg.get("average_top_k_layers")
+        self.loss_beta = cfg.get("loss_beta")
+        self.loss_scale = cfg.get("loss_scale")
+
+        self.dropout_input = nn.Dropout(cfg.get("dropout_input"))
+
+        dpr = np.linspace(cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth"))
+
+        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
+
+        self.norm = None
+        if cfg.get("layer_norm_first"):
+            self.norm = make_layer_norm(cfg.get("embed_dim"))
+
+
+
+
+    def forward(
+        self,
+        source,
+        target=None,
+        id=None,
+        mode=None,
+        padding_mask=None,
+        mask=True,
+        features_only=False,
+        force_remove_masked=False,
+        remove_extra_tokens=True,
+        precomputed_mask=None,
+        **kwargs,
+    ):
+
+        feature_extractor = self.modality_encoders['AUDIO']
+
+        mask_seeds = None
+
+        extractor_out = feature_extractor(
+            source,
+            padding_mask,
+            mask,
+            remove_masked=not features_only or force_remove_masked,
+            clone_batch=self.cfg.get("clone_batch") if not features_only else 1,
+            mask_seeds=mask_seeds,
+            precomputed_mask=precomputed_mask,
+        )
+
+        x = extractor_out["x"]
+        encoder_mask = extractor_out["encoder_mask"]
+        masked_padding_mask = extractor_out["padding_mask"]
+        masked_alibi_bias = extractor_out.get("alibi_bias", None)
+        alibi_scale = extractor_out.get("alibi_scale", None)
+
+        if self.dropout_input is not None:
+            x = self.dropout_input(x)
+
+        layer_results = []
+        for i, blk in enumerate(self.blocks):
+            if (
+                not self.training
+                or self.cfg.get("layerdrop", 0) == 0
+                or (np.random.random() > self.cfg.get("layerdrop", 0))
+            ):
+                ab = masked_alibi_bias
+                if ab is not None and alibi_scale is not None:
+                    scale = (
+                        alibi_scale[i]
+                        if alibi_scale.size(0) > 1
+                        else alibi_scale.squeeze(0)
+                    )
+                    ab = ab * scale.type_as(ab)
+
+                x, lr = blk(
+                    x,
+                    padding_mask=masked_padding_mask,
+                    alibi_bias=ab,
+                )
+                if features_only:
+                    layer_results.append(lr)
+
+        if self.norm is not None:
+            x = self.norm(x)
+
+        if features_only:
+            if remove_extra_tokens:
+                x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
+                if masked_padding_mask is not None:
+                    masked_padding_mask = masked_padding_mask[
+                        :, feature_extractor.modality_cfg.num_extra_tokens :
+                    ]
+
+            return {
+                "x": x,
+                "padding_mask": masked_padding_mask,
+                "layer_results": layer_results,
+                "mask": encoder_mask,
+            }
+
+    def extract_features(
+        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
+    ):
+        res = self.forward(
+            source,
+            mode=mode,
+            padding_mask=padding_mask,
+            mask=mask,
+            features_only=True,
+            remove_extra_tokens=remove_extra_tokens,
+        )
+        return res
+
+    def generate(self,
+                 data_in,
+                 data_lengths=None,
+                 key: list = None,
+                 tokenizer=None,
+                 frontend=None,
+                 **kwargs,
+                 ):
+    
+        # if source_file.endswith('.wav'):
+        #     wav, sr = sf.read(source_file)
+        #     channel = sf.info(source_file).channels
+        #     assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
+        #     assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
+        granularity = kwargs.get("granularity", "utterance")
+        meta_data = {}
+        # extract fbank feats
+        time1 = time.perf_counter()
+        audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000),
+                                                        data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
+        time2 = time.perf_counter()
+        meta_data["load_data"] = f"{time2 - time1:0.3f}"
+        results = []
+        for i, wav in enumerate(audio_sample_list):
+            source = wav.to(device=kwargs["device"])
+            if self.cfg.normalize:
+                source = F.layer_norm(source, source.shape)
+            source = source.view(1, -1)
+
+            feats = self.extract_features(source, padding_mask=None)
+            feats = feats['x'].squeeze(0).cpu().numpy()
+            if granularity == 'frame':
+                feats = feats
+            elif granularity == 'utterance':
+                feats = np.mean(feats, axis=0)
+            
+            result_i = {"key": key[i], "feats": feats}
+            results.append(result_i)
+            
+        return results, meta_data
\ No newline at end of file
diff --git a/funasr/models/emotion2vec/modules.py b/funasr/models/emotion2vec/modules.py
new file mode 100644
index 0000000..33947f2
--- /dev/null
+++ b/funasr/models/emotion2vec/modules.py
@@ -0,0 +1,323 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from dataclasses import dataclass
+from funasr.models.emotion2vec.fairseq_modules import (
+    LayerNorm,
+    SamePad,
+    TransposeLast,
+)
+
+from enum import Enum, auto
+class Modality(Enum):
+    AUDIO = auto()
+
+
+   
+@dataclass
+class D2vDecoderConfig:
+    decoder_dim: int = 384
+    decoder_groups: int = 16
+    decoder_kernel: int = 5
+    decoder_layers: int = 5
+    input_dropout: float = 0.1
+
+    add_positions_masked: bool = False
+    add_positions_all: bool = False
+
+    decoder_residual: bool = True
+    projection_layers: int = 1
+    projection_ratio: float = 2.0
+
+
+class FixedPositionalEncoder(nn.Module):
+    def __init__(self, pos_embed):
+        super().__init__()
+        self.positions = pos_embed
+
+    def forward(self, x, padding_mask):
+        return self.positions
+
+
+class TextFeatPositionalEncoder(nn.Module):
+    """
+    Original encoder expects (B, T) long input. This module wraps it to take
+    local_encoder output which are (B, T, D) float tensors
+    """
+
+    def __init__(self, pos_encoder):
+        super().__init__()
+        self.pos_encoder = pos_encoder
+
+    def forward(self, x, padding_mask):
+        # assume padded token embeddings are 0s
+        # TODO: consider using padding_mask as input
+        return self.pos_encoder(x[..., 0])
+
+
+class BlockEncoder(nn.Module):
+    def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout):
+        super().__init__()
+        self.blocks = blocks
+        self.norm = norm_layer
+        self.layer_norm_first = layer_norm_first
+        self.layerdrop = layerdrop
+        self.dropout = nn.Dropout(dropout, inplace=True)
+
+    def forward(self, x, padding_mask, alibi_bias, alibi_scale):
+        if self.norm is not None and not self.layer_norm_first:
+            x = self.norm(x)
+
+        x = self.dropout(x)
+
+        for i, blk in enumerate(self.blocks):
+            if (
+                not self.training
+                or self.layerdrop == 0
+                or (np.random.random() > self.layerdrop)
+            ):
+                ab = alibi_bias
+                if ab is not None and alibi_scale is not None:
+                    scale = (
+                        alibi_scale[i]
+                        if alibi_scale.size(0) > 1
+                        else alibi_scale.squeeze(0)
+                    )
+                    ab = ab * scale.type_as(ab)
+                x, _ = blk(x, padding_mask, ab)
+
+        if self.norm is not None and self.layer_norm_first:
+            x = self.norm(x)
+
+        return x
+
+
+class DecoderBase(nn.Module):
+    decoder_cfg: D2vDecoderConfig
+
+    def __init__(self, cfg: D2vDecoderConfig):
+        super().__init__()
+
+        self.decoder_cfg = cfg
+
+    def reset_parameters(self):
+        for mod in self.proj.modules():
+            if isinstance(mod, nn.Linear):
+                mod.reset_parameters()
+
+    def add_residual(self, x, residual, i, mask_info):
+        if (
+            residual is None
+            or not self.decoder_cfg.decoder_residual
+            or residual.size(1) != x.size(1)
+        ):
+            return x
+
+        ret = x + residual
+
+        return ret
+
+
+class Decoder1d(DecoderBase):
+    def __init__(self, cfg: D2vDecoderConfig, input_dim):
+        super().__init__(cfg)
+
+        def make_block(in_dim):
+            block = [
+                nn.Conv1d(
+                    in_dim,
+                    cfg.decoder_dim,
+                    kernel_size=cfg.decoder_kernel,
+                    padding=cfg.decoder_kernel // 2,
+                    groups=cfg.decoder_groups,
+                ),
+                SamePad(cfg.decoder_kernel),
+                TransposeLast(),
+                LayerNorm(cfg.decoder_dim, elementwise_affine=False),
+                TransposeLast(),
+                nn.GELU(),
+            ]
+
+            return nn.Sequential(*block)
+
+        self.blocks = nn.Sequential(
+            *[
+                make_block(input_dim if i == 0 else cfg.decoder_dim)
+                for i in range(cfg.decoder_layers)
+            ]
+        )
+
+        projs = []
+        curr_dim = cfg.decoder_dim
+        for i in range(cfg.projection_layers - 1):
+            next_dim = int(curr_dim * cfg.projection_ratio) if i == 0 else curr_dim
+            projs.append(nn.Linear(curr_dim, next_dim))
+            projs.append(nn.GELU())
+            curr_dim = next_dim
+        projs.append(nn.Linear(curr_dim, input_dim))
+        if len(projs) == 1:
+            self.proj = projs[0]
+        else:
+            self.proj = nn.Sequential(*projs)
+
+    def forward(self, x, mask_info):
+
+        x = x.transpose(1, 2)
+
+        residual = x
+
+        for i, layer in enumerate(self.blocks):
+            x = layer(x)
+            x = self.add_residual(x, residual, i, mask_info)
+            residual = x
+
+        x = x.transpose(1, 2)
+        x = self.proj(x)
+        return x
+
+
+class AltBlock(nn.Module):
+    def __init__(
+        self,
+        dim,
+        num_heads,
+        mlp_ratio=4.0,
+        qkv_bias=False,
+        qk_scale=None,
+        drop=0.0,
+        attn_drop=0.0,
+        mlp_drop=0.0,
+        post_mlp_drop=0.0,
+        drop_path=0.0,
+        act_layer=nn.GELU,
+        norm_layer=nn.LayerNorm,
+        layer_norm_first=True,
+        ffn_targets=False,
+        cosine_attention=False,
+    ):
+        super().__init__()
+
+        self.layer_norm_first = layer_norm_first
+        self.ffn_targets = ffn_targets
+
+        from funasr.models.emotion2vec.timm_modules import DropPath, Mlp
+
+        self.norm1 = norm_layer(dim)
+        self.attn = AltAttention(
+            dim,
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            qk_scale=qk_scale,
+            attn_drop=attn_drop,
+            proj_drop=drop,
+            cosine_attention=cosine_attention,
+        )
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(
+            in_features=dim,
+            hidden_features=mlp_hidden_dim,
+            act_layer=act_layer,
+            drop=mlp_drop,
+        )
+        self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
+
+    def forward(self, x, padding_mask=None, alibi_bias=None):
+        if self.layer_norm_first:
+            x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
+            r = x = self.mlp(self.norm2(x))
+            t = x
+            x = r + self.drop_path(self.post_mlp_dropout(x))
+            if not self.ffn_targets:
+                t = x
+        else:
+            x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
+            r = x = self.norm1(x)
+            x = self.mlp(x)
+            t = x
+            x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
+            if not self.ffn_targets:
+                t = x
+
+        return x, t
+
+
+class AltAttention(nn.Module):
+    def __init__(
+        self,
+        dim,
+        num_heads=8,
+        qkv_bias=False,
+        qk_scale=None,
+        attn_drop=0.0,
+        proj_drop=0.0,
+        cosine_attention=False,
+    ):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        self.cosine_attention = cosine_attention
+
+        if cosine_attention:
+            self.logit_scale = nn.Parameter(
+                torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
+            )
+
+    def forward(self, x, padding_mask=None, alibi_bias=None):
+        B, N, C = x.shape
+        qkv = (
+            self.qkv(x)
+            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+            .permute(2, 0, 3, 1, 4)  # qkv x B x H x L x D
+        )
+        q, k, v = (
+            qkv[0],
+            qkv[1],
+            qkv[2],
+        )  # make torchscript happy (cannot use tensor as tuple)
+
+        dtype = q.dtype
+
+        if self.cosine_attention:
+            # cosine attention
+            attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
+            logit_scale = torch.clamp(
+                self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
+            ).exp()
+            attn = attn * logit_scale
+        else:
+            q = q * self.scale
+            attn = q @ k.transpose(-2, -1)
+
+        if alibi_bias is not None:
+            attn = attn.type_as(alibi_bias)
+            attn[:, : alibi_bias.size(1)] += alibi_bias
+
+        if padding_mask is not None and padding_mask.any():
+            attn = attn.masked_fill(
+                padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+                float("-inf"),
+            )
+
+        attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
+        attn = self.attn_drop(attn)
+        x = (attn @ v).transpose(1, 2)  #
+        x = x.reshape(B, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
diff --git a/funasr/models/emotion2vec/template.yaml b/funasr/models/emotion2vec/template.yaml
new file mode 100644
index 0000000..53bca63
--- /dev/null
+++ b/funasr/models/emotion2vec/template.yaml
@@ -0,0 +1,113 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: Emotion2vec
+model_conf:
+    loss_beta: 0.0 
+    loss_scale: null
+    depth: 8
+    start_drop_path_rate: 0.0
+    end_drop_path_rate: 0.0
+    num_heads: 12
+    norm_eps: 1e-05
+    norm_affine: true
+    encoder_dropout: 0.1
+    post_mlp_drop: 0.1
+    attention_dropout: 0.1
+    activation_dropout: 0.0
+    dropout_input: 0.0
+    layerdrop: 0.05
+    embed_dim: 768
+    mlp_ratio: 4.0
+    layer_norm_first: false
+    average_top_k_layers: 8
+    end_of_block_targets: false
+    clone_batch: 8
+    layer_norm_target_layer: false
+    batch_norm_target_layer: false
+    instance_norm_target_layer: true
+    instance_norm_targets: false
+    layer_norm_targets: false
+    ema_decay: 0.999
+    ema_same_dtype: true
+    log_norms: true
+    ema_end_decay: 0.99999
+    ema_anneal_end_step: 20000
+    ema_encoder_only: false
+    max_update: 100000
+    extractor_mode: layer_norm
+    shared_decoder: null
+    min_target_var: 0.1
+    min_pred_var: 0.01
+    supported_modality: AUDIO
+    mae_init: false
+    seed: 1
+    skip_ema: false
+    cls_loss: 1.0
+    recon_loss: 0.0
+    d2v_loss: 1.0
+    decoder_group: false
+    adversarial_training: false
+    adversarial_hidden_dim: 128
+    adversarial_weight: 0.1
+    cls_type: chunk
+    normalize: true
+
+    modalities:
+        audio:
+            type: AUDIO
+            prenet_depth: 4
+            prenet_layerdrop: 0.05
+            prenet_dropout: 0.1
+            start_drop_path_rate: 0.0
+            end_drop_path_rate: 0.0
+            num_extra_tokens: 10
+            init_extra_token_zero: true
+            mask_noise_std: 0.01
+            mask_prob_min: null
+            mask_prob: 0.5
+            inverse_mask: false
+            mask_prob_adjust: 0.05
+            keep_masked_pct: 0.0
+            mask_length: 5
+            add_masks: false
+            remove_masks: false
+            mask_dropout: 0.0
+            encoder_zero_mask: true
+            mask_channel_prob: 0.0
+            mask_channel_length: 64
+            ema_local_encoder: false
+            local_grad_mult: 1.0
+            use_alibi_encoder: true
+            alibi_scale: 1.0
+            learned_alibi: false
+            alibi_max_pos: null
+            learned_alibi_scale: true
+            learned_alibi_scale_per_head: true
+            learned_alibi_scale_per_layer: false
+            num_alibi_heads: 12
+            model_depth: 8
+            decoder:
+                decoder_dim: 384
+                decoder_groups: 16
+                decoder_kernel: 7
+                decoder_layers: 4
+                input_dropout: 0.1
+                add_positions_masked: false
+                add_positions_all: false
+                decoder_residual: true
+                projection_layers: 1
+                projection_ratio: 2.0
+            extractor_mode: layer_norm
+            feature_encoder_spec: '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]'
+            conv_pos_width: 95
+            conv_pos_groups: 16
+            conv_pos_depth: 5
+            conv_pos_pre_ln: false
+
+
diff --git a/funasr/models/emotion2vec/timm_modules.py b/funasr/models/emotion2vec/timm_modules.py
new file mode 100644
index 0000000..1f6285a
--- /dev/null
+++ b/funasr/models/emotion2vec/timm_modules.py
@@ -0,0 +1,100 @@
+from itertools import repeat
+import collections.abc
+from functools import partial
+from typing import Optional, Tuple
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+    'survival rate' as the argument.
+
+    """
+    if drop_prob == 0. or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+    if keep_prob > 0.0 and scale_by_keep:
+        random_tensor.div_(keep_prob)
+    return x * random_tensor
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+        self.scale_by_keep = scale_by_keep
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+    def extra_repr(self):
+        return f'drop_prob={round(self.drop_prob,3):0.3f}'
+    
+
+
+
+
+# From PyTorch internals
+def _ntuple(n):
+    def parse(x):
+        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+            return tuple(x)
+        return tuple(repeat(x, n))
+    return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
+
+class Mlp(nn.Module):
+    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+    """
+    def __init__(
+            self,
+            in_features,
+            hidden_features=None,
+            out_features=None,
+            act_layer=nn.GELU,
+            norm_layer=None,
+            bias=True,
+            drop=0.,
+            use_conv=False,
+    ):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        bias = to_2tuple(bias)
+        drop_probs = to_2tuple(drop)
+        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+        self.act = act_layer()
+        self.drop1 = nn.Dropout(drop_probs[0])
+        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
+        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+        self.drop2 = nn.Dropout(drop_probs[1])
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop1(x)
+        x = self.norm(x)
+        x = self.fc2(x)
+        x = self.drop2(x)
+        return x
+
+
diff --git a/funasr/register.py b/funasr/register.py
index 145a698..15363c0 100644
--- a/funasr/register.py
+++ b/funasr/register.py
@@ -19,13 +19,16 @@
     dataset_classes = {}
     index_ds_classes = {}
 
-    def print(self,):
+    def print(self, key=None):
         print("\ntables: \n")
         fields = vars(self)
         for classes_key, classes_dict in fields.items():
-            print(f"-----------    ** {classes_key.replace('_meta', '')} **    --------------")
-        
-            if classes_key.endswith("_meta"):
+            
+            flag = True
+            if key is not None:
+                flag = key in classes_key
+            if classes_key.endswith("_meta") and flag:
+                print(f"-----------    ** {classes_key.replace('_meta', '')} **    --------------")
                 headers = ["class name", "register name", "class location"]
                 metas = []
                 for register_key, meta in classes_dict.items():
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index 963d734..a6596a0 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -105,6 +105,7 @@
     else:
         buffer = BytesIO(oss_bucket.get_object(path).read())
         src_state = torch.load(buffer, map_location=map_location)
+    src_state = src_state["model"] if "model" in src_state else src_state
     if excludes is not None:
         for e in excludes.split(","):
             src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}

--
Gitblit v1.9.1