jmwang66
2024-02-29 2acd24f0158b2c86d2fb4e6f1134b67a1150500e
update whisper lid (#1407)

* update whisper lid
13个文件已添加
1711 ■■■■■ 已修改文件
examples/common_voice/whisper_lid/demo_funasr.py 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/common_voice/whisper_lid/demo_modelscope.py 22 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/frontends/whisper_frontend.py 102 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/whisper_lid/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/whisper_lid/decoder.py 167 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/whisper_lid/encoder.py 119 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/whisper_lid/eres2net/ResNet.py 428 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/whisper_lid/eres2net/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/whisper_lid/eres2net/fusion.py 29 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/whisper_lid/eres2net/pooling_layers.py 118 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/whisper_lid/eres2net/simple_avg.py 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/whisper_lid/lid_predictor.py 25 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/whisper_lid/model.py 665 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/common_voice/whisper_lid/demo_funasr.py
New file
@@ -0,0 +1,19 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from funasr import AutoModel
multilingual_wavs = [
    "example_zh-CN.mp3",
    "example_en.mp3",
    "example_ja.mp3",
    "example_ko.mp3",
]
model = AutoModel(model="iic/speech_whisper-large_lid_multilingual_pytorch", model_revision="v2.0.4")
for wav_id in multilingual_wavs:
    wav_file = f"{model.model_path}/examples/{wav_id}"
    res = model.generate(input=wav_file, data_type="sound", inference_clip_length=250)
    print("detect sample {}: {}".format(wav_id, res))
examples/common_voice/whisper_lid/demo_modelscope.py
New file
@@ -0,0 +1,22 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
multilingual_wavs=[
    "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_zh-CN.mp3",
    "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_en.mp3",
    "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ja.mp3",
    "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ko.mp3",
]
inference_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model='iic/speech_whisper-large_lid_multilingual_pytorch', model_revision="v2.0.4")
for wav in multilingual_wavs:
    rec_result = inference_pipeline(input=wav, inference_clip_length=250)
    print(rec_result)
funasr/frontends/whisper_frontend.py
New file
@@ -0,0 +1,102 @@
from typing import Tuple
import torch
import torch.nn as nn
import whisper
from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
from funasr.register import tables
from torch.nn.utils.rnn import pad_sequence
@tables.register("frontend_classes", "WhisperFrontend")
class WhisperFrontend(nn.Module):
    """Speech Representation Using Encoder Outputs from OpenAI's Whisper Model:
    URL: https://github.com/openai/whisper
    """
    def __init__(
            self,
            fs: int = 16000,
            whisper_model: str = "large-v3",
            do_pad_trim: bool = True,
    ):
        super().__init__()
        assert fs == 16000
        self.fs = fs
        self.n_fft = N_FFT
        self.win_length = N_FFT
        self.hop_length = HOP_LENGTH
        self.pad_samples = N_SAMPLES
        self.frame_shift = self.hop_length
        self.lfr_n = 1
        if whisper_model == "large-v3" or whisper_model == "large":
            self.n_mels = 128
        else:
            self.n_mels = 80
        self.mel_filters = whisper.audio.mel_filters
        self.do_pad_trim = do_pad_trim
        if do_pad_trim:
            self.pad_or_trim = whisper.pad_or_trim
        assert whisper_model in whisper.available_models()
    def output_size(self) -> int:
        return self.n_mels
    def log_mel_spectrogram(
            self,
            audio: torch.Tensor,
            ilens: torch.Tensor = None,
    ) -> torch.Tensor:
        window = torch.hann_window(self.win_length).to(audio.device)
        stft = torch.stft(
            audio, self.n_fft, self.hop_length, window=window, return_complex=True
        )
        # whisper deletes the last frame by default (Shih-Lun)
        magnitudes = stft[..., :-1].abs() ** 2
        filters = self.mel_filters(audio.device, self.n_mels)
        mel_spec = filters @ magnitudes
        log_spec = torch.clamp(mel_spec, min=1e-10).log10()
        if ilens is not None:
            olens = ilens // self.hop_length
        else:
            olens = None
        log_spec = torch.maximum(
            log_spec,
            log_spec.view(audio.size(0), -1).max(dim=-1)[0][:, None, None] - 8.0,
        )
        log_spec = (log_spec + 4.0) / 4.0
        return log_spec, olens
    def forward(
            self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        for i in range(batch_size):
            if self.do_pad_trim:
                feat = self.pad_or_trim(input[i], self.pad_samples)
            else:
                feat = input[i]
            feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[0])
            feats.append(feat[0])
            feats_lens.append(feat_len)
        feats_lens = torch.as_tensor(feats_lens)
        if batch_size == 1:
            feats_pad = feats[0][None, :, :]
        else:
            feats_pad = pad_sequence(feats,
                                     batch_first=True,
                                     padding_value=0.0)
        return feats_pad, feats_lens
