From 2acd24f0158b2c86d2fb4e6f1134b67a1150500e Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 29 二月 2024 17:14:59 +0800
Subject: [PATCH] update whisper lid (#1407)

---
 funasr/models/whisper_lid/lid_predictor.py           |   25 
 funasr/models/whisper_lid/decoder.py                 |  167 +++++
 funasr/models/whisper_lid/eres2net/simple_avg.py     |   17 
 funasr/models/whisper_lid/encoder.py                 |  119 ++++
 funasr/models/whisper_lid/__init__.py                |    0 
 funasr/models/whisper_lid/eres2net/fusion.py         |   29 +
 funasr/models/whisper_lid/eres2net/ResNet.py         |  428 ++++++++++++++
 examples/common_voice/whisper_lid/demo_funasr.py     |   19 
 examples/common_voice/whisper_lid/demo_modelscope.py |   22 
 funasr/models/whisper_lid/eres2net/pooling_layers.py |  118 ++++
 funasr/models/whisper_lid/model.py                   |  665 ++++++++++++++++++++++
 funasr/models/whisper_lid/eres2net/__init__.py       |    0 
 funasr/frontends/whisper_frontend.py                 |  102 +++
 13 files changed, 1,711 insertions(+), 0 deletions(-)

diff --git a/examples/common_voice/whisper_lid/demo_funasr.py b/examples/common_voice/whisper_lid/demo_funasr.py
new file mode 100644
index 0000000..9af790e
--- /dev/null
+++ b/examples/common_voice/whisper_lid/demo_funasr.py
@@ -0,0 +1,19 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+multilingual_wavs = [
+    "example_zh-CN.mp3",
+    "example_en.mp3",
+    "example_ja.mp3",
+    "example_ko.mp3",
+]
+
+model = AutoModel(model="iic/speech_whisper-large_lid_multilingual_pytorch", model_revision="v2.0.4")
+for wav_id in multilingual_wavs:
+    wav_file = f"{model.model_path}/examples/{wav_id}"
+    res = model.generate(input=wav_file, data_type="sound", inference_clip_length=250)
+    print("detect sample {}: {}".format(wav_id, res))
\ No newline at end of file
diff --git a/examples/common_voice/whisper_lid/demo_modelscope.py b/examples/common_voice/whisper_lid/demo_modelscope.py
new file mode 100644
index 0000000..cce389a
--- /dev/null
+++ b/examples/common_voice/whisper_lid/demo_modelscope.py
@@ -0,0 +1,22 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+multilingual_wavs=[
+    "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_zh-CN.mp3",
+    "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_en.mp3",
+    "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ja.mp3",
+    "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ko.mp3",
+]
+
+inference_pipeline = pipeline(
+    task=Tasks.auto_speech_recognition,
+    model='iic/speech_whisper-large_lid_multilingual_pytorch', model_revision="v2.0.4")
+
+for wav in multilingual_wavs:
+    rec_result = inference_pipeline(input=wav, inference_clip_length=250)
+    print(rec_result)
\ No newline at end of file
diff --git a/funasr/frontends/whisper_frontend.py b/funasr/frontends/whisper_frontend.py
new file mode 100644
index 0000000..752fd20
--- /dev/null
+++ b/funasr/frontends/whisper_frontend.py
@@ -0,0 +1,102 @@
+from typing import Tuple
+import torch
+import torch.nn as nn
+import whisper
+from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
+from funasr.register import tables
+from torch.nn.utils.rnn import pad_sequence
+
+
+@tables.register("frontend_classes", "WhisperFrontend")
+class WhisperFrontend(nn.Module):
+    """Speech Representation Using Encoder Outputs from OpenAI's Whisper Model:
+
+    URL: https://github.com/openai/whisper
+    """
+
+    def __init__(
+            self,
+            fs: int = 16000,
+            whisper_model: str = "large-v3",
+            do_pad_trim: bool = True,
+    ):
+        super().__init__()
+        assert fs == 16000
+        self.fs = fs
+
+        self.n_fft = N_FFT
+        self.win_length = N_FFT
+        self.hop_length = HOP_LENGTH
+        self.pad_samples = N_SAMPLES
+        self.frame_shift = self.hop_length
+        self.lfr_n = 1
+        if whisper_model == "large-v3" or whisper_model == "large":
+            self.n_mels = 128
+        else:
+            self.n_mels = 80
+
+        self.mel_filters = whisper.audio.mel_filters
+        self.do_pad_trim = do_pad_trim
+        if do_pad_trim:
+            self.pad_or_trim = whisper.pad_or_trim
+
+        assert whisper_model in whisper.available_models()
+
+    def output_size(self) -> int:
+        return self.n_mels
+
+    def log_mel_spectrogram(
+            self,
+            audio: torch.Tensor,
+            ilens: torch.Tensor = None,
+    ) -> torch.Tensor:
+        window = torch.hann_window(self.win_length).to(audio.device)
+        stft = torch.stft(
+            audio, self.n_fft, self.hop_length, window=window, return_complex=True
+        )
+
+        # whisper deletes the last frame by default (Shih-Lun)
+        magnitudes = stft[..., :-1].abs() ** 2
+
+        filters = self.mel_filters(audio.device, self.n_mels)
+        mel_spec = filters @ magnitudes
+
+        log_spec = torch.clamp(mel_spec, min=1e-10).log10()
+
+        if ilens is not None:
+            olens = ilens // self.hop_length
+        else:
+            olens = None
+
+        log_spec = torch.maximum(
+            log_spec,
+            log_spec.view(audio.size(0), -1).max(dim=-1)[0][:, None, None] - 8.0,
+        )
+        log_spec = (log_spec + 4.0) / 4.0
+
+        return log_spec, olens
+
+    def forward(
+            self, input: torch.Tensor, input_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        batch_size = input.size(0)
+        feats = []
+        feats_lens = []
+        for i in range(batch_size):
+            if self.do_pad_trim:
+                feat = self.pad_or_trim(input[i], self.pad_samples)
+            else:
+                feat = input[i]
+            feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[0])
+            feats.append(feat[0])
+            feats_lens.append(feat_len)
+        feats_lens = torch.as_tensor(feats_lens)
+
+        if batch_size == 1:
+            feats_pad = feats[0][None, :, :]
+        else:
+            feats_pad = pad_sequence(feats,
+                                     batch_first=True,
+                                     padding_value=0.0)
+
+        return feats_pad, feats_lens
\ No newline at end of file
diff --git a/funasr/models/whisper_lid/__init__.py b/funasr/models/whisper_lid/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/whisper_lid/__init__.py
diff --git a/funasr/models/whisper_lid/decoder.py b/funasr/models/whisper_lid/decoder.py
new file mode 100644
index 0000000..4db9205
--- /dev/null
+++ b/funasr/models/whisper_lid/decoder.py
@@ -0,0 +1,167 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+import copy
+from typing import Any, List, Tuple
+
+import torch
+from torch import nn
+import whisper
+
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.register import tables
+
+
+@tables.register("decoder_classes", "OpenAIWhisperDecoderWarp")
+class OpenAIWhisperDecoderWarp(nn.Module):
+    """Transformer-based Speech-to-Text Decoder from OpenAI's Whisper Model:
+
+    URL: https://github.com/openai/whisper
+    """
+
+    def __init__(
+        self,
+        dropout_rate: float = 0.0,
+        whisper_model: str = "small",
+        download_dir: str = None,
+        use_padmask: bool = False,
+    ):
+        super().__init__()
+
+        assert whisper_model in whisper.available_models()
+        _model = whisper.load_model(
+            whisper_model, download_root=download_dir, device="cpu"
+        )
+        self.decoders = copy.deepcopy(_model.decoder)
+        attention_dim = self.decoders.token_embedding.embedding_dim
+
+        # note that originally Whisper doesn't use dropouts
+        self.dropout = torch.nn.Dropout(dropout_rate)
+
+        self.decoders.train()
+        del _model
+        self.use_padmask = use_padmask
+
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        tgt, memory = ys_in_pad, hs_pad
+        tgt = (
+            self.decoders.token_embedding(tgt)
+            + self.decoders.positional_embedding[: tgt.size(1)]
+        )
+        tgt = self.dropout(tgt)
+
+        x = tgt.to(memory.dtype)
+
+        if self.use_padmask:
+            memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
+        else:
+            memory_mask = None
+
+        for layer, block in enumerate(self.decoders.blocks):
+            x = block(x, memory, mask=self.decoders.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
+
+            if layer < len(self.decoders.blocks) - 1:
+                x = self.dropout(x)
+
+        x = self.decoders.ln(x)
+        x = (
+            x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
+        ).float()
+
+        return x, ys_in_lens
+
+    def forward_one_step(
+        self,
+        tgt: torch.Tensor,
+        tgt_mask: torch.Tensor,
+        memory: torch.Tensor,
+        cache: List[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        """Forward one step.
+
+        Args:
+            tgt: input token ids, int64 (batch, maxlen_out)
+            tgt_mask: input token mask,  (batch, maxlen_out)
+                      dtype=torch.uint8 in PyTorch 1.2-
+                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+            memory: encoded memory, float32  (batch, maxlen_in, feat)
+            cache: cached output list of (batch, max_time_out-1, size)
+        Returns:
+            y, cache: NN output value and cache per `self.decoders`.
+            y.shape` is (batch, maxlen_out, token)
+        NOTE (Shih-Lun):
+            cache implementation is ignored for now
+            for simplicity & correctness
+        """
+        x = (
+            self.decoders.token_embedding(tgt)
+            + self.decoders.positional_embedding[: tgt.size(1)]
+        )
+        x = self.dropout(x)
+        x = x.to(memory.dtype)
+
+        for layer, block in enumerate(self.decoders.blocks):
+            x = block(x, memory, mask=self.decoders.mask)
+            if layer < len(self.decoders.blocks) - 1:
+                x = self.dropout(x)
+
+        x = self.decoders.ln(x)
+        y = x[:, -1]
+        y = (
+            y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
+        ).float()
+        y = torch.log_softmax(y, dim=-1)
+
+        return y, None
+
+    def score(self, ys, state, x):
+        """Score."""
+        logp, state = self.forward_one_step(
+            ys.unsqueeze(0), torch.empty(0), x.unsqueeze(0), cache=state  # dummy mask
+        )
+        return logp.squeeze(0), state
+
+    def batch_score(
+        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
+    ) -> Tuple[torch.Tensor, List[Any]]:
+        """Score new token batch.
+
+        Args:
+            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
+            states (List[Any]): Scorer states for prefix tokens.
+            xs (torch.Tensor):
+                The encoder feature that generates ys (n_batch, xlen, n_feat).
+
+        Returns:
+            tuple[torch.Tensor, List[Any]]: Tuple of
+                batchfied scores for next token with shape of `(n_batch, n_vocab)`
+                and next state list for ys.
+
+        """
+        # batch decoding, dummy mask is passed
+        logp, states = self.forward_one_step(ys, torch.empty(0), xs, cache=None)
+
+        return logp, None
diff --git a/funasr/models/whisper_lid/encoder.py b/funasr/models/whisper_lid/encoder.py
new file mode 100644
index 0000000..7eeb643
--- /dev/null
+++ b/funasr/models/whisper_lid/encoder.py
@@ -0,0 +1,119 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+import copy
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+import whisper
+
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.specaug.specaug import SpecAug
+from funasr.register import tables
+
+
+@tables.register("encoder_classes", "OpenAIWhisperEncoderWarp")
+class OpenAIWhisperEncoderWarp(nn.Module):
+    """Transformer-based Speech Encoder from OpenAI's Whisper Model:
+
+    URL: https://github.com/openai/whisper
+    """
+
+    def __init__(
+            self,
+            dropout_rate: float = 0.0,
+            whisper_model: str = "small",
+            download_dir: str = None,
+            use_specaug: bool = False,
+            use_padmask: bool = False,
+            specaug_conf: Union[dict, None] = None,
+    ):
+        super().__init__()
+
+        # note that originally Whisper doesn't use dropouts
+        self.dropout = torch.nn.Dropout(dropout_rate)
+
+        assert whisper_model in whisper.available_models()
+        _model = whisper.load_model(
+            whisper_model, download_root=download_dir, device="cpu"
+        )
+        self.encoders = copy.deepcopy(_model.encoder)
+        self.encoders.train()
+
+        del _model
+
+        if use_specaug:
+            self.specaug = SpecAug(**specaug_conf)
+        else:
+            self.specaug = None
+        self.use_padmask = use_padmask
+
+    def whisper_encode(
+            self,
+            input: torch.Tensor,
+            ilens: torch.Tensor = None,
+    ) -> torch.Tensor:
+        x = F.gelu(self.encoders.conv1(input))
+        x = F.gelu(self.encoders.conv2(x))
+        x = x.permute(0, 2, 1)
+
+        n_frames = x.size(1)
+        max_pos = self.encoders.positional_embedding.size(0)
+        if n_frames <= max_pos:
+            x = (x + self.encoders.positional_embedding[: x.size(1), :]).to(x.dtype)
+        else:
+            # due to positional encoding, audios >30 sec won't be accepted
+            x = x[:, :max_pos, :] + self.encoders.positional_embedding
+
+        if ilens is not None:
+            olens = (
+                    1
+                    + (
+                            ilens
+                            - self.encoders.conv2.kernel_size[0]
+                            + 2 * self.encoders.conv2.padding[0]
+                    )
+                    // self.encoders.conv2.stride[0]
+            )
+            olens = torch.clamp(olens, max=max_pos)
+        else:
+            olens = None
+
+        if self.use_padmask:
+            padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
+        else:
+            padding_mask = None
+
+        x = self.dropout(x)
+
+        for layer, block in enumerate(self.encoders.blocks):
+            x = block(x)
+            if layer < len(self.encoders.blocks) - 1:
+                x = self.dropout(x)
+
+        x = self.encoders.ln_post(x)
+
+        return x, olens
+
+    def output_size(self) -> int:
+        # dummy output size
+        return self.encoders.conv2.weight.shape[0]
+
+    def forward(
+            self,
+            xs_pad: torch.Tensor,
+            ilens: torch.Tensor,
+            prev_states: torch.Tensor = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        feats, feats_lens = xs_pad, ilens
+
+        if self.specaug is not None and self.encoders.training:
+            feats = torch.transpose(feats, 1, 2)
+            feats, feats_lens = self.specaug(feats, feats_lens)
+            feats = torch.transpose(feats, 1, 2)
+
+        xs_pad, olens = self.whisper_encode(feats, feats_lens)
+
+        return xs_pad, olens, None
diff --git a/funasr/models/whisper_lid/eres2net/ResNet.py b/funasr/models/whisper_lid/eres2net/ResNet.py
new file mode 100644
index 0000000..25c79f5
--- /dev/null
+++ b/funasr/models/whisper_lid/eres2net/ResNet.py
@@ -0,0 +1,428 @@
+# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
+    ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
+    The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
+    The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
+    ERes2Net-Large is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
+    recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
+"""
+
+import torch
+import math
+import torch.nn as nn
+import torch.nn.functional as F
+import funasr.models.whisper_lid.eres2net.pooling_layers as pooling_layers
+from funasr.models.whisper_lid.eres2net.fusion import AFF
+
+
+class ReLU(nn.Hardtanh):
+
+    def __init__(self, inplace=False):
+        super(ReLU, self).__init__(0, 20, inplace)
+
+    def __repr__(self):
+        inplace_str = 'inplace' if self.inplace else ''
+        return self.__class__.__name__ + ' (' \
+               + inplace_str + ')'
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    "1x1 convolution without padding"
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
+                     padding=0, bias=False)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+    "3x3 convolution with padding"
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+
+class BasicBlockERes2Net(nn.Module):
+    expansion = 2
+
+    def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
+        super(BasicBlockERes2Net, self).__init__()
+        width = int(math.floor(planes * (baseWidth / 64.0)))
+        self.conv1 = conv1x1(in_planes, width * scale, stride)
+        self.bn1 = nn.BatchNorm2d(width * scale)
+        self.nums = scale
+
+        convs = []
+        bns = []
+        for i in range(self.nums):
+            convs.append(conv3x3(width, width))
+            bns.append(nn.BatchNorm2d(width))
+        self.convs = nn.ModuleList(convs)
+        self.bns = nn.ModuleList(bns)
+        self.relu = ReLU(inplace=True)
+
+        self.conv3 = conv1x1(width * scale, planes * self.expansion)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes,
+                          self.expansion * planes,
+                          kernel_size=1,
+                          stride=stride,
+                          bias=False),
+                nn.BatchNorm2d(self.expansion * planes))
+        self.stride = stride
+        self.width = width
+        self.scale = scale
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+        spx = torch.split(out, self.width, 1)
+        for i in range(self.nums):
+            if i == 0:
+                sp = spx[i]
+            else:
+                sp = sp + spx[i]
+            sp = self.convs[i](sp)
+            sp = self.relu(self.bns[i](sp))
+            if i == 0:
+                out = sp
+            else:
+                out = torch.cat((out, sp), 1)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        residual = self.shortcut(x)
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class BasicBlockERes2Net_diff_AFF(nn.Module):
+    expansion = 2
+
+    def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
+        super(BasicBlockERes2Net_diff_AFF, self).__init__()
+        width = int(math.floor(planes * (baseWidth / 64.0)))
+        self.conv1 = conv1x1(in_planes, width * scale, stride)
+        self.bn1 = nn.BatchNorm2d(width * scale)
+        self.nums = scale
+
+        convs = []
+        fuse_models = []
+        bns = []
+        for i in range(self.nums):
+            convs.append(conv3x3(width, width))
+            bns.append(nn.BatchNorm2d(width))
+        for j in range(self.nums - 1):
+            fuse_models.append(AFF(channels=width))
+
+        self.convs = nn.ModuleList(convs)
+        self.bns = nn.ModuleList(bns)
+        self.fuse_models = nn.ModuleList(fuse_models)
+        self.relu = ReLU(inplace=True)
+
+        self.conv3 = conv1x1(width * scale, planes * self.expansion)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes,
+                          self.expansion * planes,
+                          kernel_size=1,
+                          stride=stride,
+                          bias=False),
+                nn.BatchNorm2d(self.expansion * planes))
+        self.stride = stride
+        self.width = width
+        self.scale = scale
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+        spx = torch.split(out, self.width, 1)
+        for i in range(self.nums):
+            if i == 0:
+                sp = spx[i]
+            else:
+                sp = self.fuse_models[i - 1](sp, spx[i])
+
+            sp = self.convs[i](sp)
+            sp = self.relu(self.bns[i](sp))
+            if i == 0:
+                out = sp
+            else:
+                out = torch.cat((out, sp), 1)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        residual = self.shortcut(x)
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class ERes2Net(nn.Module):
+    def __init__(self,
+                 block=BasicBlockERes2Net,
+                 block_fuse=BasicBlockERes2Net_diff_AFF,
+                 num_blocks=[3, 4, 6, 3],
+                 m_channels=32,
+                 feat_dim=80,
+                 embedding_size=192,
+                 pooling_func='TSTP',
+                 two_emb_layer=False):
+        super(ERes2Net, self).__init__()
+        self.in_planes = m_channels
+        self.feat_dim = feat_dim
+        self.embedding_size = embedding_size
+        self.stats_dim = int(feat_dim / 8) * m_channels * 8
+        self.two_emb_layer = two_emb_layer
+        self._output_size = embedding_size
+
+        self.conv1 = nn.Conv2d(1,
+                               m_channels,
+                               kernel_size=3,
+                               stride=1,
+                               padding=1,
+                               bias=False)
+        self.bn1 = nn.BatchNorm2d(m_channels)
+        self.layer1 = self._make_layer(block,
+                                       m_channels,
+                                       num_blocks[0],
+                                       stride=1)
+        self.layer2 = self._make_layer(block,
+                                       m_channels * 2,
+                                       num_blocks[1],
+                                       stride=2)
+        self.layer3 = self._make_layer(block_fuse,
+                                       m_channels * 4,
+                                       num_blocks[2],
+                                       stride=2)
+        self.layer4 = self._make_layer(block_fuse,
+                                       m_channels * 8,
+                                       num_blocks[3],
+                                       stride=2)
+
+        # Downsampling module for each layer
+        self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1,
+                                           bias=False)
+        self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2,
+                                           bias=False)
+        self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2,
+                                           bias=False)
+
+        # Bottom-up fusion module
+        self.fuse_mode12 = AFF(channels=m_channels * 4)
+        self.fuse_mode123 = AFF(channels=m_channels * 8)
+        self.fuse_mode1234 = AFF(channels=m_channels * 16)
+
+        self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
+        self.pool = getattr(pooling_layers, pooling_func)(
+            in_dim=self.stats_dim * block.expansion)
+        self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
+                               embedding_size)
+        if self.two_emb_layer:
+            self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
+            self.seg_2 = nn.Linear(embedding_size, embedding_size)
+        else:
+            self.seg_bn_1 = nn.Identity()
+            self.seg_2 = nn.Identity()
+
+    def _make_layer(self, block, planes, num_blocks, stride):
+        strides = [stride] + [1] * (num_blocks - 1)
+        layers = []
+        for stride in strides:
+            layers.append(block(self.in_planes, planes, stride))
+            self.in_planes = planes * block.expansion
+        return nn.Sequential(*layers)
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(self, x, ilens):
+        # assert x.shape[1] == ilens.max()
+        x = x.permute(0, 2, 1)  # (B,T,F) => (B,F,T)
+        x = x.unsqueeze_(1)
+        out = F.relu(self.bn1(self.conv1(x)))
+        out1 = self.layer1(out)
+        out2 = self.layer2(out1)
+        out1_downsample = self.layer1_downsample(out1)
+        fuse_out12 = self.fuse_mode12(out2, out1_downsample)
+        out3 = self.layer3(out2)
+        fuse_out12_downsample = self.layer2_downsample(fuse_out12)
+        fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
+        out4 = self.layer4(out3)
+        fuse_out123_downsample = self.layer3_downsample(fuse_out123)
+        fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
+        olens = (((((ilens - 1) // 2 + 1) - 1) // 2 + 1) - 1) // 2 + 1
+        stats = self.pool(fuse_out1234, olens)
+
+        embed_a = self.seg_1(stats)
+        if self.two_emb_layer:
+            out = F.relu(embed_a)
+            out = self.seg_bn_1(out)
+            embed_b = self.seg_2(out)
+            return embed_b
+        else:
+            return embed_a
+
+
+class BasicBlockRes2Net(nn.Module):
+    expansion = 2
+
+    def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
+        super(BasicBlockRes2Net, self).__init__()
+        width = int(math.floor(planes * (baseWidth / 64.0)))
+        self.conv1 = conv1x1(in_planes, width * scale, stride)
+        self.bn1 = nn.BatchNorm2d(width * scale)
+        self.nums = scale - 1
+        convs = []
+        bns = []
+        for i in range(self.nums):
+            convs.append(conv3x3(width, width))
+            bns.append(nn.BatchNorm2d(width))
+        self.convs = nn.ModuleList(convs)
+        self.bns = nn.ModuleList(bns)
+        self.relu = ReLU(inplace=True)
+
+        self.conv3 = conv1x1(width * scale, planes * self.expansion)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes,
+                          self.expansion * planes,
+                          kernel_size=1,
+                          stride=stride,
+                          bias=False),
+                nn.BatchNorm2d(self.expansion * planes))
+        self.stride = stride
+        self.width = width
+        self.scale = scale
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+        spx = torch.split(out, self.width, 1)
+        for i in range(self.nums):
+            if i == 0:
+                sp = spx[i]
+            else:
+                sp = sp + spx[i]
+            sp = self.convs[i](sp)
+            sp = self.relu(self.bns[i](sp))
+            if i == 0:
+                out = sp
+            else:
+                out = torch.cat((out, sp), 1)
+
+        out = torch.cat((out, spx[self.nums]), 1)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        residual = self.shortcut(x)
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class Res2Net(nn.Module):
+    def __init__(self,
+                 block=BasicBlockRes2Net,
+                 num_blocks=[3, 4, 6, 3],
+                 m_channels=32,
+                 feat_dim=80,
+                 embedding_size=192,
+                 pooling_func='TSTP',
+                 two_emb_layer=False):
+        super(Res2Net, self).__init__()
+        self.in_planes = m_channels
+        self.feat_dim = feat_dim
+        self.embedding_size = embedding_size
+        self.stats_dim = int(feat_dim / 8) * m_channels * 8
+        self.two_emb_layer = two_emb_layer
+
+        self.conv1 = nn.Conv2d(1,
+                               m_channels,
+                               kernel_size=3,
+                               stride=1,
+                               padding=1,
+                               bias=False)
+        self.bn1 = nn.BatchNorm2d(m_channels)
+        self.layer1 = self._make_layer(block,
+                                       m_channels,
+                                       num_blocks[0],
+                                       stride=1)
+        self.layer2 = self._make_layer(block,
+                                       m_channels * 2,
+                                       num_blocks[1],
+                                       stride=2)
+        self.layer3 = self._make_layer(block,
+                                       m_channels * 4,
+                                       num_blocks[2],
+                                       stride=2)
+        self.layer4 = self._make_layer(block,
+                                       m_channels * 8,
+                                       num_blocks[3],
+                                       stride=2)
+
+        self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
+        self.pool = getattr(pooling_layers, pooling_func)(
+            in_dim=self.stats_dim * block.expansion)
+        self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
+                               embedding_size)
+        if self.two_emb_layer:
+            self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
+            self.seg_2 = nn.Linear(embedding_size, embedding_size)
+        else:
+            self.seg_bn_1 = nn.Identity()
+            self.seg_2 = nn.Identity()
+
+    def _make_layer(self, block, planes, num_blocks, stride):
+        strides = [stride] + [1] * (num_blocks - 1)
+        layers = []
+        for stride in strides:
+            layers.append(block(self.in_planes, planes, stride))
+            self.in_planes = planes * block.expansion
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = x.permute(0, 2, 1)  # (B,T,F) => (B,F,T)
+
+        x = x.unsqueeze_(1)
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.layer1(out)
+        out = self.layer2(out)
+        out = self.layer3(out)
+        out = self.layer4(out)
+
+        stats = self.pool(out)
+
+        embed_a = self.seg_1(stats)
+        if self.two_emb_layer:
+            out = F.relu(embed_a)
+            out = self.seg_bn_1(out)
+            embed_b = self.seg_2(out)
+            return embed_b
+        else:
+            return embed_a
+
+
+
+
diff --git a/funasr/models/whisper_lid/eres2net/__init__.py b/funasr/models/whisper_lid/eres2net/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/whisper_lid/eres2net/__init__.py
diff --git a/funasr/models/whisper_lid/eres2net/fusion.py b/funasr/models/whisper_lid/eres2net/fusion.py
new file mode 100644
index 0000000..2aff7a7
--- /dev/null
+++ b/funasr/models/whisper_lid/eres2net/fusion.py
@@ -0,0 +1,29 @@
+# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import torch
+import torch.nn as nn
+
+
+class AFF(nn.Module):
+
+    def __init__(self, channels=64, r=4):
+        super(AFF, self).__init__()
+        inter_channels = int(channels // r)
+
+        self.local_att = nn.Sequential(
+            nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
+            nn.BatchNorm2d(inter_channels),
+            nn.SiLU(inplace=True),
+            nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+            nn.BatchNorm2d(channels),
+        )
+
+    def forward(self, x, ds_y):
+        xa = torch.cat((x, ds_y), dim=1)
+        x_att = self.local_att(xa)
+        x_att = 1.0 + torch.tanh(x_att)
+        xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att)
+
+        return xo
+
diff --git a/funasr/models/whisper_lid/eres2net/pooling_layers.py b/funasr/models/whisper_lid/eres2net/pooling_layers.py
new file mode 100644
index 0000000..f756ac8
--- /dev/null
+++ b/funasr/models/whisper_lid/eres2net/pooling_layers.py
@@ -0,0 +1,118 @@
+# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+""" This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
+
+import torch
+import torch.nn as nn
+
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+
+
+class TAP(nn.Module):
+    """
+    Temporal average pooling, only first-order mean is considered
+    """
+
+    def __init__(self, **kwargs):
+        super(TAP, self).__init__()
+
+    def forward(self, x):
+        pooling_mean = x.mean(dim=-1)
+        # To be compatable with 2D input
+        pooling_mean = pooling_mean.flatten(start_dim=1)
+        return pooling_mean
+
+
+class TSDP(nn.Module):
+    """
+    Temporal standard deviation pooling, only second-order std is considered
+    """
+
+    def __init__(self, **kwargs):
+        super(TSDP, self).__init__()
+
+    def forward(self, x):
+        # The last dimension is the temporal axis
+        pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
+        pooling_std = pooling_std.flatten(start_dim=1)
+        return pooling_std
+
+
+class TSTP(nn.Module):
+    """
+    Temporal statistics pooling, concatenate mean and std, which is used in
+    x-vector
+    Comment: simple concatenation can not make full use of both statistics
+    """
+
+    def __init__(self, **kwargs):
+        super(TSTP, self).__init__()
+
+    def forward(self, x, olens):
+        # The last dimension is the temporal axis
+        masks = (~make_pad_mask(olens, maxlen=x.shape[-1])[:, None, None, :]).to(x.device)
+        x_masked = x * masks
+        sum_without_padding = torch.sum(x_masked, axis=-1)
+        count_without_padding = torch.sum(masks, axis=-1)
+        mean_without_padding = sum_without_padding / count_without_padding
+
+        var_without_padding = ((x_masked - mean_without_padding.unsqueeze(-1)) ** 2 * masks).sum(-1) / count_without_padding
+
+        pooling_mean = mean_without_padding
+        pooling_std = torch.sqrt(var_without_padding + 1e-8)
+        pooling_mean = pooling_mean.flatten(start_dim=1)
+        pooling_std = pooling_std.flatten(start_dim=1)
+
+        stats = torch.cat((pooling_mean, pooling_std), 1)
+        return stats
+
+
+class ASTP(nn.Module):
+    """ Attentive statistics pooling: Channel- and context-dependent
+        statistics pooling, first used in ECAPA_TDNN.
+    """
+
+    def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
+        super(ASTP, self).__init__()
+        self.global_context_att = global_context_att
+
+        # Use Conv1d with stride == 1 rather than Linear, then we don't
+        # need to transpose inputs.
+        if global_context_att:
+            self.linear1 = nn.Conv1d(
+                in_dim * 3, bottleneck_dim,
+                kernel_size=1)  # equals W and b in the paper
+        else:
+            self.linear1 = nn.Conv1d(
+                in_dim, bottleneck_dim,
+                kernel_size=1)  # equals W and b in the paper
+        self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
+                                 kernel_size=1)  # equals V and k in the paper
+
+    def forward(self, x):
+        """
+        x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
+            or a 4-dimensional tensor in resnet architecture (B,C,F,T)
+            0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
+        """
+        if len(x.shape) == 4:
+            x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
+        assert len(x.shape) == 3
+
+        if self.global_context_att:
+            context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
+            context_std = torch.sqrt(
+                torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
+            x_in = torch.cat((x, context_mean, context_std), dim=1)
+        else:
+            x_in = x
+
+        # DON'T use ReLU here! ReLU may be hard to converge.
+        alpha = torch.tanh(
+            self.linear1(x_in))  # alpha = F.relu(self.linear1(x_in))
+        alpha = torch.softmax(self.linear2(alpha), dim=2)
+        mean = torch.sum(alpha * x, dim=2)
+        var = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
+        std = torch.sqrt(var.clamp(min=1e-10))
+        return torch.cat([mean, std], dim=1)
diff --git a/funasr/models/whisper_lid/eres2net/simple_avg.py b/funasr/models/whisper_lid/eres2net/simple_avg.py
new file mode 100644
index 0000000..4fb4c0a
--- /dev/null
+++ b/funasr/models/whisper_lid/eres2net/simple_avg.py
@@ -0,0 +1,17 @@
+import torch
+
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.modules.nets_utils import make_pad_mask
+
+class SimpleAvg(AbsEncoder):
+    def __init__(self, feat_dim):
+        super(SimpleAvg, self).__init__()
+        self.feat_dim = feat_dim
+
+    def forward(self, x, ilens):
+        mask = ~make_pad_mask(ilens, maxlen=x.shape[1]).to(x.device)
+        avg_x = (x * mask[:, :, None]).sum(1) / mask.sum(-1)[:, None]
+        return avg_x
+
+    def output_size(self) -> int:
+        return self.feat_dim
\ No newline at end of file
diff --git a/funasr/models/whisper_lid/lid_predictor.py b/funasr/models/whisper_lid/lid_predictor.py
new file mode 100644
index 0000000..5e042d2
--- /dev/null
+++ b/funasr/models/whisper_lid/lid_predictor.py
@@ -0,0 +1,25 @@
+from funasr.register import tables
+from funasr.models.whisper_lid.eres2net.ResNet import ERes2Net, BasicBlockERes2Net, BasicBlockERes2Net_diff_AFF
+
+
+@tables.register("lid_predictor_classes", "LidPredictor")
+class LidPredictor(ERes2Net):
+    def __init__(self,
+                 block=BasicBlockERes2Net,
+                 block_fuse=BasicBlockERes2Net_diff_AFF,
+                 num_blocks=[3, 4, 6, 3],
+                 m_channels=32,
+                 feat_dim=80,
+                 embedding_size=192,
+                 pooling_func='TSTP',
+                 two_emb_layer=False):
+        super(LidPredictor, self).__init__(
+                block=block,
+                block_fuse=block_fuse,
+                num_blocks=num_blocks,
+                m_channels=m_channels,
+                feat_dim=feat_dim,
+                embedding_size=embedding_size,
+                pooling_func=pooling_func,
+                two_emb_layer=two_emb_layer
+        )
\ No newline at end of file
diff --git a/funasr/models/whisper_lid/model.py b/funasr/models/whisper_lid/model.py
new file mode 100644
index 0000000..6ffb43a
--- /dev/null
+++ b/funasr/models/whisper_lid/model.py
@@ -0,0 +1,665 @@
+import logging
+from typing import Union, Dict, List, Tuple, Optional
+
+import time
+import torch
+import numpy as np
+import torch.nn as nn
+from torch.cuda.amp import autocast
+
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.ctc.ctc import CTC
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.register import tables
+
+
+@tables.register("model_classes", "OpenAIWhisperModel")
+class OpenAIWhisperModel(nn.Module):
+    """CTC-attention hybrid Encoder-Decoder model"""
+
+    
+    def __init__(
+        self,
+        specaug: str = None,
+        specaug_conf: dict = None,
+        normalize: str = None,
+        normalize_conf: dict = None,
+        encoder: str = None,
+        encoder_conf: dict = None,
+        decoder: str = None,
+        decoder_conf: dict = None,
+        ctc: str = None,
+        ctc_conf: dict = None,
+        ctc_weight: float = 0.5,
+        interctc_weight: float = 0.0,
+        input_size: int = 80,
+        vocab_size: int = -1,
+        ignore_id: int = -1,
+        blank_id: int = 0,
+        sos: int = 1,
+        eos: int = 2,
+        lsm_weight: float = 0.0,
+        length_normalized_loss: bool = False,
+        report_cer: bool = True,
+        report_wer: bool = True,
+        sym_space: str = "<space>",
+        sym_blank: str = "<blank>",
+        # extract_feats_in_collect_stats: bool = True,
+        share_embedding: bool = False,
+        # preencoder: Optional[AbsPreEncoder] = None,
+        # postencoder: Optional[AbsPostEncoder] = None,
+        **kwargs,
+    ):
+
+        super().__init__()
+
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+        if normalize is not None:
+            normalize_class = tables.normalize_classes.get(normalize)
+            normalize = normalize_class(**normalize_conf)
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(input_size=input_size, **encoder_conf)
+        encoder_output_size = encoder.output_size()
+        if decoder is not None:
+            decoder_class = tables.decoder_classes.get(decoder)
+            decoder = decoder_class(decoder_conf)
+        if ctc_weight > 0.0:
+            
+            if ctc_conf is None:
+                ctc_conf = {}
+            
+            ctc = CTC(
+                odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+            )
+    
+        self.blank_id = blank_id
+        self.sos = sos if sos is not None else vocab_size - 1
+        self.eos = eos if eos is not None else vocab_size - 1
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.ctc_weight = ctc_weight
+        self.specaug = specaug
+        self.normalize = normalize
+        self.encoder = encoder
+
+        if not hasattr(self.encoder, "interctc_use_conditioning"):
+            self.encoder.interctc_use_conditioning = False
+        if self.encoder.interctc_use_conditioning:
+            self.encoder.conditioning_layer = torch.nn.Linear(
+                vocab_size, self.encoder.output_size()
+            )
+        self.interctc_weight = interctc_weight
+
+        # self.error_calculator = None
+        if ctc_weight == 1.0:
+            self.decoder = None
+        else:
+            self.decoder = decoder
+        
+        self.criterion_att = LabelSmoothingLoss(
+            size=vocab_size,
+            padding_idx=ignore_id,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+        #
+        # if report_cer or report_wer:
+        #     self.error_calculator = ErrorCalculator(
+        #         token_list, sym_space, sym_blank, report_cer, report_wer
+        #     )
+        #
+        self.error_calculator = None
+        if ctc_weight == 0.0:
+            self.ctc = None
+        else:
+            self.ctc = ctc
+            
+        self.share_embedding = share_embedding
+        if self.share_embedding:
+            self.decoder.embed = None
+        
+        self.length_normalized_loss = length_normalized_loss
+        self.beam_search = None
+    
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+        # import pdb;
+        # pdb.set_trace()
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+        
+        batch_size = speech.shape[0]
+        
+        # 1. Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
+        
+        loss_att, acc_att, cer_att, wer_att = None, None, None, None
+        loss_ctc, cer_ctc = None, None
+        stats = dict()
+        
+        # decoder: CTC branch
+        if self.ctc_weight != 0.0:
+            loss_ctc, cer_ctc = self._calc_ctc_loss(
+                encoder_out, encoder_out_lens, text, text_lengths
+            )
+            
+            # Collect CTC branch stats
+            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+            stats["cer_ctc"] = cer_ctc
+        
+        # Intermediate CTC (optional)
+        loss_interctc = 0.0
+        if self.interctc_weight != 0.0 and intermediate_outs is not None:
+            for layer_idx, intermediate_out in intermediate_outs:
+                # we assume intermediate_out has the same length & padding
+                # as those of encoder_out
+                loss_ic, cer_ic = self._calc_ctc_loss(
+                    intermediate_out, encoder_out_lens, text, text_lengths
+                )
+                loss_interctc = loss_interctc + loss_ic
+                
+                # Collect Intermedaite CTC stats
+                stats["loss_interctc_layer{}".format(layer_idx)] = (
+                    loss_ic.detach() if loss_ic is not None else None
+                )
+                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+            
+            loss_interctc = loss_interctc / len(intermediate_outs)
+            
+            # calculate whole encoder loss
+            loss_ctc = (
+                           1 - self.interctc_weight
+                       ) * loss_ctc + self.interctc_weight * loss_interctc
+        
+        # decoder: Attention decoder branch
+        loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+            encoder_out, encoder_out_lens, text, text_lengths
+        )
+        
+        # 3. CTC-Att loss definition
+        if self.ctc_weight == 0.0:
+            loss = loss_att
+        elif self.ctc_weight == 1.0:
+            loss = loss_ctc
+        else:
+            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
+        
+        # Collect Attn branch stats
+        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+        stats["acc"] = acc_att
+        stats["cer"] = cer_att
+        stats["wer"] = wer_att
+        
+        # Collect total loss stats
+        stats["loss"] = torch.clone(loss.detach())
+        
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + 1).sum())
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+    
+
+    def encode(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+            
+            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                speech, speech_lengths = self.normalize(speech, speech_lengths)
+        
+        # Forward encoder
+        # feats: (Batch, Length, Dim)
+        # -> encoder_out: (Batch, Length2, Dim2)
+        if self.encoder.interctc_use_conditioning:
+            encoder_out, encoder_out_lens, _ = self.encoder(
+                speech, speech_lengths, ctc=self.ctc
+            )
+        else:
+            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
+        
+        if intermediate_outs is not None:
+            return (encoder_out, intermediate_outs), encoder_out_lens
+        
+        return encoder_out, encoder_out_lens
+    
+    def _calc_att_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+        ys_in_lens = ys_pad_lens + 1
+        
+        # 1. Forward decoder
+        decoder_out, _ = self.decoder(
+            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+        )
+        
+        # 2. Compute attention loss
+        loss_att = self.criterion_att(decoder_out, ys_out_pad)
+        acc_att = th_accuracy(
+            decoder_out.view(-1, self.vocab_size),
+            ys_out_pad,
+            ignore_label=self.ignore_id,
+        )
+        
+        # Compute cer/wer using attention-decoder
+        if self.training or self.error_calculator is None:
+            cer_att, wer_att = None, None
+        else:
+            ys_hat = decoder_out.argmax(dim=-1)
+            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+        
+        return loss_att, acc_att, cer_att, wer_att
+    
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        # Calc CTC loss
+        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+        
+        # Calc CER using CTC
+        cer_ctc = None
+        if not self.training and self.error_calculator is not None:
+            ys_hat = self.ctc.argmax(encoder_out).data
+            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+        return loss_ctc, cer_ctc
+    
+    def init_beam_search(self,
+                         **kwargs,
+                         ):
+        from funasr.models.transformer.search import BeamSearch
+        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+        from funasr.models.transformer.scorers.length_bonus import LengthBonus
+    
+        # 1. Build ASR model
+        scorers = {}
+        
+        if self.ctc != None:
+            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+            scorers.update(
+                ctc=ctc
+            )
+        token_list = kwargs.get("token_list")
+        scorers.update(
+            decoder=self.decoder,
+            length_bonus=LengthBonus(len(token_list)),
+        )
+
+        
+        # 3. Build ngram model
+        # ngram is not supported now
+        ngram = None
+        scorers["ngram"] = ngram
+        
+        weights = dict(
+            decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.5),
+            ctc=kwargs.get("decoding_ctc_weight", 0.5),
+            lm=kwargs.get("lm_weight", 0.0),
+            ngram=kwargs.get("ngram_weight", 0.0),
+            length_bonus=kwargs.get("penalty", 0.0),
+        )
+        beam_search = BeamSearch(
+            beam_size=kwargs.get("beam_size", 10),
+            weights=weights,
+            scorers=scorers,
+            sos=self.sos,
+            eos=self.eos,
+            vocab_size=len(token_list),
+            token_list=token_list,
+            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+        )
+
+        self.beam_search = beam_search
+        
+    def inference(self,
+             data_in,
+             data_lengths=None,
+             key: list=None,
+             tokenizer=None,
+             frontend=None,
+             **kwargs,
+             ):
+        
+        if kwargs.get("batch_size", 1) > 1:
+            raise NotImplementedError("batch decoding is not implemented")
+        
+        # init beamsearch
+        if self.beam_search is None:
+            logging.info("enable beam_search")
+            self.init_beam_search(**kwargs)
+            self.nbest = kwargs.get("nbest", 1)
+
+        meta_data = {}
+        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+                                                            data_type=kwargs.get("data_type", "sound"),
+                                                            tokenizer=tokenizer)
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+                                                   frontend=frontend)
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+        # Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+        
+        # c. Passed the encoder result and the beam search
+        nbest_hyps = self.beam_search(
+            x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
+        )
+        
+        nbest_hyps = nbest_hyps[: self.nbest]
+
+
+        results = []
+        b, n, d = encoder_out.size()
+        for i in range(b):
+
+            for nbest_idx, hyp in enumerate(nbest_hyps):
+                ibest_writer = None
+                if kwargs.get("output_dir") is not None:
+                    if not hasattr(self, "writer"):
+                        self.writer = DatadirWriter(kwargs.get("output_dir"))
+                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+                    
+                # remove sos/eos and get results
+                last_pos = -1
+                if isinstance(hyp.yseq, list):
+                    token_int = hyp.yseq[1:last_pos]
+                else:
+                    token_int = hyp.yseq[1:last_pos].tolist()
+                    
+                # remove blank symbol id, which is assumed to be 0
+                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+                
+                # Change integer-ids to tokens
+                token = tokenizer.ids2tokens(token_int)
+                text = tokenizer.tokens2text(token)
+                
+                text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+                result_i = {"key": key[i], "token": token, "text": text_postprocessed}
+                results.append(result_i)
+                
+                if ibest_writer is not None:
+                    ibest_writer["token"][key[i]] = " ".join(token)
+                    ibest_writer["text"][key[i]] = text_postprocessed
+        
+        return results, meta_data
+
+
+@tables.register("model_classes", "OpenAIWhisperLIDModel")
+class OpenAIWhisperLIDModel(nn.Module):
+    """WhisperEncoder and EResNet based LID Model"""
+
+    def __init__(
+            self,
+            vocab_size: int,
+            specaug: str = None,
+            specaug_conf: dict = None,
+            encoder: str = None,
+            encoder_conf: dict = None,
+            lid_predictor:  str = None,
+            lid_predictor_conf: dict = None,
+            proj_dim: int = None,
+            clip_frames: int = None,
+            random_clip: bool = False,
+            **kwargs,
+    ):
+        super().__init__()
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(**encoder_conf)
+        lid_predictor_class = tables.lid_predictor_classes.get(lid_predictor)
+        lid_predictor = lid_predictor_class(**lid_predictor_conf)
+        if encoder.output_size() != proj_dim:
+            self.proj_layer =  torch.nn.Linear(encoder.output_size(), proj_dim)
+        else:
+            self.proj_layer = None
+        self.output_layer = torch.nn.Linear(lid_predictor.output_size(), vocab_size)
+        self.criterion_lid = LabelSmoothingLoss(
+            size=vocab_size,
+            padding_idx=-1,
+            smoothing=0.0,
+            normalize_length=False,
+        )
+
+        self.specaug = specaug
+        self.encoder = encoder
+        self.lid_predictor = lid_predictor
+        self.clip_frames = clip_frames
+        self.random_clip = random_clip
+        self.normalize = None
+        self.beam_search = None
+        if not hasattr(self.encoder, "interctc_use_conditioning"):
+            self.encoder.interctc_use_conditioning = False
+
+    def forward(self,
+                speech: torch.Tensor,  # may be padding
+                speech_lengths: torch.Tensor,  # actual length
+                lid: torch.Tensor,  # lid label, (batch_size, 1)
+                lid_lengths: torch.Tensor,
+                ):
+        assert lid.shape[1] == 1
+        batch_size = speech.shape[0]
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        # re-generate encoder_out
+        if self.clip_frames is None:
+            reduced_encoder_out = torch.zeros(batch_size, encoder_out_lens.max(), encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device)
+            for i, enc_length in enumerate(encoder_out_lens):
+                reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
+        else:
+            reduced_encoder_out = torch.zeros(batch_size, self.clip_frames, encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device)
+            if self.random_clip:
+                for i, enc_length in enumerate(encoder_out_lens):
+                    if enc_length <= self.clip_frames:
+                        reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
+                        encoder_out_lens[i] = enc_length
+                    else:
+                        max_start_index = enc_length.item() - self.clip_frames
+                        start_index = np.random.randint(0, max_start_index + 1)
+                        reduced_encoder_out[i, :self.clip_frames] = encoder_out[i, start_index:start_index + self.clip_frames]
+                        encoder_out_lens[i] = self.clip_frames
+            else:
+                for i, enc_length in enumerate(encoder_out_lens):
+                    enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length
+                    reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
+                    encoder_out_lens[i] = enc_length
+        if self.proj_layer is not None:
+            reduced_encoder_out = self.proj_layer(reduced_encoder_out)
+        lid_output = self.lid_predictor(reduced_encoder_out, encoder_out_lens)  # (B, D)
+        lid_logits = self.output_layer(lid_output)  # (B, num_classes)
+        loss = self.criterion_lid(lid_logits[:, None, :], lid)
+        with torch.no_grad():
+            _, predicted_lid = torch.max(lid_logits, 1)
+            correct = (predicted_lid == lid[:, 0]).sum().item()
+            lid_acc = correct * 1.0 / lid_logits.shape[0]
+        stats = dict()
+        stats["batch_size"] = batch_size
+        stats["loss"] = torch.clone(loss.detach())
+        stats["acc"] = lid_acc
+        stats["token_length"] = speech_lengths.max()
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def encode(
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+        Args:
+            speech: (Batch, Length, ...)
+            speech_lengths: (Batch, )
+        """
+        with autocast(False):
+
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech = speech.permute(0, 2, 1)
+                # suit for whisper padding
+                padded_speech_lengths = torch.ones_like(speech_lengths) * speech.shape[1]
+                speech, padded_speech_lengths = self.specaug(speech, padded_speech_lengths)
+                speech = speech.permute(0, 2, 1)
+
+            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+        # Forward encoder
+        # feats: (Batch, Length, Dim)
+        # -> encoder_out: (Batch, Length2, Dim2)
+        if self.encoder.interctc_use_conditioning:
+            encoder_out, encoder_out_lens, _ = self.encoder(
+                speech, speech_lengths, ctc=self.ctc
+            )
+        else:
+            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
+
+        if intermediate_outs is not None:
+            return (encoder_out, intermediate_outs), encoder_out_lens
+
+        return encoder_out, encoder_out_lens
+
+    def inference(self,
+                  data_in,
+                  data_lengths=None,
+                  key: list = None,
+                  tokenizer=None,
+                  frontend=None,
+                  **kwargs,
+                  ):
+
+        if kwargs.get("batch_size", 1) > 1:
+            raise NotImplementedError("batch decoding is not implemented")
+
+        meta_data = {}
+        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+                                                            data_type=kwargs.get("data_type", "sound"),
+                                                            tokenizer=tokenizer)
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+                                                   frontend=frontend)
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+        # Encoder
+        enc, enc_out_lens  = self.encode(speech, speech_lengths)
+
+        inference_clip_length = kwargs.get("inference_clip_length", None)
+        if self.clip_frames is not None:
+            if inference_clip_length is None:
+                reduced_enc = torch.zeros(enc.shape[0], self.clip_frames, enc.shape[-1]).to(enc.dtype).to(enc.device)
+                for i, enc_length in enumerate(enc_out_lens):
+                    enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length
+                    reduced_enc[i, :enc_length] = enc[i, :enc_length]
+                    enc_out_lens[i] = enc_length
+            else:
+                assert inference_clip_length > 0, "inference_clip_length must be larger than 0"
+                reduced_enc = torch.zeros(enc.shape[0], inference_clip_length, enc.shape[-1]).to(enc.dtype).to(enc.device)
+                for i, enc_length in enumerate(enc_out_lens):
+                    enc_length = inference_clip_length if enc_length >= inference_clip_length else enc_length
+                    reduced_enc[i, :enc_length] = enc[i, :enc_length]
+                    enc_out_lens[i] = enc_length
+        else:
+            reduced_enc = torch.zeros(enc.shape[0], enc_out_lens.max(), enc.shape[-1]).to(enc.dtype).to(enc.device)
+            for i, enc_length in enumerate(enc_out_lens):
+                reduced_enc[i, :enc_length] = enc[i, :enc_length]
+
+        if self.proj_layer is not None:
+            reduced_enc = self.proj_layer(reduced_enc)
+        lid_output = self.lid_predictor(reduced_enc, enc_out_lens)  # (B, D)
+        lid_logits = self.output_layer(lid_output)  # (B, num_classes)
+
+        _, predicted_lid_index = torch.max(lid_logits, 1)
+        predicted_lid = tokenizer.ids2tokens([predicted_lid_index[0].cpu()])[0]
+
+        if kwargs.get("output_dir") is not None:
+            if not hasattr(self, "writer"):
+                self.writer = DatadirWriter(kwargs.get("output_dir"))
+            lid_writer = self.writer["lid"]
+            lid_writer[key[0]] = predicted_lid
+
+        results = [{"key": key[0], "lid": predicted_lid}]
+
+        return results, meta_data
\ No newline at end of file

--
Gitblit v1.9.1