From 2acd24f0158b2c86d2fb4e6f1134b67a1150500e Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 29 二月 2024 17:14:59 +0800
Subject: [PATCH] update whisper lid (#1407)
---
funasr/models/whisper_lid/lid_predictor.py | 25
funasr/models/whisper_lid/decoder.py | 167 +++++
funasr/models/whisper_lid/eres2net/simple_avg.py | 17
funasr/models/whisper_lid/encoder.py | 119 ++++
funasr/models/whisper_lid/__init__.py | 0
funasr/models/whisper_lid/eres2net/fusion.py | 29 +
funasr/models/whisper_lid/eres2net/ResNet.py | 428 ++++++++++++++
examples/common_voice/whisper_lid/demo_funasr.py | 19
examples/common_voice/whisper_lid/demo_modelscope.py | 22
funasr/models/whisper_lid/eres2net/pooling_layers.py | 118 ++++
funasr/models/whisper_lid/model.py | 665 ++++++++++++++++++++++
funasr/models/whisper_lid/eres2net/__init__.py | 0
funasr/frontends/whisper_frontend.py | 102 +++
13 files changed, 1,711 insertions(+), 0 deletions(-)
diff --git a/examples/common_voice/whisper_lid/demo_funasr.py b/examples/common_voice/whisper_lid/demo_funasr.py
new file mode 100644
index 0000000..9af790e
--- /dev/null
+++ b/examples/common_voice/whisper_lid/demo_funasr.py
@@ -0,0 +1,19 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+multilingual_wavs = [
+ "example_zh-CN.mp3",
+ "example_en.mp3",
+ "example_ja.mp3",
+ "example_ko.mp3",
+]
+
+model = AutoModel(model="iic/speech_whisper-large_lid_multilingual_pytorch", model_revision="v2.0.4")
+for wav_id in multilingual_wavs:
+ wav_file = f"{model.model_path}/examples/{wav_id}"
+ res = model.generate(input=wav_file, data_type="sound", inference_clip_length=250)
+ print("detect sample {}: {}".format(wav_id, res))
\ No newline at end of file
diff --git a/examples/common_voice/whisper_lid/demo_modelscope.py b/examples/common_voice/whisper_lid/demo_modelscope.py
new file mode 100644
index 0000000..cce389a
--- /dev/null
+++ b/examples/common_voice/whisper_lid/demo_modelscope.py
@@ -0,0 +1,22 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+multilingual_wavs=[
+ "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_zh-CN.mp3",
+ "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_en.mp3",
+ "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ja.mp3",
+ "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ko.mp3",
+]
+
+inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='iic/speech_whisper-large_lid_multilingual_pytorch', model_revision="v2.0.4")
+
+for wav in multilingual_wavs:
+ rec_result = inference_pipeline(input=wav, inference_clip_length=250)
+ print(rec_result)
\ No newline at end of file
diff --git a/funasr/frontends/whisper_frontend.py b/funasr/frontends/whisper_frontend.py
new file mode 100644
index 0000000..752fd20
--- /dev/null
+++ b/funasr/frontends/whisper_frontend.py
@@ -0,0 +1,102 @@
+from typing import Tuple
+import torch
+import torch.nn as nn
+import whisper
+from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
+from funasr.register import tables
+from torch.nn.utils.rnn import pad_sequence
+
+
+@tables.register("frontend_classes", "WhisperFrontend")
+class WhisperFrontend(nn.Module):
+ """Speech Representation Using Encoder Outputs from OpenAI's Whisper Model:
+
+ URL: https://github.com/openai/whisper
+ """
+
+ def __init__(
+ self,
+ fs: int = 16000,
+ whisper_model: str = "large-v3",
+ do_pad_trim: bool = True,
+ ):
+ super().__init__()
+ assert fs == 16000
+ self.fs = fs
+
+ self.n_fft = N_FFT
+ self.win_length = N_FFT
+ self.hop_length = HOP_LENGTH
+ self.pad_samples = N_SAMPLES
+ self.frame_shift = self.hop_length
+ self.lfr_n = 1
+ if whisper_model == "large-v3" or whisper_model == "large":
+ self.n_mels = 128
+ else:
+ self.n_mels = 80
+
+ self.mel_filters = whisper.audio.mel_filters
+ self.do_pad_trim = do_pad_trim
+ if do_pad_trim:
+ self.pad_or_trim = whisper.pad_or_trim
+
+ assert whisper_model in whisper.available_models()
+
+ def output_size(self) -> int:
+ return self.n_mels
+
+ def log_mel_spectrogram(
+ self,
+ audio: torch.Tensor,
+ ilens: torch.Tensor = None,
+ ) -> torch.Tensor:
+ window = torch.hann_window(self.win_length).to(audio.device)
+ stft = torch.stft(
+ audio, self.n_fft, self.hop_length, window=window, return_complex=True
+ )
+
+ # whisper deletes the last frame by default (Shih-Lun)
+ magnitudes = stft[..., :-1].abs() ** 2
+
+ filters = self.mel_filters(audio.device, self.n_mels)
+ mel_spec = filters @ magnitudes
+
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
+
+ if ilens is not None:
+ olens = ilens // self.hop_length
+ else:
+ olens = None
+
+ log_spec = torch.maximum(
+ log_spec,
+ log_spec.view(audio.size(0), -1).max(dim=-1)[0][:, None, None] - 8.0,
+ )
+ log_spec = (log_spec + 4.0) / 4.0
+
+ return log_spec, olens
+
+ def forward(
+ self, input: torch.Tensor, input_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size = input.size(0)
+ feats = []
+ feats_lens = []
+ for i in range(batch_size):
+ if self.do_pad_trim:
+ feat = self.pad_or_trim(input[i], self.pad_samples)
+ else:
+ feat = input[i]
+ feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[0])
+ feats.append(feat[0])
+ feats_lens.append(feat_len)
+ feats_lens = torch.as_tensor(feats_lens)
+
+ if batch_size == 1:
+ feats_pad = feats[0][None, :, :]
+ else:
+ feats_pad = pad_sequence(feats,
+ batch_first=True,
+ padding_value=0.0)
+
+ return feats_pad, feats_lens
\ No newline at end of file
diff --git a/funasr/models/whisper_lid/__init__.py b/funasr/models/whisper_lid/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/whisper_lid/__init__.py
diff --git a/funasr/models/whisper_lid/decoder.py b/funasr/models/whisper_lid/decoder.py
new file mode 100644
index 0000000..4db9205
--- /dev/null
+++ b/funasr/models/whisper_lid/decoder.py
@@ -0,0 +1,167 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import copy
+from typing import Any, List, Tuple
+
+import torch
+from torch import nn
+import whisper
+
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.register import tables
+
+
+@tables.register("decoder_classes", "OpenAIWhisperDecoderWarp")
+class OpenAIWhisperDecoderWarp(nn.Module):
+ """Transformer-based Speech-to-Text Decoder from OpenAI's Whisper Model:
+
+ URL: https://github.com/openai/whisper
+ """
+
+ def __init__(
+ self,
+ dropout_rate: float = 0.0,
+ whisper_model: str = "small",
+ download_dir: str = None,
+ use_padmask: bool = False,
+ ):
+ super().__init__()
+
+ assert whisper_model in whisper.available_models()
+ _model = whisper.load_model(
+ whisper_model, download_root=download_dir, device="cpu"
+ )
+ self.decoders = copy.deepcopy(_model.decoder)
+ attention_dim = self.decoders.token_embedding.embedding_dim
+
+ # note that originally Whisper doesn't use dropouts
+ self.dropout = torch.nn.Dropout(dropout_rate)
+
+ self.decoders.train()
+ del _model
+ self.use_padmask = use_padmask
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+
+ Args:
+ hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
+ hlens: (batch)
+ ys_in_pad:
+ input token ids, int64 (batch, maxlen_out)
+ if input_layer == "embed"
+ input tensor (batch, maxlen_out, #mels) in the other cases
+ ys_in_lens: (batch)
+ Returns:
+ (tuple): tuple containing:
+
+ x: decoded token score before softmax (batch, maxlen_out, token)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ tgt, memory = ys_in_pad, hs_pad
+ tgt = (
+ self.decoders.token_embedding(tgt)
+ + self.decoders.positional_embedding[: tgt.size(1)]
+ )
+ tgt = self.dropout(tgt)
+
+ x = tgt.to(memory.dtype)
+
+ if self.use_padmask:
+ memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
+ else:
+ memory_mask = None
+
+ for layer, block in enumerate(self.decoders.blocks):
+ x = block(x, memory, mask=self.decoders.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
+
+ if layer < len(self.decoders.blocks) - 1:
+ x = self.dropout(x)
+
+ x = self.decoders.ln(x)
+ x = (
+ x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
+ ).float()
+
+ return x, ys_in_lens
+
+ def forward_one_step(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ cache: List[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+
+ Args:
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ NOTE (Shih-Lun):
+ cache implementation is ignored for now
+ for simplicity & correctness
+ """
+ x = (
+ self.decoders.token_embedding(tgt)
+ + self.decoders.positional_embedding[: tgt.size(1)]
+ )
+ x = self.dropout(x)
+ x = x.to(memory.dtype)
+
+ for layer, block in enumerate(self.decoders.blocks):
+ x = block(x, memory, mask=self.decoders.mask)
+ if layer < len(self.decoders.blocks) - 1:
+ x = self.dropout(x)
+
+ x = self.decoders.ln(x)
+ y = x[:, -1]
+ y = (
+ y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
+ ).float()
+ y = torch.log_softmax(y, dim=-1)
+
+ return y, None
+
+ def score(self, ys, state, x):
+ """Score."""
+ logp, state = self.forward_one_step(
+ ys.unsqueeze(0), torch.empty(0), x.unsqueeze(0), cache=state # dummy mask
+ )
+ return logp.squeeze(0), state
+
+ def batch_score(
+ self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
+ ) -> Tuple[torch.Tensor, List[Any]]:
+ """Score new token batch.
+
+ Args:
+ ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
+ states (List[Any]): Scorer states for prefix tokens.
+ xs (torch.Tensor):
+ The encoder feature that generates ys (n_batch, xlen, n_feat).
+
+ Returns:
+ tuple[torch.Tensor, List[Any]]: Tuple of
+ batchfied scores for next token with shape of `(n_batch, n_vocab)`
+ and next state list for ys.
+
+ """
+ # batch decoding, dummy mask is passed
+ logp, states = self.forward_one_step(ys, torch.empty(0), xs, cache=None)
+
+ return logp, None
diff --git a/funasr/models/whisper_lid/encoder.py b/funasr/models/whisper_lid/encoder.py
new file mode 100644
index 0000000..7eeb643
--- /dev/null
+++ b/funasr/models/whisper_lid/encoder.py
@@ -0,0 +1,119 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import copy
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+import whisper
+
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.specaug.specaug import SpecAug
+from funasr.register import tables
+
+
+@tables.register("encoder_classes", "OpenAIWhisperEncoderWarp")
+class OpenAIWhisperEncoderWarp(nn.Module):
+ """Transformer-based Speech Encoder from OpenAI's Whisper Model:
+
+ URL: https://github.com/openai/whisper
+ """
+
+ def __init__(
+ self,
+ dropout_rate: float = 0.0,
+ whisper_model: str = "small",
+ download_dir: str = None,
+ use_specaug: bool = False,
+ use_padmask: bool = False,
+ specaug_conf: Union[dict, None] = None,
+ ):
+ super().__init__()
+
+ # note that originally Whisper doesn't use dropouts
+ self.dropout = torch.nn.Dropout(dropout_rate)
+
+ assert whisper_model in whisper.available_models()
+ _model = whisper.load_model(
+ whisper_model, download_root=download_dir, device="cpu"
+ )
+ self.encoders = copy.deepcopy(_model.encoder)
+ self.encoders.train()
+
+ del _model
+
+ if use_specaug:
+ self.specaug = SpecAug(**specaug_conf)
+ else:
+ self.specaug = None
+ self.use_padmask = use_padmask
+
+ def whisper_encode(
+ self,
+ input: torch.Tensor,
+ ilens: torch.Tensor = None,
+ ) -> torch.Tensor:
+ x = F.gelu(self.encoders.conv1(input))
+ x = F.gelu(self.encoders.conv2(x))
+ x = x.permute(0, 2, 1)
+
+ n_frames = x.size(1)
+ max_pos = self.encoders.positional_embedding.size(0)
+ if n_frames <= max_pos:
+ x = (x + self.encoders.positional_embedding[: x.size(1), :]).to(x.dtype)
+ else:
+ # due to positional encoding, audios >30 sec won't be accepted
+ x = x[:, :max_pos, :] + self.encoders.positional_embedding
+
+ if ilens is not None:
+ olens = (
+ 1
+ + (
+ ilens
+ - self.encoders.conv2.kernel_size[0]
+ + 2 * self.encoders.conv2.padding[0]
+ )
+ // self.encoders.conv2.stride[0]
+ )
+ olens = torch.clamp(olens, max=max_pos)
+ else:
+ olens = None
+
+ if self.use_padmask:
+ padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
+ else:
+ padding_mask = None
+
+ x = self.dropout(x)
+
+ for layer, block in enumerate(self.encoders.blocks):
+ x = block(x)
+ if layer < len(self.encoders.blocks) - 1:
+ x = self.dropout(x)
+
+ x = self.encoders.ln_post(x)
+
+ return x, olens
+
+ def output_size(self) -> int:
+ # dummy output size
+ return self.encoders.conv2.weight.shape[0]
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ feats, feats_lens = xs_pad, ilens
+
+ if self.specaug is not None and self.encoders.training:
+ feats = torch.transpose(feats, 1, 2)
+ feats, feats_lens = self.specaug(feats, feats_lens)
+ feats = torch.transpose(feats, 1, 2)
+
+ xs_pad, olens = self.whisper_encode(feats, feats_lens)
+
+ return xs_pad, olens, None
diff --git a/funasr/models/whisper_lid/eres2net/ResNet.py b/funasr/models/whisper_lid/eres2net/ResNet.py
new file mode 100644
index 0000000..25c79f5
--- /dev/null
+++ b/funasr/models/whisper_lid/eres2net/ResNet.py
@@ -0,0 +1,428 @@
+# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
+ ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
+ The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
+ The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
+ ERes2Net-Large is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
+ recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
+"""
+
+import torch
+import math
+import torch.nn as nn
+import torch.nn.functional as F
+import funasr.models.whisper_lid.eres2net.pooling_layers as pooling_layers
+from funasr.models.whisper_lid.eres2net.fusion import AFF
+
+
+class ReLU(nn.Hardtanh):
+
+ def __init__(self, inplace=False):
+ super(ReLU, self).__init__(0, 20, inplace)
+
+ def __repr__(self):
+ inplace_str = 'inplace' if self.inplace else ''
+ return self.__class__.__name__ + ' (' \
+ + inplace_str + ')'
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ "1x1 convolution without padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
+ padding=0, bias=False)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlockERes2Net(nn.Module):
+ expansion = 2
+
+ def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
+ super(BasicBlockERes2Net, self).__init__()
+ width = int(math.floor(planes * (baseWidth / 64.0)))
+ self.conv1 = conv1x1(in_planes, width * scale, stride)
+ self.bn1 = nn.BatchNorm2d(width * scale)
+ self.nums = scale
+
+ convs = []
+ bns = []
+ for i in range(self.nums):
+ convs.append(conv3x3(width, width))
+ bns.append(nn.BatchNorm2d(width))
+ self.convs = nn.ModuleList(convs)
+ self.bns = nn.ModuleList(bns)
+ self.relu = ReLU(inplace=True)
+
+ self.conv3 = conv1x1(width * scale, planes * self.expansion)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(self.expansion * planes))
+ self.stride = stride
+ self.width = width
+ self.scale = scale
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ spx = torch.split(out, self.width, 1)
+ for i in range(self.nums):
+ if i == 0:
+ sp = spx[i]
+ else:
+ sp = sp + spx[i]
+ sp = self.convs[i](sp)
+ sp = self.relu(self.bns[i](sp))
+ if i == 0:
+ out = sp
+ else:
+ out = torch.cat((out, sp), 1)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ residual = self.shortcut(x)
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class BasicBlockERes2Net_diff_AFF(nn.Module):
+ expansion = 2
+
+ def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
+ super(BasicBlockERes2Net_diff_AFF, self).__init__()
+ width = int(math.floor(planes * (baseWidth / 64.0)))
+ self.conv1 = conv1x1(in_planes, width * scale, stride)
+ self.bn1 = nn.BatchNorm2d(width * scale)
+ self.nums = scale
+
+ convs = []
+ fuse_models = []
+ bns = []
+ for i in range(self.nums):
+ convs.append(conv3x3(width, width))
+ bns.append(nn.BatchNorm2d(width))
+ for j in range(self.nums - 1):
+ fuse_models.append(AFF(channels=width))
+
+ self.convs = nn.ModuleList(convs)
+ self.bns = nn.ModuleList(bns)
+ self.fuse_models = nn.ModuleList(fuse_models)
+ self.relu = ReLU(inplace=True)
+
+ self.conv3 = conv1x1(width * scale, planes * self.expansion)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(self.expansion * planes))
+ self.stride = stride
+ self.width = width
+ self.scale = scale
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ spx = torch.split(out, self.width, 1)
+ for i in range(self.nums):
+ if i == 0:
+ sp = spx[i]
+ else:
+ sp = self.fuse_models[i - 1](sp, spx[i])
+
+ sp = self.convs[i](sp)
+ sp = self.relu(self.bns[i](sp))
+ if i == 0:
+ out = sp
+ else:
+ out = torch.cat((out, sp), 1)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ residual = self.shortcut(x)
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ERes2Net(nn.Module):
+ def __init__(self,
+ block=BasicBlockERes2Net,
+ block_fuse=BasicBlockERes2Net_diff_AFF,
+ num_blocks=[3, 4, 6, 3],
+ m_channels=32,
+ feat_dim=80,
+ embedding_size=192,
+ pooling_func='TSTP',
+ two_emb_layer=False):
+ super(ERes2Net, self).__init__()
+ self.in_planes = m_channels
+ self.feat_dim = feat_dim
+ self.embedding_size = embedding_size
+ self.stats_dim = int(feat_dim / 8) * m_channels * 8
+ self.two_emb_layer = two_emb_layer
+ self._output_size = embedding_size
+
+ self.conv1 = nn.Conv2d(1,
+ m_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(m_channels)
+ self.layer1 = self._make_layer(block,
+ m_channels,
+ num_blocks[0],
+ stride=1)
+ self.layer2 = self._make_layer(block,
+ m_channels * 2,
+ num_blocks[1],
+ stride=2)
+ self.layer3 = self._make_layer(block_fuse,
+ m_channels * 4,
+ num_blocks[2],
+ stride=2)
+ self.layer4 = self._make_layer(block_fuse,
+ m_channels * 8,
+ num_blocks[3],
+ stride=2)
+
+ # Downsampling module for each layer
+ self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1,
+ bias=False)
+ self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2,
+ bias=False)
+ self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2,
+ bias=False)
+
+ # Bottom-up fusion module
+ self.fuse_mode12 = AFF(channels=m_channels * 4)
+ self.fuse_mode123 = AFF(channels=m_channels * 8)
+ self.fuse_mode1234 = AFF(channels=m_channels * 16)
+
+ self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
+ self.pool = getattr(pooling_layers, pooling_func)(
+ in_dim=self.stats_dim * block.expansion)
+ self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
+ embedding_size)
+ if self.two_emb_layer:
+ self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
+ self.seg_2 = nn.Linear(embedding_size, embedding_size)
+ else:
+ self.seg_bn_1 = nn.Identity()
+ self.seg_2 = nn.Identity()
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(self, x, ilens):
+ # assert x.shape[1] == ilens.max()
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
+ x = x.unsqueeze_(1)
+ out = F.relu(self.bn1(self.conv1(x)))
+ out1 = self.layer1(out)
+ out2 = self.layer2(out1)
+ out1_downsample = self.layer1_downsample(out1)
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
+ out3 = self.layer3(out2)
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
+ out4 = self.layer4(out3)
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
+ olens = (((((ilens - 1) // 2 + 1) - 1) // 2 + 1) - 1) // 2 + 1
+ stats = self.pool(fuse_out1234, olens)
+
+ embed_a = self.seg_1(stats)
+ if self.two_emb_layer:
+ out = F.relu(embed_a)
+ out = self.seg_bn_1(out)
+ embed_b = self.seg_2(out)
+ return embed_b
+ else:
+ return embed_a
+
+
+class BasicBlockRes2Net(nn.Module):
+ expansion = 2
+
+ def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
+ super(BasicBlockRes2Net, self).__init__()
+ width = int(math.floor(planes * (baseWidth / 64.0)))
+ self.conv1 = conv1x1(in_planes, width * scale, stride)
+ self.bn1 = nn.BatchNorm2d(width * scale)
+ self.nums = scale - 1
+ convs = []
+ bns = []
+ for i in range(self.nums):
+ convs.append(conv3x3(width, width))
+ bns.append(nn.BatchNorm2d(width))
+ self.convs = nn.ModuleList(convs)
+ self.bns = nn.ModuleList(bns)
+ self.relu = ReLU(inplace=True)
+
+ self.conv3 = conv1x1(width * scale, planes * self.expansion)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(self.expansion * planes))
+ self.stride = stride
+ self.width = width
+ self.scale = scale
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ spx = torch.split(out, self.width, 1)
+ for i in range(self.nums):
+ if i == 0:
+ sp = spx[i]
+ else:
+ sp = sp + spx[i]
+ sp = self.convs[i](sp)
+ sp = self.relu(self.bns[i](sp))
+ if i == 0:
+ out = sp
+ else:
+ out = torch.cat((out, sp), 1)
+
+ out = torch.cat((out, spx[self.nums]), 1)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ residual = self.shortcut(x)
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Res2Net(nn.Module):
+ def __init__(self,
+ block=BasicBlockRes2Net,
+ num_blocks=[3, 4, 6, 3],
+ m_channels=32,
+ feat_dim=80,
+ embedding_size=192,
+ pooling_func='TSTP',
+ two_emb_layer=False):
+ super(Res2Net, self).__init__()
+ self.in_planes = m_channels
+ self.feat_dim = feat_dim
+ self.embedding_size = embedding_size
+ self.stats_dim = int(feat_dim / 8) * m_channels * 8
+ self.two_emb_layer = two_emb_layer
+
+ self.conv1 = nn.Conv2d(1,
+ m_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(m_channels)
+ self.layer1 = self._make_layer(block,
+ m_channels,
+ num_blocks[0],
+ stride=1)
+ self.layer2 = self._make_layer(block,
+ m_channels * 2,
+ num_blocks[1],
+ stride=2)
+ self.layer3 = self._make_layer(block,
+ m_channels * 4,
+ num_blocks[2],
+ stride=2)
+ self.layer4 = self._make_layer(block,
+ m_channels * 8,
+ num_blocks[3],
+ stride=2)
+
+ self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
+ self.pool = getattr(pooling_layers, pooling_func)(
+ in_dim=self.stats_dim * block.expansion)
+ self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
+ embedding_size)
+ if self.two_emb_layer:
+ self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
+ self.seg_2 = nn.Linear(embedding_size, embedding_size)
+ else:
+ self.seg_bn_1 = nn.Identity()
+ self.seg_2 = nn.Identity()
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
+
+ x = x.unsqueeze_(1)
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = self.layer3(out)
+ out = self.layer4(out)
+
+ stats = self.pool(out)
+
+ embed_a = self.seg_1(stats)
+ if self.two_emb_layer:
+ out = F.relu(embed_a)
+ out = self.seg_bn_1(out)
+ embed_b = self.seg_2(out)
+ return embed_b
+ else:
+ return embed_a
+
+
+
+
diff --git a/funasr/models/whisper_lid/eres2net/__init__.py b/funasr/models/whisper_lid/eres2net/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/whisper_lid/eres2net/__init__.py
diff --git a/funasr/models/whisper_lid/eres2net/fusion.py b/funasr/models/whisper_lid/eres2net/fusion.py
new file mode 100644
index 0000000..2aff7a7
--- /dev/null
+++ b/funasr/models/whisper_lid/eres2net/fusion.py
@@ -0,0 +1,29 @@
+# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import torch
+import torch.nn as nn
+
+
+class AFF(nn.Module):
+
+ def __init__(self, channels=64, r=4):
+ super(AFF, self).__init__()
+ inter_channels = int(channels // r)
+
+ self.local_att = nn.Sequential(
+ nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.SiLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+
+ def forward(self, x, ds_y):
+ xa = torch.cat((x, ds_y), dim=1)
+ x_att = self.local_att(xa)
+ x_att = 1.0 + torch.tanh(x_att)
+ xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att)
+
+ return xo
+
diff --git a/funasr/models/whisper_lid/eres2net/pooling_layers.py b/funasr/models/whisper_lid/eres2net/pooling_layers.py
new file mode 100644
index 0000000..f756ac8
--- /dev/null
+++ b/funasr/models/whisper_lid/eres2net/pooling_layers.py
@@ -0,0 +1,118 @@
+# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+""" This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
+
+import torch
+import torch.nn as nn
+
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+
+
+class TAP(nn.Module):
+ """
+ Temporal average pooling, only first-order mean is considered
+ """
+
+ def __init__(self, **kwargs):
+ super(TAP, self).__init__()
+
+ def forward(self, x):
+ pooling_mean = x.mean(dim=-1)
+ # To be compatable with 2D input
+ pooling_mean = pooling_mean.flatten(start_dim=1)
+ return pooling_mean
+
+
+class TSDP(nn.Module):
+ """
+ Temporal standard deviation pooling, only second-order std is considered
+ """
+
+ def __init__(self, **kwargs):
+ super(TSDP, self).__init__()
+
+ def forward(self, x):
+ # The last dimension is the temporal axis
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
+ pooling_std = pooling_std.flatten(start_dim=1)
+ return pooling_std
+
+
+class TSTP(nn.Module):
+ """
+ Temporal statistics pooling, concatenate mean and std, which is used in
+ x-vector
+ Comment: simple concatenation can not make full use of both statistics
+ """
+
+ def __init__(self, **kwargs):
+ super(TSTP, self).__init__()
+
+ def forward(self, x, olens):
+ # The last dimension is the temporal axis
+ masks = (~make_pad_mask(olens, maxlen=x.shape[-1])[:, None, None, :]).to(x.device)
+ x_masked = x * masks
+ sum_without_padding = torch.sum(x_masked, axis=-1)
+ count_without_padding = torch.sum(masks, axis=-1)
+ mean_without_padding = sum_without_padding / count_without_padding
+
+ var_without_padding = ((x_masked - mean_without_padding.unsqueeze(-1)) ** 2 * masks).sum(-1) / count_without_padding
+
+ pooling_mean = mean_without_padding
+ pooling_std = torch.sqrt(var_without_padding + 1e-8)
+ pooling_mean = pooling_mean.flatten(start_dim=1)
+ pooling_std = pooling_std.flatten(start_dim=1)
+
+ stats = torch.cat((pooling_mean, pooling_std), 1)
+ return stats
+
+
+class ASTP(nn.Module):
+ """ Attentive statistics pooling: Channel- and context-dependent
+ statistics pooling, first used in ECAPA_TDNN.
+ """
+
+ def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
+ super(ASTP, self).__init__()
+ self.global_context_att = global_context_att
+
+ # Use Conv1d with stride == 1 rather than Linear, then we don't
+ # need to transpose inputs.
+ if global_context_att:
+ self.linear1 = nn.Conv1d(
+ in_dim * 3, bottleneck_dim,
+ kernel_size=1) # equals W and b in the paper
+ else:
+ self.linear1 = nn.Conv1d(
+ in_dim, bottleneck_dim,
+ kernel_size=1) # equals W and b in the paper
+ self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
+ kernel_size=1) # equals V and k in the paper
+
+ def forward(self, x):
+ """
+ x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
+ or a 4-dimensional tensor in resnet architecture (B,C,F,T)
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
+ """
+ if len(x.shape) == 4:
+ x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
+ assert len(x.shape) == 3
+
+ if self.global_context_att:
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
+ context_std = torch.sqrt(
+ torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
+ else:
+ x_in = x
+
+ # DON'T use ReLU here! ReLU may be hard to converge.
+ alpha = torch.tanh(
+ self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
+ mean = torch.sum(alpha * x, dim=2)
+ var = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
+ std = torch.sqrt(var.clamp(min=1e-10))
+ return torch.cat([mean, std], dim=1)
diff --git a/funasr/models/whisper_lid/eres2net/simple_avg.py b/funasr/models/whisper_lid/eres2net/simple_avg.py
new file mode 100644
index 0000000..4fb4c0a
--- /dev/null
+++ b/funasr/models/whisper_lid/eres2net/simple_avg.py
@@ -0,0 +1,17 @@
+import torch
+
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.modules.nets_utils import make_pad_mask
+
+class SimpleAvg(AbsEncoder):
+ def __init__(self, feat_dim):
+ super(SimpleAvg, self).__init__()
+ self.feat_dim = feat_dim
+
+ def forward(self, x, ilens):
+ mask = ~make_pad_mask(ilens, maxlen=x.shape[1]).to(x.device)
+ avg_x = (x * mask[:, :, None]).sum(1) / mask.sum(-1)[:, None]
+ return avg_x
+
+ def output_size(self) -> int:
+ return self.feat_dim
\ No newline at end of file
diff --git a/funasr/models/whisper_lid/lid_predictor.py b/funasr/models/whisper_lid/lid_predictor.py
new file mode 100644
index 0000000..5e042d2
--- /dev/null
+++ b/funasr/models/whisper_lid/lid_predictor.py
@@ -0,0 +1,25 @@
+from funasr.register import tables
+from funasr.models.whisper_lid.eres2net.ResNet import ERes2Net, BasicBlockERes2Net, BasicBlockERes2Net_diff_AFF
+
+
+@tables.register("lid_predictor_classes", "LidPredictor")
+class LidPredictor(ERes2Net):
+ def __init__(self,
+ block=BasicBlockERes2Net,
+ block_fuse=BasicBlockERes2Net_diff_AFF,
+ num_blocks=[3, 4, 6, 3],
+ m_channels=32,
+ feat_dim=80,
+ embedding_size=192,
+ pooling_func='TSTP',
+ two_emb_layer=False):
+ super(LidPredictor, self).__init__(
+ block=block,
+ block_fuse=block_fuse,
+ num_blocks=num_blocks,
+ m_channels=m_channels,
+ feat_dim=feat_dim,
+ embedding_size=embedding_size,
+ pooling_func=pooling_func,
+ two_emb_layer=two_emb_layer
+ )
\ No newline at end of file
diff --git a/funasr/models/whisper_lid/model.py b/funasr/models/whisper_lid/model.py
new file mode 100644
index 0000000..6ffb43a
--- /dev/null
+++ b/funasr/models/whisper_lid/model.py
@@ -0,0 +1,665 @@
+import logging
+from typing import Union, Dict, List, Tuple, Optional
+
+import time
+import torch
+import numpy as np
+import torch.nn as nn
+from torch.cuda.amp import autocast
+
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.ctc.ctc import CTC
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.register import tables
+
+
+@tables.register("model_classes", "OpenAIWhisperModel")
+class OpenAIWhisperModel(nn.Module):
+ """CTC-attention hybrid Encoder-Decoder model"""
+
+
+ def __init__(
+ self,
+ specaug: str = None,
+ specaug_conf: dict = None,
+ normalize: str = None,
+ normalize_conf: dict = None,
+ encoder: str = None,
+ encoder_conf: dict = None,
+ decoder: str = None,
+ decoder_conf: dict = None,
+ ctc: str = None,
+ ctc_conf: dict = None,
+ ctc_weight: float = 0.5,
+ interctc_weight: float = 0.0,
+ input_size: int = 80,
+ vocab_size: int = -1,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ # extract_feats_in_collect_stats: bool = True,
+ share_embedding: bool = False,
+ # preencoder: Optional[AbsPreEncoder] = None,
+ # postencoder: Optional[AbsPostEncoder] = None,
+ **kwargs,
+ ):
+
+ super().__init__()
+
+ if specaug is not None:
+ specaug_class = tables.specaug_classes.get(specaug)
+ specaug = specaug_class(**specaug_conf)
+ if normalize is not None:
+ normalize_class = tables.normalize_classes.get(normalize)
+ normalize = normalize_class(**normalize_conf)
+ encoder_class = tables.encoder_classes.get(encoder)
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
+ encoder_output_size = encoder.output_size()
+ if decoder is not None:
+ decoder_class = tables.decoder_classes.get(decoder)
+ decoder = decoder_class(decoder_conf)
+ if ctc_weight > 0.0:
+
+ if ctc_conf is None:
+ ctc_conf = {}
+
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+ )
+
+ self.blank_id = blank_id
+ self.sos = sos if sos is not None else vocab_size - 1
+ self.eos = eos if eos is not None else vocab_size - 1
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.ctc_weight = ctc_weight
+ self.specaug = specaug
+ self.normalize = normalize
+ self.encoder = encoder
+
+ if not hasattr(self.encoder, "interctc_use_conditioning"):
+ self.encoder.interctc_use_conditioning = False
+ if self.encoder.interctc_use_conditioning:
+ self.encoder.conditioning_layer = torch.nn.Linear(
+ vocab_size, self.encoder.output_size()
+ )
+ self.interctc_weight = interctc_weight
+
+ # self.error_calculator = None
+ if ctc_weight == 1.0:
+ self.decoder = None
+ else:
+ self.decoder = decoder
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+ #
+ # if report_cer or report_wer:
+ # self.error_calculator = ErrorCalculator(
+ # token_list, sym_space, sym_blank, report_cer, report_wer
+ # )
+ #
+ self.error_calculator = None
+ if ctc_weight == 0.0:
+ self.ctc = None
+ else:
+ self.ctc = ctc
+
+ self.share_embedding = share_embedding
+ if self.share_embedding:
+ self.decoder.embed = None
+
+ self.length_normalized_loss = length_normalized_loss
+ self.beam_search = None
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ # import pdb;
+ # pdb.set_trace()
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ loss_att, acc_att, cer_att, wer_att = None, None, None, None
+ loss_ctc, cer_ctc = None, None
+ stats = dict()
+
+ # decoder: CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+ # Intermediate CTC (optional)
+ loss_interctc = 0.0
+ if self.interctc_weight != 0.0 and intermediate_outs is not None:
+ for layer_idx, intermediate_out in intermediate_outs:
+ # we assume intermediate_out has the same length & padding
+ # as those of encoder_out
+ loss_ic, cer_ic = self._calc_ctc_loss(
+ intermediate_out, encoder_out_lens, text, text_lengths
+ )
+ loss_interctc = loss_interctc + loss_ic
+
+ # Collect Intermedaite CTC stats
+ stats["loss_interctc_layer{}".format(layer_idx)] = (
+ loss_ic.detach() if loss_ic is not None else None
+ )
+ stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+ loss_interctc = loss_interctc / len(intermediate_outs)
+
+ # calculate whole encoder loss
+ loss_ctc = (
+ 1 - self.interctc_weight
+ ) * loss_ctc + self.interctc_weight * loss_interctc
+
+ # decoder: Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att
+ elif self.ctc_weight == 1.0:
+ loss = loss_ctc
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+
+ # Collect total loss stats
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = int((text_lengths + 1).sum())
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+ # Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ if self.encoder.interctc_use_conditioning:
+ encoder_out, encoder_out_lens, _ = self.encoder(
+ speech, speech_lengths, ctc=self.ctc
+ )
+ else:
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ if intermediate_outs is not None:
+ return (encoder_out, intermediate_outs), encoder_out_lens
+
+ return encoder_out, encoder_out_lens
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_in_lens = ys_pad_lens + 1
+
+ # 1. Forward decoder
+ decoder_out, _ = self.decoder(
+ encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+ )
+
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_out_pad)
+ acc_att = th_accuracy(
+ decoder_out.view(-1, self.vocab_size),
+ ys_out_pad,
+ ignore_label=self.ignore_id,
+ )
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+ return loss_ctc, cer_ctc
+
+ def init_beam_search(self,
+ **kwargs,
+ ):
+ from funasr.models.transformer.search import BeamSearch
+ from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+ from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+ # 1. Build ASR model
+ scorers = {}
+
+ if self.ctc != None:
+ ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+ scorers.update(
+ ctc=ctc
+ )
+ token_list = kwargs.get("token_list")
+ scorers.update(
+ decoder=self.decoder,
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ weights = dict(
+ decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.5),
+ ctc=kwargs.get("decoding_ctc_weight", 0.5),
+ lm=kwargs.get("lm_weight", 0.0),
+ ngram=kwargs.get("ngram_weight", 0.0),
+ length_bonus=kwargs.get("penalty", 0.0),
+ )
+ beam_search = BeamSearch(
+ beam_size=kwargs.get("beam_size", 10),
+ weights=weights,
+ scorers=scorers,
+ sos=self.sos,
+ eos=self.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+ )
+
+ self.beam_search = beam_search
+
+ def inference(self,
+ data_in,
+ data_lengths=None,
+ key: list=None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+
+ if kwargs.get("batch_size", 1) > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+
+ # init beamsearch
+ if self.beam_search is None:
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ meta_data = {}
+ if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is None:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer)
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+
+
+ results = []
+ b, n, d = encoder_out.size()
+ for i in range(b):
+
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ result_i = {"key": key[i], "token": token, "text": text_postprocessed}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text_postprocessed
+
+ return results, meta_data
+
+
+@tables.register("model_classes", "OpenAIWhisperLIDModel")
+class OpenAIWhisperLIDModel(nn.Module):
+ """WhisperEncoder and EResNet based LID Model"""
+
+ def __init__(
+ self,
+ vocab_size: int,
+ specaug: str = None,
+ specaug_conf: dict = None,
+ encoder: str = None,
+ encoder_conf: dict = None,
+ lid_predictor: str = None,
+ lid_predictor_conf: dict = None,
+ proj_dim: int = None,
+ clip_frames: int = None,
+ random_clip: bool = False,
+ **kwargs,
+ ):
+ super().__init__()
+ if specaug is not None:
+ specaug_class = tables.specaug_classes.get(specaug)
+ specaug = specaug_class(**specaug_conf)
+ encoder_class = tables.encoder_classes.get(encoder)
+ encoder = encoder_class(**encoder_conf)
+ lid_predictor_class = tables.lid_predictor_classes.get(lid_predictor)
+ lid_predictor = lid_predictor_class(**lid_predictor_conf)
+ if encoder.output_size() != proj_dim:
+ self.proj_layer = torch.nn.Linear(encoder.output_size(), proj_dim)
+ else:
+ self.proj_layer = None
+ self.output_layer = torch.nn.Linear(lid_predictor.output_size(), vocab_size)
+ self.criterion_lid = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=-1,
+ smoothing=0.0,
+ normalize_length=False,
+ )
+
+ self.specaug = specaug
+ self.encoder = encoder
+ self.lid_predictor = lid_predictor
+ self.clip_frames = clip_frames
+ self.random_clip = random_clip
+ self.normalize = None
+ self.beam_search = None
+ if not hasattr(self.encoder, "interctc_use_conditioning"):
+ self.encoder.interctc_use_conditioning = False
+
+ def forward(self,
+ speech: torch.Tensor, # may be padding
+ speech_lengths: torch.Tensor, # actual length
+ lid: torch.Tensor, # lid label, (batch_size, 1)
+ lid_lengths: torch.Tensor,
+ ):
+ assert lid.shape[1] == 1
+ batch_size = speech.shape[0]
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ # re-generate encoder_out
+ if self.clip_frames is None:
+ reduced_encoder_out = torch.zeros(batch_size, encoder_out_lens.max(), encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device)
+ for i, enc_length in enumerate(encoder_out_lens):
+ reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
+ else:
+ reduced_encoder_out = torch.zeros(batch_size, self.clip_frames, encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device)
+ if self.random_clip:
+ for i, enc_length in enumerate(encoder_out_lens):
+ if enc_length <= self.clip_frames:
+ reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
+ encoder_out_lens[i] = enc_length
+ else:
+ max_start_index = enc_length.item() - self.clip_frames
+ start_index = np.random.randint(0, max_start_index + 1)
+ reduced_encoder_out[i, :self.clip_frames] = encoder_out[i, start_index:start_index + self.clip_frames]
+ encoder_out_lens[i] = self.clip_frames
+ else:
+ for i, enc_length in enumerate(encoder_out_lens):
+ enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length
+ reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
+ encoder_out_lens[i] = enc_length
+ if self.proj_layer is not None:
+ reduced_encoder_out = self.proj_layer(reduced_encoder_out)
+ lid_output = self.lid_predictor(reduced_encoder_out, encoder_out_lens) # (B, D)
+ lid_logits = self.output_layer(lid_output) # (B, num_classes)
+ loss = self.criterion_lid(lid_logits[:, None, :], lid)
+ with torch.no_grad():
+ _, predicted_lid = torch.max(lid_logits, 1)
+ correct = (predicted_lid == lid[:, 0]).sum().item()
+ lid_acc = correct * 1.0 / lid_logits.shape[0]
+ stats = dict()
+ stats["batch_size"] = batch_size
+ stats["loss"] = torch.clone(loss.detach())
+ stats["acc"] = lid_acc
+ stats["token_length"] = speech_lengths.max()
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ """
+ with autocast(False):
+
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech = speech.permute(0, 2, 1)
+ # suit for whisper padding
+ padded_speech_lengths = torch.ones_like(speech_lengths) * speech.shape[1]
+ speech, padded_speech_lengths = self.specaug(speech, padded_speech_lengths)
+ speech = speech.permute(0, 2, 1)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+ # Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ if self.encoder.interctc_use_conditioning:
+ encoder_out, encoder_out_lens, _ = self.encoder(
+ speech, speech_lengths, ctc=self.ctc
+ )
+ else:
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ if intermediate_outs is not None:
+ return (encoder_out, intermediate_outs), encoder_out_lens
+
+ return encoder_out, encoder_out_lens
+
+ def inference(self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+
+ if kwargs.get("batch_size", 1) > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+
+ meta_data = {}
+ if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is None:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer)
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+ # Encoder
+ enc, enc_out_lens = self.encode(speech, speech_lengths)
+
+ inference_clip_length = kwargs.get("inference_clip_length", None)
+ if self.clip_frames is not None:
+ if inference_clip_length is None:
+ reduced_enc = torch.zeros(enc.shape[0], self.clip_frames, enc.shape[-1]).to(enc.dtype).to(enc.device)
+ for i, enc_length in enumerate(enc_out_lens):
+ enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length
+ reduced_enc[i, :enc_length] = enc[i, :enc_length]
+ enc_out_lens[i] = enc_length
+ else:
+ assert inference_clip_length > 0, "inference_clip_length must be larger than 0"
+ reduced_enc = torch.zeros(enc.shape[0], inference_clip_length, enc.shape[-1]).to(enc.dtype).to(enc.device)
+ for i, enc_length in enumerate(enc_out_lens):
+ enc_length = inference_clip_length if enc_length >= inference_clip_length else enc_length
+ reduced_enc[i, :enc_length] = enc[i, :enc_length]
+ enc_out_lens[i] = enc_length
+ else:
+ reduced_enc = torch.zeros(enc.shape[0], enc_out_lens.max(), enc.shape[-1]).to(enc.dtype).to(enc.device)
+ for i, enc_length in enumerate(enc_out_lens):
+ reduced_enc[i, :enc_length] = enc[i, :enc_length]
+
+ if self.proj_layer is not None:
+ reduced_enc = self.proj_layer(reduced_enc)
+ lid_output = self.lid_predictor(reduced_enc, enc_out_lens) # (B, D)
+ lid_logits = self.output_layer(lid_output) # (B, num_classes)
+
+ _, predicted_lid_index = torch.max(lid_logits, 1)
+ predicted_lid = tokenizer.ids2tokens([predicted_lid_index[0].cpu()])[0]
+
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ lid_writer = self.writer["lid"]
+ lid_writer[key[0]] = predicted_lid
+
+ results = [{"key": key[0], "lid": predicted_lid}]
+
+ return results, meta_data
\ No newline at end of file
--
Gitblit v1.9.1