funasr/models/whisper_lid/__init__.py
funasr/models/whisper_lid/decoder.py
New file
@@ -0,0 +1,167 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
import copy
from typing import Any, List, Tuple
import torch
from torch import nn
import whisper
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.register import tables
@tables.register("decoder_classes", "OpenAIWhisperDecoderWarp")
class OpenAIWhisperDecoderWarp(nn.Module):
    """Transformer-based Speech-to-Text Decoder from OpenAI's Whisper Model:
    URL: https://github.com/openai/whisper
    """
    def __init__(
        self,
        dropout_rate: float = 0.0,
        whisper_model: str = "small",
        download_dir: str = None,
        use_padmask: bool = False,
    ):
        super().__init__()
        assert whisper_model in whisper.available_models()
        _model = whisper.load_model(
            whisper_model, download_root=download_dir, device="cpu"
        )
        self.decoders = copy.deepcopy(_model.decoder)
        attention_dim = self.decoders.token_embedding.embedding_dim
        # note that originally Whisper doesn't use dropouts
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.decoders.train()
        del _model
        self.use_padmask = use_padmask
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward decoder.
        Args:
            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
            hlens: (batch)
            ys_in_pad:
                input token ids, int64 (batch, maxlen_out)
                if input_layer == "embed"
                input tensor (batch, maxlen_out, #mels) in the other cases
            ys_in_lens: (batch)
        Returns:
            (tuple): tuple containing:
            x: decoded token score before softmax (batch, maxlen_out, token)
                if use_output_layer is True,
            olens: (batch, )
        """
        tgt, memory = ys_in_pad, hs_pad
        tgt = (
            self.decoders.token_embedding(tgt)
            + self.decoders.positional_embedding[: tgt.size(1)]
        )
        tgt = self.dropout(tgt)
        x = tgt.to(memory.dtype)
        if self.use_padmask:
            memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
        else:
            memory_mask = None
        for layer, block in enumerate(self.decoders.blocks):
            x = block(x, memory, mask=self.decoders.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
            if layer < len(self.decoders.blocks) - 1:
                x = self.dropout(x)
        x = self.decoders.ln(x)
        x = (
            x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()
        return x, ys_in_lens
    def forward_one_step(
        self,
        tgt: torch.Tensor,
        tgt_mask: torch.Tensor,
        memory: torch.Tensor,
        cache: List[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """Forward one step.
        Args:
            tgt: input token ids, int64 (batch, maxlen_out)
            tgt_mask: input token mask,  (batch, maxlen_out)
                      dtype=torch.uint8 in PyTorch 1.2-
                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
            memory: encoded memory, float32  (batch, maxlen_in, feat)
            cache: cached output list of (batch, max_time_out-1, size)
        Returns:
            y, cache: NN output value and cache per `self.decoders`.
            y.shape` is (batch, maxlen_out, token)
        NOTE (Shih-Lun):
            cache implementation is ignored for now
            for simplicity & correctness
        """
        x = (
            self.decoders.token_embedding(tgt)
            + self.decoders.positional_embedding[: tgt.size(1)]
        )
        x = self.dropout(x)
        x = x.to(memory.dtype)
        for layer, block in enumerate(self.decoders.blocks):
            x = block(x, memory, mask=self.decoders.mask)
            if layer < len(self.decoders.blocks) - 1:
                x = self.dropout(x)
        x = self.decoders.ln(x)
        y = x[:, -1]
        y = (
            y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()
        y = torch.log_softmax(y, dim=-1)
        return y, None
    def score(self, ys, state, x):
        """Score."""
        logp, state = self.forward_one_step(
            ys.unsqueeze(0), torch.empty(0), x.unsqueeze(0), cache=state  # dummy mask
        )
        return logp.squeeze(0), state
    def batch_score(
        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
    ) -> Tuple[torch.Tensor, List[Any]]:
        """Score new token batch.
        Args:
            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
            states (List[Any]): Scorer states for prefix tokens.
            xs (torch.Tensor):
                The encoder feature that generates ys (n_batch, xlen, n_feat).
        Returns:
            tuple[torch.Tensor, List[Any]]: Tuple of
                batchfied scores for next token with shape of `(n_batch, n_vocab)`
                and next state list for ys.
        """
        # batch decoding, dummy mask is passed
        logp, states = self.forward_one_step(ys, torch.empty(0), xs, cache=None)
        return logp, None
funasr/models/whisper_lid/encoder.py
New file
@@ -0,0 +1,119 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
import copy
from typing import Optional, Tuple, Union
import torch
from torch import nn
import torch.nn.functional as F
import whisper
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.specaug.specaug import SpecAug
from funasr.register import tables
@tables.register("encoder_classes", "OpenAIWhisperEncoderWarp")
class OpenAIWhisperEncoderWarp(nn.Module):
    """Transformer-based Speech Encoder from OpenAI's Whisper Model:
    URL: https://github.com/openai/whisper
    """
    def __init__(
            self,
            dropout_rate: float = 0.0,
            whisper_model: str = "small",
            download_dir: str = None,
            use_specaug: bool = False,
            use_padmask: bool = False,
            specaug_conf: Union[dict, None] = None,
    ):
        super().__init__()
        # note that originally Whisper doesn't use dropouts
        self.dropout = torch.nn.Dropout(dropout_rate)
        assert whisper_model in whisper.available_models()
        _model = whisper.load_model(
            whisper_model, download_root=download_dir, device="cpu"
        )
        self.encoders = copy.deepcopy(_model.encoder)
        self.encoders.train()
        del _model
        if use_specaug:
            self.specaug = SpecAug(**specaug_conf)
        else:
            self.specaug = None
        self.use_padmask = use_padmask
    def whisper_encode(
            self,
            input: torch.Tensor,
            ilens: torch.Tensor = None,
    ) -> torch.Tensor:
        x = F.gelu(self.encoders.conv1(input))
        x = F.gelu(self.encoders.conv2(x))
        x = x.permute(0, 2, 1)
        n_frames = x.size(1)
        max_pos = self.encoders.positional_embedding.size(0)
        if n_frames <= max_pos:
            x = (x + self.encoders.positional_embedding[: x.size(1), :]).to(x.dtype)
        else:
            # due to positional encoding, audios >30 sec won't be accepted
            x = x[:, :max_pos, :] + self.encoders.positional_embedding
        if ilens is not None:
            olens = (
                    1
                    + (
                            ilens
                            - self.encoders.conv2.kernel_size[0]
                            + 2 * self.encoders.conv2.padding[0]
                    )
                    // self.encoders.conv2.stride[0]
            )
            olens = torch.clamp(olens, max=max_pos)
        else:
            olens = None
        if self.use_padmask:
            padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
        else:
            padding_mask = None
        x = self.dropout(x)
        for layer, block in enumerate(self.encoders.blocks):
            x = block(x)
            if layer < len(self.encoders.blocks) - 1:
                x = self.dropout(x)
        x = self.encoders.ln_post(x)
        return x, olens
    def output_size(self) -> int:
        # dummy output size
        return self.encoders.conv2.weight.shape[0]
    def forward(
            self,
            xs_pad: torch.Tensor,
            ilens: torch.Tensor,
            prev_states: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        feats, feats_lens = xs_pad, ilens
        if self.specaug is not None and self.encoders.training:
            feats = torch.transpose(feats, 1, 2)
            feats, feats_lens = self.specaug(feats, feats_lens)
            feats = torch.transpose(feats, 1, 2)
        xs_pad, olens = self.whisper_encode(feats, feats_lens)
        return xs_pad, olens, None
funasr/models/whisper_lid/eres2net/ResNet.py
New file
@@ -0,0 +1,428 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
    ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
    The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
    The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
    ERes2Net-Large is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
    recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
"""
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import funasr.models.whisper_lid.eres2net.pooling_layers as pooling_layers
from funasr.models.whisper_lid.eres2net.fusion import AFF
class ReLU(nn.Hardtanh):
    def __init__(self, inplace=False):
        super(ReLU, self).__init__(0, 20, inplace)
    def __repr__(self):
        inplace_str = 'inplace' if self.inplace else ''
        return self.__class__.__name__ + ' (' \
               + inplace_str + ')'
def conv1x1(in_planes, out_planes, stride=1):
    "1x1 convolution without padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                     padding=0, bias=False)
def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
class BasicBlockERes2Net(nn.Module):
    expansion = 2
    def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
        super(BasicBlockERes2Net, self).__init__()
        width = int(math.floor(planes * (baseWidth / 64.0)))
        self.conv1 = conv1x1(in_planes, width * scale, stride)
        self.bn1 = nn.BatchNorm2d(width * scale)
        self.nums = scale
        convs = []
        bns = []
        for i in range(self.nums):
            convs.append(conv3x3(width, width))
            bns.append(nn.BatchNorm2d(width))
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)
        self.relu = ReLU(inplace=True)
        self.conv3 = conv1x1(width * scale, planes * self.expansion)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes,
                          self.expansion * planes,
                          kernel_size=1,
                          stride=stride,
                          bias=False),
                nn.BatchNorm2d(self.expansion * planes))
        self.stride = stride
        self.width = width
        self.scale = scale
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        spx = torch.split(out, self.width, 1)
        for i in range(self.nums):
            if i == 0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = self.convs[i](sp)
            sp = self.relu(self.bns[i](sp))
            if i == 0:
                out = sp
            else:
                out = torch.cat((out, sp), 1)
        out = self.conv3(out)
        out = self.bn3(out)
        residual = self.shortcut(x)
        out += residual
        out = self.relu(out)
        return out
class BasicBlockERes2Net_diff_AFF(nn.Module):
    expansion = 2
    def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
        super(BasicBlockERes2Net_diff_AFF, self).__init__()
        width = int(math.floor(planes * (baseWidth / 64.0)))
        self.conv1 = conv1x1(in_planes, width * scale, stride)
        self.bn1 = nn.BatchNorm2d(width * scale)
        self.nums = scale
        convs = []
        fuse_models = []
        bns = []
        for i in range(self.nums):
            convs.append(conv3x3(width, width))
            bns.append(nn.BatchNorm2d(width))
        for j in range(self.nums - 1):
            fuse_models.append(AFF(channels=width))
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)
        self.fuse_models = nn.ModuleList(fuse_models)
        self.relu = ReLU(inplace=True)
        self.conv3 = conv1x1(width * scale, planes * self.expansion)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes,
                          self.expansion * planes,
                          kernel_size=1,
                          stride=stride,
                          bias=False),
                nn.BatchNorm2d(self.expansion * planes))
        self.stride = stride
        self.width = width
        self.scale = scale
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        spx = torch.split(out, self.width, 1)
        for i in range(self.nums):
            if i == 0:
                sp = spx[i]
            else:
                sp = self.fuse_models[i - 1](sp, spx[i])
            sp = self.convs[i](sp)
            sp = self.relu(self.bns[i](sp))
            if i == 0:
                out = sp
            else:
                out = torch.cat((out, sp), 1)
        out = self.conv3(out)
        out = self.bn3(out)
        residual = self.shortcut(x)
        out += residual
        out = self.relu(out)
        return out
class ERes2Net(nn.Module):
    def __init__(self,
                 block=BasicBlockERes2Net,
                 block_fuse=BasicBlockERes2Net_diff_AFF,
                 num_blocks=[3, 4, 6, 3],
                 m_channels=32,
                 feat_dim=80,
                 embedding_size=192,
                 pooling_func='TSTP',
                 two_emb_layer=False):
        super(ERes2Net, self).__init__()
        self.in_planes = m_channels
        self.feat_dim = feat_dim
        self.embedding_size = embedding_size
        self.stats_dim = int(feat_dim / 8) * m_channels * 8
        self.two_emb_layer = two_emb_layer
        self._output_size = embedding_size
        self.conv1 = nn.Conv2d(1,
                               m_channels,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(m_channels)
        self.layer1 = self._make_layer(block,
                                       m_channels,
                                       num_blocks[0],
                                       stride=1)
        self.layer2 = self._make_layer(block,
                                       m_channels * 2,
                                       num_blocks[1],
                                       stride=2)
        self.layer3 = self._make_layer(block_fuse,
                                       m_channels * 4,
                                       num_blocks[2],
                                       stride=2)
        self.layer4 = self._make_layer(block_fuse,
                                       m_channels * 8,
                                       num_blocks[3],
                                       stride=2)
        # Downsampling module for each layer
        self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1,
                                           bias=False)
        self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2,
                                           bias=False)
        self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2,
                                           bias=False)
        # Bottom-up fusion module
        self.fuse_mode12 = AFF(channels=m_channels * 4)
        self.fuse_mode123 = AFF(channels=m_channels * 8)
        self.fuse_mode1234 = AFF(channels=m_channels * 16)
        self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
        self.pool = getattr(pooling_layers, pooling_func)(
            in_dim=self.stats_dim * block.expansion)
        self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
                               embedding_size)
        if self.two_emb_layer:
            self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
            self.seg_2 = nn.Linear(embedding_size, embedding_size)
        else:
            self.seg_bn_1 = nn.Identity()
            self.seg_2 = nn.Identity()
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    def output_size(self) -> int:
        return self._output_size
    def forward(self, x, ilens):
        # assert x.shape[1] == ilens.max()
        x = x.permute(0, 2, 1)  # (B,T,F) => (B,F,T)
        x = x.unsqueeze_(1)
        out = F.relu(self.bn1(self.conv1(x)))
        out1 = self.layer1(out)
        out2 = self.layer2(out1)
        out1_downsample = self.layer1_downsample(out1)
        fuse_out12 = self.fuse_mode12(out2, out1_downsample)
        out3 = self.layer3(out2)
        fuse_out12_downsample = self.layer2_downsample(fuse_out12)
        fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
        out4 = self.layer4(out3)
        fuse_out123_downsample = self.layer3_downsample(fuse_out123)
        fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
        olens = (((((ilens - 1) // 2 + 1) - 1) // 2 + 1) - 1) // 2 + 1
        stats = self.pool(fuse_out1234, olens)
        embed_a = self.seg_1(stats)
        if self.two_emb_layer:
            out = F.relu(embed_a)
            out = self.seg_bn_1(out)
            embed_b = self.seg_2(out)
            return embed_b
        else:
            return embed_a
class BasicBlockRes2Net(nn.Module):
    expansion = 2
    def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
        super(BasicBlockRes2Net, self).__init__()
        width = int(math.floor(planes * (baseWidth / 64.0)))
        self.conv1 = conv1x1(in_planes, width * scale, stride)
        self.bn1 = nn.BatchNorm2d(width * scale)
        self.nums = scale - 1
        convs = []
        bns = []
        for i in range(self.nums):
            convs.append(conv3x3(width, width))
            bns.append(nn.BatchNorm2d(width))
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)
        self.relu = ReLU(inplace=True)
        self.conv3 = conv1x1(width * scale, planes * self.expansion)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes,
                          self.expansion * planes,
                          kernel_size=1,
                          stride=stride,
                          bias=False),
                nn.BatchNorm2d(self.expansion * planes))
        self.stride = stride
        self.width = width
        self.scale = scale
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        spx = torch.split(out, self.width, 1)
        for i in range(self.nums):
            if i == 0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = self.convs[i](sp)
            sp = self.relu(self.bns[i](sp))
            if i == 0:
                out = sp
            else:
                out = torch.cat((out, sp), 1)
        out = torch.cat((out, spx[self.nums]), 1)
        out = self.conv3(out)
        out = self.bn3(out)
        residual = self.shortcut(x)
        out += residual
        out = self.relu(out)
        return out
class Res2Net(nn.Module):
    def __init__(self,
                 block=BasicBlockRes2Net,
                 num_blocks=[3, 4, 6, 3],
                 m_channels=32,
                 feat_dim=80,
                 embedding_size=192,
                 pooling_func='TSTP',
                 two_emb_layer=False):
        super(Res2Net, self).__init__()
        self.in_planes = m_channels
        self.feat_dim = feat_dim
        self.embedding_size = embedding_size
        self.stats_dim = int(feat_dim / 8) * m_channels * 8
        self.two_emb_layer = two_emb_layer
        self.conv1 = nn.Conv2d(1,
                               m_channels,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(m_channels)
        self.layer1 = self._make_layer(block,
                                       m_channels,
                                       num_blocks[0],
                                       stride=1)
        self.layer2 = self._make_layer(block,
                                       m_channels * 2,
                                       num_blocks[1],
                                       stride=2)
        self.layer3 = self._make_layer(block,
                                       m_channels * 4,
                                       num_blocks[2],
                                       stride=2)
        self.layer4 = self._make_layer(block,
                                       m_channels * 8,
                                       num_blocks[3],
                                       stride=2)
        self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
        self.pool = getattr(pooling_layers, pooling_func)(
            in_dim=self.stats_dim * block.expansion)
        self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
                               embedding_size)
        if self.two_emb_layer:
            self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
            self.seg_2 = nn.Linear(embedding_size, embedding_size)
        else:
            self.seg_bn_1 = nn.Identity()
            self.seg_2 = nn.Identity()
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    def forward(self, x):
        x = x.permute(0, 2, 1)  # (B,T,F) => (B,F,T)
        x = x.unsqueeze_(1)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        stats = self.pool(out)
        embed_a = self.seg_1(stats)
        if self.two_emb_layer:
            out = F.relu(embed_a)
            out = self.seg_bn_1(out)
            embed_b = self.seg_2(out)
            return embed_b
        else:
            return embed_a
funasr/models/whisper_lid/eres2net/__init__.py
funasr/models/whisper_lid/eres2net/fusion.py
New file
@@ -0,0 +1,29 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
import torch.nn as nn
class AFF(nn.Module):
    def __init__(self, channels=64, r=4):
        super(AFF, self).__init__()
        inter_channels = int(channels // r)
        self.local_att = nn.Sequential(
            nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(inter_channels),
            nn.SiLU(inplace=True),
            nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(channels),
        )
    def forward(self, x, ds_y):
        xa = torch.cat((x, ds_y), dim=1)
        x_att = self.local_att(xa)
        x_att = 1.0 + torch.tanh(x_att)
        xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att)
        return xo
funasr/models/whisper_lid/eres2net/pooling_layers.py
New file
@@ -0,0 +1,118 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
""" This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
import torch
import torch.nn as nn
from funasr.models.transformer.utils.nets_utils import make_pad_mask
class TAP(nn.Module):
    """
    Temporal average pooling, only first-order mean is considered
    """
    def __init__(self, **kwargs):
        super(TAP, self).__init__()
    def forward(self, x):
        pooling_mean = x.mean(dim=-1)
        # To be compatable with 2D input
        pooling_mean = pooling_mean.flatten(start_dim=1)
        return pooling_mean
class TSDP(nn.Module):
    """
    Temporal standard deviation pooling, only second-order std is considered
    """
    def __init__(self, **kwargs):
        super(TSDP, self).__init__()
    def forward(self, x):
        # The last dimension is the temporal axis
        pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
        pooling_std = pooling_std.flatten(start_dim=1)
        return pooling_std
class TSTP(nn.Module):
    """
    Temporal statistics pooling, concatenate mean and std, which is used in
    x-vector
    Comment: simple concatenation can not make full use of both statistics
    """
    def __init__(self, **kwargs):
        super(TSTP, self).__init__()
    def forward(self, x, olens):
        # The last dimension is the temporal axis
        masks = (~make_pad_mask(olens, maxlen=x.shape[-1])[:, None, None, :]).to(x.device)
        x_masked = x * masks
        sum_without_padding = torch.sum(x_masked, axis=-1)
        count_without_padding = torch.sum(masks, axis=-1)
        mean_without_padding = sum_without_padding / count_without_padding
        var_without_padding = ((x_masked - mean_without_padding.unsqueeze(-1)) ** 2 * masks).sum(-1) / count_without_padding
        pooling_mean = mean_without_padding
        pooling_std = torch.sqrt(var_without_padding + 1e-8)
        pooling_mean = pooling_mean.flatten(start_dim=1)
        pooling_std = pooling_std.flatten(start_dim=1)
        stats = torch.cat((pooling_mean, pooling_std), 1)
        return stats
class ASTP(nn.Module):
    """ Attentive statistics pooling: Channel- and context-dependent
        statistics pooling, first used in ECAPA_TDNN.
    """
    def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
        super(ASTP, self).__init__()
        self.global_context_att = global_context_att
        # Use Conv1d with stride == 1 rather than Linear, then we don't
        # need to transpose inputs.
        if global_context_att:
            self.linear1 = nn.Conv1d(
                in_dim * 3, bottleneck_dim,
                kernel_size=1)  # equals W and b in the paper
        else:
            self.linear1 = nn.Conv1d(
                in_dim, bottleneck_dim,
                kernel_size=1)  # equals W and b in the paper
        self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
                                 kernel_size=1)  # equals V and k in the paper
    def forward(self, x):
        """
        x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
            or a 4-dimensional tensor in resnet architecture (B,C,F,T)
            0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
        """
        if len(x.shape) == 4:
            x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
        assert len(x.shape) == 3
        if self.global_context_att:
            context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
            context_std = torch.sqrt(
                torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
            x_in = torch.cat((x, context_mean, context_std), dim=1)
        else:
            x_in = x
        # DON'T use ReLU here! ReLU may be hard to converge.
        alpha = torch.tanh(
            self.linear1(x_in))  # alpha = F.relu(self.linear1(x_in))
        alpha = torch.softmax(self.linear2(alpha), dim=2)
        mean = torch.sum(alpha * x, dim=2)
        var = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
        std = torch.sqrt(var.clamp(min=1e-10))
        return torch.cat([mean, std], dim=1)
funasr/models/whisper_lid/eres2net/simple_avg.py
New file
@@ -0,0 +1,17 @@
import torch
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.nets_utils import make_pad_mask
class SimpleAvg(AbsEncoder):
    def __init__(self, feat_dim):
        super(SimpleAvg, self).__init__()
        self.feat_dim = feat_dim
    def forward(self, x, ilens):
        mask = ~make_pad_mask(ilens, maxlen=x.shape[1]).to(x.device)
        avg_x = (x * mask[:, :, None]).sum(1) / mask.sum(-1)[:, None]
        return avg_x
    def output_size(self) -> int:
        return self.feat_dim
funasr/models/whisper_lid/lid_predictor.py
New file
@@ -0,0 +1,25 @@
from funasr.register import tables
from funasr.models.whisper_lid.eres2net.ResNet import ERes2Net, BasicBlockERes2Net, BasicBlockERes2Net_diff_AFF
@tables.register("lid_predictor_classes", "LidPredictor")
class LidPredictor(ERes2Net):
    def __init__(self,
                 block=BasicBlockERes2Net,
                 block_fuse=BasicBlockERes2Net_diff_AFF,
                 num_blocks=[3, 4, 6, 3],
                 m_channels=32,
                 feat_dim=80,
                 embedding_size=192,
                 pooling_func='TSTP',
                 two_emb_layer=False):
        super(LidPredictor, self).__init__(
                block=block,
                block_fuse=block_fuse,
                num_blocks=num_blocks,
                m_channels=m_channels,
                feat_dim=feat_dim,
                embedding_size=embedding_size,
                pooling_func=pooling_func,
                two_emb_layer=two_emb_layer
        )
funasr/models/whisper_lid/model.py
New file
@@ -0,0 +1,665 @@
import logging
from typing import Union, Dict, List, Tuple, Optional
import time
import torch
import numpy as np
import torch.nn as nn
from torch.cuda.amp import autocast
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.ctc.ctc import CTC
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
@tables.register("model_classes", "OpenAIWhisperModel")
class OpenAIWhisperModel(nn.Module):
    """CTC-attention hybrid Encoder-Decoder model"""
    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        interctc_weight: float = 0.0,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        report_cer: bool = True,
        report_wer: bool = True,
        sym_space: str = "<space>",
        sym_blank: str = "<blank>",
        # extract_feats_in_collect_stats: bool = True,
        share_embedding: bool = False,
        # preencoder: Optional[AbsPreEncoder] = None,
        # postencoder: Optional[AbsPostEncoder] = None,
        **kwargs,
    ):
        super().__init__()
        if specaug is not None:
            specaug_class = tables.specaug_classes.get(specaug)
            specaug = specaug_class(**specaug_conf)
        if normalize is not None:
            normalize_class = tables.normalize_classes.get(normalize)
            normalize = normalize_class(**normalize_conf)
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(input_size=input_size, **encoder_conf)
        encoder_output_size = encoder.output_size()
        if decoder is not None:
            decoder_class = tables.decoder_classes.get(decoder)
            decoder = decoder_class(decoder_conf)
        if ctc_weight > 0.0:
            if ctc_conf is None:
                ctc_conf = {}
            ctc = CTC(
                odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
            )
        self.blank_id = blank_id
        self.sos = sos if sos is not None else vocab_size - 1
        self.eos = eos if eos is not None else vocab_size - 1
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        self.ctc_weight = ctc_weight
        self.specaug = specaug
        self.normalize = normalize
        self.encoder = encoder
        if not hasattr(self.encoder, "interctc_use_conditioning"):
            self.encoder.interctc_use_conditioning = False
        if self.encoder.interctc_use_conditioning:
            self.encoder.conditioning_layer = torch.nn.Linear(
                vocab_size, self.encoder.output_size()
            )
        self.interctc_weight = interctc_weight
        # self.error_calculator = None
        if ctc_weight == 1.0:
            self.decoder = None
        else:
            self.decoder = decoder
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        #
        # if report_cer or report_wer:
        #     self.error_calculator = ErrorCalculator(
        #         token_list, sym_space, sym_blank, report_cer, report_wer
        #     )
        #
        self.error_calculator = None
        if ctc_weight == 0.0:
            self.ctc = None
        else:
            self.ctc = ctc
        self.share_embedding = share_embedding
        if self.share_embedding:
            self.decoder.embed = None
        self.length_normalized_loss = length_normalized_loss
        self.beam_search = None
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        loss_att, acc_att, cer_att, wer_att = None, None, None, None
        loss_ctc, cer_ctc = None, None
        stats = dict()
        # decoder: CTC branch
        if self.ctc_weight != 0.0:
            loss_ctc, cer_ctc = self._calc_ctc_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
            # Collect CTC branch stats
            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
            stats["cer_ctc"] = cer_ctc
        # Intermediate CTC (optional)
        loss_interctc = 0.0
        if self.interctc_weight != 0.0 and intermediate_outs is not None:
            for layer_idx, intermediate_out in intermediate_outs:
                # we assume intermediate_out has the same length & padding
                # as those of encoder_out
                loss_ic, cer_ic = self._calc_ctc_loss(
                    intermediate_out, encoder_out_lens, text, text_lengths
                )
                loss_interctc = loss_interctc + loss_ic
                # Collect Intermedaite CTC stats
                stats["loss_interctc_layer{}".format(layer_idx)] = (
                    loss_ic.detach() if loss_ic is not None else None
                )
                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
            loss_interctc = loss_interctc / len(intermediate_outs)
            # calculate whole encoder loss
            loss_ctc = (
                           1 - self.interctc_weight
                       ) * loss_ctc + self.interctc_weight * loss_interctc
        # decoder: Attention decoder branch
        loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
            encoder_out, encoder_out_lens, text, text_lengths
        )
        # 3. CTC-Att loss definition
        if self.ctc_weight == 0.0:
            loss = loss_att
        elif self.ctc_weight == 1.0:
            loss = loss_ctc
        else:
            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
        # Collect Attn branch stats
        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
        stats["acc"] = acc_att
        stats["cer"] = cer_att
        stats["wer"] = wer_att
        # Collect total loss stats
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + 1).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)
            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)
        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(
                speech, speech_lengths, ctc=self.ctc
            )
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        if intermediate_outs is not None:
            return (encoder_out, intermediate_outs), encoder_out_lens
        return encoder_out, encoder_out_lens
    def _calc_att_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        ys_in_lens = ys_pad_lens + 1
        # 1. Forward decoder
        decoder_out, _ = self.decoder(
            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
        )
        # 2. Compute attention loss
        loss_att = self.criterion_att(decoder_out, ys_out_pad)
        acc_att = th_accuracy(
            decoder_out.view(-1, self.vocab_size),
            ys_out_pad,
            ignore_label=self.ignore_id,
        )
        # Compute cer/wer using attention-decoder
        if self.training or self.error_calculator is None:
            cer_att, wer_att = None, None
        else:
            ys_hat = decoder_out.argmax(dim=-1)
            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
        return loss_att, acc_att, cer_att, wer_att
    def _calc_ctc_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        # Calc CTC loss
        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
        # Calc CER using CTC
        cer_ctc = None
        if not self.training and self.error_calculator is not None:
            ys_hat = self.ctc.argmax(encoder_out).data
            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
        return loss_ctc, cer_ctc
    def init_beam_search(self,
                         **kwargs,
                         ):
        from funasr.models.transformer.search import BeamSearch
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus
        # 1. Build ASR model
        scorers = {}
        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(
                ctc=ctc
            )
        token_list = kwargs.get("token_list")
        scorers.update(
            decoder=self.decoder,
            length_bonus=LengthBonus(len(token_list)),
        )
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        weights = dict(
            decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.5),
            ctc=kwargs.get("decoding_ctc_weight", 0.5),
            lm=kwargs.get("lm_weight", 0.0),
            ngram=kwargs.get("ngram_weight", 0.0),
            length_bonus=kwargs.get("penalty", 0.0),
        )
        beam_search = BeamSearch(
            beam_size=kwargs.get("beam_size", 10),
            weights=weights,
            scorers=scorers,
            sos=self.sos,
            eos=self.eos,
            vocab_size=len(token_list),
            token_list=token_list,
            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
        )
        self.beam_search = beam_search
    def inference(self,
             data_in,
             data_lengths=None,
             key: list=None,
             tokenizer=None,
             frontend=None,
             **kwargs,
             ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        # init beamsearch
        if self.beam_search is None:
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                            data_type=kwargs.get("data_type", "sound"),
                                                            tokenizer=tokenizer)
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend)
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
        )
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        b, n, d = encoder_out.size()
        for i in range(b):
            for nbest_idx, hyp in enumerate(nbest_hyps):
                ibest_writer = None
                if kwargs.get("output_dir") is not None:
                    if not hasattr(self, "writer"):
                        self.writer = DatadirWriter(kwargs.get("output_dir"))
                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
                    token_int = hyp.yseq[1:last_pos]
                else:
                    token_int = hyp.yseq[1:last_pos].tolist()
                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
                # Change integer-ids to tokens
                token = tokenizer.ids2tokens(token_int)
                text = tokenizer.tokens2text(token)
                text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                result_i = {"key": key[i], "token": token, "text": text_postprocessed}
                results.append(result_i)
                if ibest_writer is not None:
                    ibest_writer["token"][key[i]] = " ".join(token)
                    ibest_writer["text"][key[i]] = text_postprocessed
        return results, meta_data
@tables.register("model_classes", "OpenAIWhisperLIDModel")
class OpenAIWhisperLIDModel(nn.Module):
    """WhisperEncoder and EResNet based LID Model"""
    def __init__(
            self,
            vocab_size: int,
            specaug: str = None,
            specaug_conf: dict = None,
            encoder: str = None,
            encoder_conf: dict = None,
            lid_predictor:  str = None,
            lid_predictor_conf: dict = None,
            proj_dim: int = None,
            clip_frames: int = None,
            random_clip: bool = False,
            **kwargs,
    ):
        super().__init__()
        if specaug is not None:
            specaug_class = tables.specaug_classes.get(specaug)
            specaug = specaug_class(**specaug_conf)
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(**encoder_conf)
        lid_predictor_class = tables.lid_predictor_classes.get(lid_predictor)
        lid_predictor = lid_predictor_class(**lid_predictor_conf)
        if encoder.output_size() != proj_dim:
            self.proj_layer =  torch.nn.Linear(encoder.output_size(), proj_dim)
        else:
            self.proj_layer = None
        self.output_layer = torch.nn.Linear(lid_predictor.output_size(), vocab_size)
        self.criterion_lid = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=-1,
            smoothing=0.0,
            normalize_length=False,
        )
        self.specaug = specaug
        self.encoder = encoder
        self.lid_predictor = lid_predictor
        self.clip_frames = clip_frames
        self.random_clip = random_clip
        self.normalize = None
        self.beam_search = None
        if not hasattr(self.encoder, "interctc_use_conditioning"):
            self.encoder.interctc_use_conditioning = False
    def forward(self,
                speech: torch.Tensor,  # may be padding
                speech_lengths: torch.Tensor,  # actual length
                lid: torch.Tensor,  # lid label, (batch_size, 1)
                lid_lengths: torch.Tensor,
                ):
        assert lid.shape[1] == 1
        batch_size = speech.shape[0]
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        # re-generate encoder_out
        if self.clip_frames is None:
            reduced_encoder_out = torch.zeros(batch_size, encoder_out_lens.max(), encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device)
            for i, enc_length in enumerate(encoder_out_lens):
                reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
        else:
            reduced_encoder_out = torch.zeros(batch_size, self.clip_frames, encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device)
            if self.random_clip:
                for i, enc_length in enumerate(encoder_out_lens):
                    if enc_length <= self.clip_frames:
                        reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
                        encoder_out_lens[i] = enc_length
                    else:
                        max_start_index = enc_length.item() - self.clip_frames
                        start_index = np.random.randint(0, max_start_index + 1)
                        reduced_encoder_out[i, :self.clip_frames] = encoder_out[i, start_index:start_index + self.clip_frames]
                        encoder_out_lens[i] = self.clip_frames
            else:
                for i, enc_length in enumerate(encoder_out_lens):
                    enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length
                    reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
                    encoder_out_lens[i] = enc_length
        if self.proj_layer is not None:
            reduced_encoder_out = self.proj_layer(reduced_encoder_out)
        lid_output = self.lid_predictor(reduced_encoder_out, encoder_out_lens)  # (B, D)
        lid_logits = self.output_layer(lid_output)  # (B, num_classes)
        loss = self.criterion_lid(lid_logits[:, None, :], lid)
        with torch.no_grad():
            _, predicted_lid = torch.max(lid_logits, 1)
            correct = (predicted_lid == lid[:, 0]).sum().item()
            lid_acc = correct * 1.0 / lid_logits.shape[0]
        stats = dict()
        stats["batch_size"] = batch_size
        stats["loss"] = torch.clone(loss.detach())
        stats["acc"] = lid_acc
        stats["token_length"] = speech_lengths.max()
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def encode(
            self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech = speech.permute(0, 2, 1)
                # suit for whisper padding
                padded_speech_lengths = torch.ones_like(speech_lengths) * speech.shape[1]
                speech, padded_speech_lengths = self.specaug(speech, padded_speech_lengths)
                speech = speech.permute(0, 2, 1)
            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)
        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(
                speech, speech_lengths, ctc=self.ctc
            )
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        if intermediate_outs is not None:
            return (encoder_out, intermediate_outs), encoder_out_lens
        return encoder_out, encoder_out_lens
    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
                  tokenizer=None,
                  frontend=None,
                  **kwargs,
                  ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                            data_type=kwargs.get("data_type", "sound"),
                                                            tokenizer=tokenizer)
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend)
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        # Encoder
        enc, enc_out_lens  = self.encode(speech, speech_lengths)
        inference_clip_length = kwargs.get("inference_clip_length", None)
        if self.clip_frames is not None:
            if inference_clip_length is None:
                reduced_enc = torch.zeros(enc.shape[0], self.clip_frames, enc.shape[-1]).to(enc.dtype).to(enc.device)
                for i, enc_length in enumerate(enc_out_lens):
                    enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length
                    reduced_enc[i, :enc_length] = enc[i, :enc_length]
                    enc_out_lens[i] = enc_length
            else:
                assert inference_clip_length > 0, "inference_clip_length must be larger than 0"
                reduced_enc = torch.zeros(enc.shape[0], inference_clip_length, enc.shape[-1]).to(enc.dtype).to(enc.device)
                for i, enc_length in enumerate(enc_out_lens):
                    enc_length = inference_clip_length if enc_length >= inference_clip_length else enc_length
                    reduced_enc[i, :enc_length] = enc[i, :enc_length]
                    enc_out_lens[i] = enc_length
        else:
            reduced_enc = torch.zeros(enc.shape[0], enc_out_lens.max(), enc.shape[-1]).to(enc.dtype).to(enc.device)
            for i, enc_length in enumerate(enc_out_lens):
                reduced_enc[i, :enc_length] = enc[i, :enc_length]
        if self.proj_layer is not None:
            reduced_enc = self.proj_layer(reduced_enc)
        lid_output = self.lid_predictor(reduced_enc, enc_out_lens)  # (B, D)
        lid_logits = self.output_layer(lid_output)  # (B, num_classes)
        _, predicted_lid_index = torch.max(lid_logits, 1)
        predicted_lid = tokenizer.ids2tokens([predicted_lid_index[0].cpu()])[0]
        if kwargs.get("output_dir") is not None:
            if not hasattr(self, "writer"):
                self.writer = DatadirWriter(kwargs.get("output_dir"))
            lid_writer = self.writer["lid"]
            lid_writer[key[0]] = predicted_lid
        results = [{"key": key[0], "lid": predicted_lid}]
        return results, meta_data