update whisper lid (#1407)
* update whisper lid
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | from funasr import AutoModel |
| | | |
| | | multilingual_wavs = [ |
| | | "example_zh-CN.mp3", |
| | | "example_en.mp3", |
| | | "example_ja.mp3", |
| | | "example_ko.mp3", |
| | | ] |
| | | |
| | | model = AutoModel(model="iic/speech_whisper-large_lid_multilingual_pytorch", model_revision="v2.0.4") |
| | | for wav_id in multilingual_wavs: |
| | | wav_file = f"{model.model_path}/examples/{wav_id}" |
| | | res = model.generate(input=wav_file, data_type="sound", inference_clip_length=250) |
| | | print("detect sample {}: {}".format(wav_id, res)) |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | from modelscope.pipelines import pipeline |
| | | from modelscope.utils.constant import Tasks |
| | | |
| | | multilingual_wavs=[ |
| | | "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_zh-CN.mp3", |
| | | "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_en.mp3", |
| | | "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ja.mp3", |
| | | "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ko.mp3", |
| | | ] |
| | | |
| | | inference_pipeline = pipeline( |
| | | task=Tasks.auto_speech_recognition, |
| | | model='iic/speech_whisper-large_lid_multilingual_pytorch', model_revision="v2.0.4") |
| | | |
| | | for wav in multilingual_wavs: |
| | | rec_result = inference_pipeline(input=wav, inference_clip_length=250) |
| | | print(rec_result) |
| New file |
| | |
| | | from typing import Tuple |
| | | import torch |
| | | import torch.nn as nn |
| | | import whisper |
| | | from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES |
| | | from funasr.register import tables |
| | | from torch.nn.utils.rnn import pad_sequence |
| | | |
| | | |
| | | @tables.register("frontend_classes", "WhisperFrontend") |
| | | class WhisperFrontend(nn.Module): |
| | | """Speech Representation Using Encoder Outputs from OpenAI's Whisper Model: |
| | | |
| | | URL: https://github.com/openai/whisper |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | fs: int = 16000, |
| | | whisper_model: str = "large-v3", |
| | | do_pad_trim: bool = True, |
| | | ): |
| | | super().__init__() |
| | | assert fs == 16000 |
| | | self.fs = fs |
| | | |
| | | self.n_fft = N_FFT |
| | | self.win_length = N_FFT |
| | | self.hop_length = HOP_LENGTH |
| | | self.pad_samples = N_SAMPLES |
| | | self.frame_shift = self.hop_length |
| | | self.lfr_n = 1 |
| | | if whisper_model == "large-v3" or whisper_model == "large": |
| | | self.n_mels = 128 |
| | | else: |
| | | self.n_mels = 80 |
| | | |
| | | self.mel_filters = whisper.audio.mel_filters |
| | | self.do_pad_trim = do_pad_trim |
| | | if do_pad_trim: |
| | | self.pad_or_trim = whisper.pad_or_trim |
| | | |
| | | assert whisper_model in whisper.available_models() |
| | | |
| | | def output_size(self) -> int: |
| | | return self.n_mels |
| | | |
| | | def log_mel_spectrogram( |
| | | self, |
| | | audio: torch.Tensor, |
| | | ilens: torch.Tensor = None, |
| | | ) -> torch.Tensor: |
| | | window = torch.hann_window(self.win_length).to(audio.device) |
| | | stft = torch.stft( |
| | | audio, self.n_fft, self.hop_length, window=window, return_complex=True |
| | | ) |
| | | |
| | | # whisper deletes the last frame by default (Shih-Lun) |
| | | magnitudes = stft[..., :-1].abs() ** 2 |
| | | |
| | | filters = self.mel_filters(audio.device, self.n_mels) |
| | | mel_spec = filters @ magnitudes |
| | | |
| | | log_spec = torch.clamp(mel_spec, min=1e-10).log10() |
| | | |
| | | if ilens is not None: |
| | | olens = ilens // self.hop_length |
| | | else: |
| | | olens = None |
| | | |
| | | log_spec = torch.maximum( |
| | | log_spec, |
| | | log_spec.view(audio.size(0), -1).max(dim=-1)[0][:, None, None] - 8.0, |
| | | ) |
| | | log_spec = (log_spec + 4.0) / 4.0 |
| | | |
| | | return log_spec, olens |
| | | |
| | | def forward( |
| | | self, input: torch.Tensor, input_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | batch_size = input.size(0) |
| | | feats = [] |
| | | feats_lens = [] |
| | | for i in range(batch_size): |
| | | if self.do_pad_trim: |
| | | feat = self.pad_or_trim(input[i], self.pad_samples) |
| | | else: |
| | | feat = input[i] |
| | | feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[0]) |
| | | feats.append(feat[0]) |
| | | feats_lens.append(feat_len) |
| | | feats_lens = torch.as_tensor(feats_lens) |
| | | |
| | | if batch_size == 1: |
| | | feats_pad = feats[0][None, :, :] |
| | | else: |
| | | feats_pad = pad_sequence(feats, |
| | | batch_first=True, |
| | | padding_value=0.0) |
| | | |
| | | return feats_pad, feats_lens |
| New file |
| | |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | import copy |
| | | from typing import Any, List, Tuple |
| | | |
| | | import torch |
| | | from torch import nn |
| | | import whisper |
| | | |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("decoder_classes", "OpenAIWhisperDecoderWarp") |
| | | class OpenAIWhisperDecoderWarp(nn.Module): |
| | | """Transformer-based Speech-to-Text Decoder from OpenAI's Whisper Model: |
| | | |
| | | URL: https://github.com/openai/whisper |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | dropout_rate: float = 0.0, |
| | | whisper_model: str = "small", |
| | | download_dir: str = None, |
| | | use_padmask: bool = False, |
| | | ): |
| | | super().__init__() |
| | | |
| | | assert whisper_model in whisper.available_models() |
| | | _model = whisper.load_model( |
| | | whisper_model, download_root=download_dir, device="cpu" |
| | | ) |
| | | self.decoders = copy.deepcopy(_model.decoder) |
| | | attention_dim = self.decoders.token_embedding.embedding_dim |
| | | |
| | | # note that originally Whisper doesn't use dropouts |
| | | self.dropout = torch.nn.Dropout(dropout_rate) |
| | | |
| | | self.decoders.train() |
| | | del _model |
| | | self.use_padmask = use_padmask |
| | | |
| | | def forward( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Forward decoder. |
| | | |
| | | Args: |
| | | hs_pad: encoded memory, float32 (batch, maxlen_in, feat) |
| | | hlens: (batch) |
| | | ys_in_pad: |
| | | input token ids, int64 (batch, maxlen_out) |
| | | if input_layer == "embed" |
| | | input tensor (batch, maxlen_out, #mels) in the other cases |
| | | ys_in_lens: (batch) |
| | | Returns: |
| | | (tuple): tuple containing: |
| | | |
| | | x: decoded token score before softmax (batch, maxlen_out, token) |
| | | if use_output_layer is True, |
| | | olens: (batch, ) |
| | | """ |
| | | tgt, memory = ys_in_pad, hs_pad |
| | | tgt = ( |
| | | self.decoders.token_embedding(tgt) |
| | | + self.decoders.positional_embedding[: tgt.size(1)] |
| | | ) |
| | | tgt = self.dropout(tgt) |
| | | |
| | | x = tgt.to(memory.dtype) |
| | | |
| | | if self.use_padmask: |
| | | memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device) |
| | | else: |
| | | memory_mask = None |
| | | |
| | | for layer, block in enumerate(self.decoders.blocks): |
| | | x = block(x, memory, mask=self.decoders.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True) |
| | | |
| | | if layer < len(self.decoders.blocks) - 1: |
| | | x = self.dropout(x) |
| | | |
| | | x = self.decoders.ln(x) |
| | | x = ( |
| | | x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1) |
| | | ).float() |
| | | |
| | | return x, ys_in_lens |
| | | |
| | | def forward_one_step( |
| | | self, |
| | | tgt: torch.Tensor, |
| | | tgt_mask: torch.Tensor, |
| | | memory: torch.Tensor, |
| | | cache: List[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
| | | """Forward one step. |
| | | |
| | | Args: |
| | | tgt: input token ids, int64 (batch, maxlen_out) |
| | | tgt_mask: input token mask, (batch, maxlen_out) |
| | | dtype=torch.uint8 in PyTorch 1.2- |
| | | dtype=torch.bool in PyTorch 1.2+ (include 1.2) |
| | | memory: encoded memory, float32 (batch, maxlen_in, feat) |
| | | cache: cached output list of (batch, max_time_out-1, size) |
| | | Returns: |
| | | y, cache: NN output value and cache per `self.decoders`. |
| | | y.shape` is (batch, maxlen_out, token) |
| | | NOTE (Shih-Lun): |
| | | cache implementation is ignored for now |
| | | for simplicity & correctness |
| | | """ |
| | | x = ( |
| | | self.decoders.token_embedding(tgt) |
| | | + self.decoders.positional_embedding[: tgt.size(1)] |
| | | ) |
| | | x = self.dropout(x) |
| | | x = x.to(memory.dtype) |
| | | |
| | | for layer, block in enumerate(self.decoders.blocks): |
| | | x = block(x, memory, mask=self.decoders.mask) |
| | | if layer < len(self.decoders.blocks) - 1: |
| | | x = self.dropout(x) |
| | | |
| | | x = self.decoders.ln(x) |
| | | y = x[:, -1] |
| | | y = ( |
| | | y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1) |
| | | ).float() |
| | | y = torch.log_softmax(y, dim=-1) |
| | | |
| | | return y, None |
| | | |
| | | def score(self, ys, state, x): |
| | | """Score.""" |
| | | logp, state = self.forward_one_step( |
| | | ys.unsqueeze(0), torch.empty(0), x.unsqueeze(0), cache=state # dummy mask |
| | | ) |
| | | return logp.squeeze(0), state |
| | | |
| | | def batch_score( |
| | | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, List[Any]]: |
| | | """Score new token batch. |
| | | |
| | | Args: |
| | | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). |
| | | states (List[Any]): Scorer states for prefix tokens. |
| | | xs (torch.Tensor): |
| | | The encoder feature that generates ys (n_batch, xlen, n_feat). |
| | | |
| | | Returns: |
| | | tuple[torch.Tensor, List[Any]]: Tuple of |
| | | batchfied scores for next token with shape of `(n_batch, n_vocab)` |
| | | and next state list for ys. |
| | | |
| | | """ |
| | | # batch decoding, dummy mask is passed |
| | | logp, states = self.forward_one_step(ys, torch.empty(0), xs, cache=None) |
| | | |
| | | return logp, None |
| New file |
| | |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | import copy |
| | | from typing import Optional, Tuple, Union |
| | | |
| | | import torch |
| | | from torch import nn |
| | | import torch.nn.functional as F |
| | | import whisper |
| | | |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("encoder_classes", "OpenAIWhisperEncoderWarp") |
| | | class OpenAIWhisperEncoderWarp(nn.Module): |
| | | """Transformer-based Speech Encoder from OpenAI's Whisper Model: |
| | | |
| | | URL: https://github.com/openai/whisper |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | dropout_rate: float = 0.0, |
| | | whisper_model: str = "small", |
| | | download_dir: str = None, |
| | | use_specaug: bool = False, |
| | | use_padmask: bool = False, |
| | | specaug_conf: Union[dict, None] = None, |
| | | ): |
| | | super().__init__() |
| | | |
| | | # note that originally Whisper doesn't use dropouts |
| | | self.dropout = torch.nn.Dropout(dropout_rate) |
| | | |
| | | assert whisper_model in whisper.available_models() |
| | | _model = whisper.load_model( |
| | | whisper_model, download_root=download_dir, device="cpu" |
| | | ) |
| | | self.encoders = copy.deepcopy(_model.encoder) |
| | | self.encoders.train() |
| | | |
| | | del _model |
| | | |
| | | if use_specaug: |
| | | self.specaug = SpecAug(**specaug_conf) |
| | | else: |
| | | self.specaug = None |
| | | self.use_padmask = use_padmask |
| | | |
| | | def whisper_encode( |
| | | self, |
| | | input: torch.Tensor, |
| | | ilens: torch.Tensor = None, |
| | | ) -> torch.Tensor: |
| | | x = F.gelu(self.encoders.conv1(input)) |
| | | x = F.gelu(self.encoders.conv2(x)) |
| | | x = x.permute(0, 2, 1) |
| | | |
| | | n_frames = x.size(1) |
| | | max_pos = self.encoders.positional_embedding.size(0) |
| | | if n_frames <= max_pos: |
| | | x = (x + self.encoders.positional_embedding[: x.size(1), :]).to(x.dtype) |
| | | else: |
| | | # due to positional encoding, audios >30 sec won't be accepted |
| | | x = x[:, :max_pos, :] + self.encoders.positional_embedding |
| | | |
| | | if ilens is not None: |
| | | olens = ( |
| | | 1 |
| | | + ( |
| | | ilens |
| | | - self.encoders.conv2.kernel_size[0] |
| | | + 2 * self.encoders.conv2.padding[0] |
| | | ) |
| | | // self.encoders.conv2.stride[0] |
| | | ) |
| | | olens = torch.clamp(olens, max=max_pos) |
| | | else: |
| | | olens = None |
| | | |
| | | if self.use_padmask: |
| | | padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device) |
| | | else: |
| | | padding_mask = None |
| | | |
| | | x = self.dropout(x) |
| | | |
| | | for layer, block in enumerate(self.encoders.blocks): |
| | | x = block(x) |
| | | if layer < len(self.encoders.blocks) - 1: |
| | | x = self.dropout(x) |
| | | |
| | | x = self.encoders.ln_post(x) |
| | | |
| | | return x, olens |
| | | |
| | | def output_size(self) -> int: |
| | | # dummy output size |
| | | return self.encoders.conv2.weight.shape[0] |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | feats, feats_lens = xs_pad, ilens |
| | | |
| | | if self.specaug is not None and self.encoders.training: |
| | | feats = torch.transpose(feats, 1, 2) |
| | | feats, feats_lens = self.specaug(feats, feats_lens) |
| | | feats = torch.transpose(feats, 1, 2) |
| | | |
| | | xs_pad, olens = self.whisper_encode(feats, feats_lens) |
| | | |
| | | return xs_pad, olens, None |
| New file |
| | |
| | | # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. |
| | | # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | """ Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker. |
| | | ERes2Net incorporates both local and global feature fusion techniques to improve the performance. |
| | | The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal. |
| | | The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal. |
| | | ERes2Net-Large is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better |
| | | recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance. |
| | | """ |
| | | |
| | | import torch |
| | | import math |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | import funasr.models.whisper_lid.eres2net.pooling_layers as pooling_layers |
| | | from funasr.models.whisper_lid.eres2net.fusion import AFF |
| | | |
| | | |
| | | class ReLU(nn.Hardtanh): |
| | | |
| | | def __init__(self, inplace=False): |
| | | super(ReLU, self).__init__(0, 20, inplace) |
| | | |
| | | def __repr__(self): |
| | | inplace_str = 'inplace' if self.inplace else '' |
| | | return self.__class__.__name__ + ' (' \ |
| | | + inplace_str + ')' |
| | | |
| | | |
| | | def conv1x1(in_planes, out_planes, stride=1): |
| | | "1x1 convolution without padding" |
| | | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, |
| | | padding=0, bias=False) |
| | | |
| | | |
| | | def conv3x3(in_planes, out_planes, stride=1): |
| | | "3x3 convolution with padding" |
| | | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, |
| | | padding=1, bias=False) |
| | | |
| | | |
| | | class BasicBlockERes2Net(nn.Module): |
| | | expansion = 2 |
| | | |
| | | def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): |
| | | super(BasicBlockERes2Net, self).__init__() |
| | | width = int(math.floor(planes * (baseWidth / 64.0))) |
| | | self.conv1 = conv1x1(in_planes, width * scale, stride) |
| | | self.bn1 = nn.BatchNorm2d(width * scale) |
| | | self.nums = scale |
| | | |
| | | convs = [] |
| | | bns = [] |
| | | for i in range(self.nums): |
| | | convs.append(conv3x3(width, width)) |
| | | bns.append(nn.BatchNorm2d(width)) |
| | | self.convs = nn.ModuleList(convs) |
| | | self.bns = nn.ModuleList(bns) |
| | | self.relu = ReLU(inplace=True) |
| | | |
| | | self.conv3 = conv1x1(width * scale, planes * self.expansion) |
| | | self.bn3 = nn.BatchNorm2d(planes * self.expansion) |
| | | self.shortcut = nn.Sequential() |
| | | if stride != 1 or in_planes != self.expansion * planes: |
| | | self.shortcut = nn.Sequential( |
| | | nn.Conv2d(in_planes, |
| | | self.expansion * planes, |
| | | kernel_size=1, |
| | | stride=stride, |
| | | bias=False), |
| | | nn.BatchNorm2d(self.expansion * planes)) |
| | | self.stride = stride |
| | | self.width = width |
| | | self.scale = scale |
| | | |
| | | def forward(self, x): |
| | | residual = x |
| | | |
| | | out = self.conv1(x) |
| | | out = self.bn1(out) |
| | | out = self.relu(out) |
| | | spx = torch.split(out, self.width, 1) |
| | | for i in range(self.nums): |
| | | if i == 0: |
| | | sp = spx[i] |
| | | else: |
| | | sp = sp + spx[i] |
| | | sp = self.convs[i](sp) |
| | | sp = self.relu(self.bns[i](sp)) |
| | | if i == 0: |
| | | out = sp |
| | | else: |
| | | out = torch.cat((out, sp), 1) |
| | | |
| | | out = self.conv3(out) |
| | | out = self.bn3(out) |
| | | |
| | | residual = self.shortcut(x) |
| | | out += residual |
| | | out = self.relu(out) |
| | | |
| | | return out |
| | | |
| | | |
| | | class BasicBlockERes2Net_diff_AFF(nn.Module): |
| | | expansion = 2 |
| | | |
| | | def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): |
| | | super(BasicBlockERes2Net_diff_AFF, self).__init__() |
| | | width = int(math.floor(planes * (baseWidth / 64.0))) |
| | | self.conv1 = conv1x1(in_planes, width * scale, stride) |
| | | self.bn1 = nn.BatchNorm2d(width * scale) |
| | | self.nums = scale |
| | | |
| | | convs = [] |
| | | fuse_models = [] |
| | | bns = [] |
| | | for i in range(self.nums): |
| | | convs.append(conv3x3(width, width)) |
| | | bns.append(nn.BatchNorm2d(width)) |
| | | for j in range(self.nums - 1): |
| | | fuse_models.append(AFF(channels=width)) |
| | | |
| | | self.convs = nn.ModuleList(convs) |
| | | self.bns = nn.ModuleList(bns) |
| | | self.fuse_models = nn.ModuleList(fuse_models) |
| | | self.relu = ReLU(inplace=True) |
| | | |
| | | self.conv3 = conv1x1(width * scale, planes * self.expansion) |
| | | self.bn3 = nn.BatchNorm2d(planes * self.expansion) |
| | | self.shortcut = nn.Sequential() |
| | | if stride != 1 or in_planes != self.expansion * planes: |
| | | self.shortcut = nn.Sequential( |
| | | nn.Conv2d(in_planes, |
| | | self.expansion * planes, |
| | | kernel_size=1, |
| | | stride=stride, |
| | | bias=False), |
| | | nn.BatchNorm2d(self.expansion * planes)) |
| | | self.stride = stride |
| | | self.width = width |
| | | self.scale = scale |
| | | |
| | | def forward(self, x): |
| | | residual = x |
| | | |
| | | out = self.conv1(x) |
| | | out = self.bn1(out) |
| | | out = self.relu(out) |
| | | spx = torch.split(out, self.width, 1) |
| | | for i in range(self.nums): |
| | | if i == 0: |
| | | sp = spx[i] |
| | | else: |
| | | sp = self.fuse_models[i - 1](sp, spx[i]) |
| | | |
| | | sp = self.convs[i](sp) |
| | | sp = self.relu(self.bns[i](sp)) |
| | | if i == 0: |
| | | out = sp |
| | | else: |
| | | out = torch.cat((out, sp), 1) |
| | | |
| | | out = self.conv3(out) |
| | | out = self.bn3(out) |
| | | |
| | | residual = self.shortcut(x) |
| | | out += residual |
| | | out = self.relu(out) |
| | | |
| | | return out |
| | | |
| | | |
| | | class ERes2Net(nn.Module): |
| | | def __init__(self, |
| | | block=BasicBlockERes2Net, |
| | | block_fuse=BasicBlockERes2Net_diff_AFF, |
| | | num_blocks=[3, 4, 6, 3], |
| | | m_channels=32, |
| | | feat_dim=80, |
| | | embedding_size=192, |
| | | pooling_func='TSTP', |
| | | two_emb_layer=False): |
| | | super(ERes2Net, self).__init__() |
| | | self.in_planes = m_channels |
| | | self.feat_dim = feat_dim |
| | | self.embedding_size = embedding_size |
| | | self.stats_dim = int(feat_dim / 8) * m_channels * 8 |
| | | self.two_emb_layer = two_emb_layer |
| | | self._output_size = embedding_size |
| | | |
| | | self.conv1 = nn.Conv2d(1, |
| | | m_channels, |
| | | kernel_size=3, |
| | | stride=1, |
| | | padding=1, |
| | | bias=False) |
| | | self.bn1 = nn.BatchNorm2d(m_channels) |
| | | self.layer1 = self._make_layer(block, |
| | | m_channels, |
| | | num_blocks[0], |
| | | stride=1) |
| | | self.layer2 = self._make_layer(block, |
| | | m_channels * 2, |
| | | num_blocks[1], |
| | | stride=2) |
| | | self.layer3 = self._make_layer(block_fuse, |
| | | m_channels * 4, |
| | | num_blocks[2], |
| | | stride=2) |
| | | self.layer4 = self._make_layer(block_fuse, |
| | | m_channels * 8, |
| | | num_blocks[3], |
| | | stride=2) |
| | | |
| | | # Downsampling module for each layer |
| | | self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, |
| | | bias=False) |
| | | self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, |
| | | bias=False) |
| | | self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, |
| | | bias=False) |
| | | |
| | | # Bottom-up fusion module |
| | | self.fuse_mode12 = AFF(channels=m_channels * 4) |
| | | self.fuse_mode123 = AFF(channels=m_channels * 8) |
| | | self.fuse_mode1234 = AFF(channels=m_channels * 16) |
| | | |
| | | self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2 |
| | | self.pool = getattr(pooling_layers, pooling_func)( |
| | | in_dim=self.stats_dim * block.expansion) |
| | | self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, |
| | | embedding_size) |
| | | if self.two_emb_layer: |
| | | self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) |
| | | self.seg_2 = nn.Linear(embedding_size, embedding_size) |
| | | else: |
| | | self.seg_bn_1 = nn.Identity() |
| | | self.seg_2 = nn.Identity() |
| | | |
| | | def _make_layer(self, block, planes, num_blocks, stride): |
| | | strides = [stride] + [1] * (num_blocks - 1) |
| | | layers = [] |
| | | for stride in strides: |
| | | layers.append(block(self.in_planes, planes, stride)) |
| | | self.in_planes = planes * block.expansion |
| | | return nn.Sequential(*layers) |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def forward(self, x, ilens): |
| | | # assert x.shape[1] == ilens.max() |
| | | x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) |
| | | x = x.unsqueeze_(1) |
| | | out = F.relu(self.bn1(self.conv1(x))) |
| | | out1 = self.layer1(out) |
| | | out2 = self.layer2(out1) |
| | | out1_downsample = self.layer1_downsample(out1) |
| | | fuse_out12 = self.fuse_mode12(out2, out1_downsample) |
| | | out3 = self.layer3(out2) |
| | | fuse_out12_downsample = self.layer2_downsample(fuse_out12) |
| | | fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) |
| | | out4 = self.layer4(out3) |
| | | fuse_out123_downsample = self.layer3_downsample(fuse_out123) |
| | | fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample) |
| | | olens = (((((ilens - 1) // 2 + 1) - 1) // 2 + 1) - 1) // 2 + 1 |
| | | stats = self.pool(fuse_out1234, olens) |
| | | |
| | | embed_a = self.seg_1(stats) |
| | | if self.two_emb_layer: |
| | | out = F.relu(embed_a) |
| | | out = self.seg_bn_1(out) |
| | | embed_b = self.seg_2(out) |
| | | return embed_b |
| | | else: |
| | | return embed_a |
| | | |
| | | |
| | | class BasicBlockRes2Net(nn.Module): |
| | | expansion = 2 |
| | | |
| | | def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): |
| | | super(BasicBlockRes2Net, self).__init__() |
| | | width = int(math.floor(planes * (baseWidth / 64.0))) |
| | | self.conv1 = conv1x1(in_planes, width * scale, stride) |
| | | self.bn1 = nn.BatchNorm2d(width * scale) |
| | | self.nums = scale - 1 |
| | | convs = [] |
| | | bns = [] |
| | | for i in range(self.nums): |
| | | convs.append(conv3x3(width, width)) |
| | | bns.append(nn.BatchNorm2d(width)) |
| | | self.convs = nn.ModuleList(convs) |
| | | self.bns = nn.ModuleList(bns) |
| | | self.relu = ReLU(inplace=True) |
| | | |
| | | self.conv3 = conv1x1(width * scale, planes * self.expansion) |
| | | self.bn3 = nn.BatchNorm2d(planes * self.expansion) |
| | | self.shortcut = nn.Sequential() |
| | | if stride != 1 or in_planes != self.expansion * planes: |
| | | self.shortcut = nn.Sequential( |
| | | nn.Conv2d(in_planes, |
| | | self.expansion * planes, |
| | | kernel_size=1, |
| | | stride=stride, |
| | | bias=False), |
| | | nn.BatchNorm2d(self.expansion * planes)) |
| | | self.stride = stride |
| | | self.width = width |
| | | self.scale = scale |
| | | |
| | | def forward(self, x): |
| | | residual = x |
| | | |
| | | out = self.conv1(x) |
| | | out = self.bn1(out) |
| | | out = self.relu(out) |
| | | spx = torch.split(out, self.width, 1) |
| | | for i in range(self.nums): |
| | | if i == 0: |
| | | sp = spx[i] |
| | | else: |
| | | sp = sp + spx[i] |
| | | sp = self.convs[i](sp) |
| | | sp = self.relu(self.bns[i](sp)) |
| | | if i == 0: |
| | | out = sp |
| | | else: |
| | | out = torch.cat((out, sp), 1) |
| | | |
| | | out = torch.cat((out, spx[self.nums]), 1) |
| | | |
| | | out = self.conv3(out) |
| | | out = self.bn3(out) |
| | | |
| | | residual = self.shortcut(x) |
| | | out += residual |
| | | out = self.relu(out) |
| | | |
| | | return out |
| | | |
| | | |
| | | class Res2Net(nn.Module): |
| | | def __init__(self, |
| | | block=BasicBlockRes2Net, |
| | | num_blocks=[3, 4, 6, 3], |
| | | m_channels=32, |
| | | feat_dim=80, |
| | | embedding_size=192, |
| | | pooling_func='TSTP', |
| | | two_emb_layer=False): |
| | | super(Res2Net, self).__init__() |
| | | self.in_planes = m_channels |
| | | self.feat_dim = feat_dim |
| | | self.embedding_size = embedding_size |
| | | self.stats_dim = int(feat_dim / 8) * m_channels * 8 |
| | | self.two_emb_layer = two_emb_layer |
| | | |
| | | self.conv1 = nn.Conv2d(1, |
| | | m_channels, |
| | | kernel_size=3, |
| | | stride=1, |
| | | padding=1, |
| | | bias=False) |
| | | self.bn1 = nn.BatchNorm2d(m_channels) |
| | | self.layer1 = self._make_layer(block, |
| | | m_channels, |
| | | num_blocks[0], |
| | | stride=1) |
| | | self.layer2 = self._make_layer(block, |
| | | m_channels * 2, |
| | | num_blocks[1], |
| | | stride=2) |
| | | self.layer3 = self._make_layer(block, |
| | | m_channels * 4, |
| | | num_blocks[2], |
| | | stride=2) |
| | | self.layer4 = self._make_layer(block, |
| | | m_channels * 8, |
| | | num_blocks[3], |
| | | stride=2) |
| | | |
| | | self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2 |
| | | self.pool = getattr(pooling_layers, pooling_func)( |
| | | in_dim=self.stats_dim * block.expansion) |
| | | self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, |
| | | embedding_size) |
| | | if self.two_emb_layer: |
| | | self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) |
| | | self.seg_2 = nn.Linear(embedding_size, embedding_size) |
| | | else: |
| | | self.seg_bn_1 = nn.Identity() |
| | | self.seg_2 = nn.Identity() |
| | | |
| | | def _make_layer(self, block, planes, num_blocks, stride): |
| | | strides = [stride] + [1] * (num_blocks - 1) |
| | | layers = [] |
| | | for stride in strides: |
| | | layers.append(block(self.in_planes, planes, stride)) |
| | | self.in_planes = planes * block.expansion |
| | | return nn.Sequential(*layers) |
| | | |
| | | def forward(self, x): |
| | | x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) |
| | | |
| | | x = x.unsqueeze_(1) |
| | | out = F.relu(self.bn1(self.conv1(x))) |
| | | out = self.layer1(out) |
| | | out = self.layer2(out) |
| | | out = self.layer3(out) |
| | | out = self.layer4(out) |
| | | |
| | | stats = self.pool(out) |
| | | |
| | | embed_a = self.seg_1(stats) |
| | | if self.two_emb_layer: |
| | | out = F.relu(embed_a) |
| | | out = self.seg_bn_1(out) |
| | | embed_b = self.seg_2(out) |
| | | return embed_b |
| | | else: |
| | | return embed_a |
| | | |
| | | |
| | | |
| | | |
| New file |
| | |
| | | # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. |
| | | # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | |
| | | class AFF(nn.Module): |
| | | |
| | | def __init__(self, channels=64, r=4): |
| | | super(AFF, self).__init__() |
| | | inter_channels = int(channels // r) |
| | | |
| | | self.local_att = nn.Sequential( |
| | | nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0), |
| | | nn.BatchNorm2d(inter_channels), |
| | | nn.SiLU(inplace=True), |
| | | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), |
| | | nn.BatchNorm2d(channels), |
| | | ) |
| | | |
| | | def forward(self, x, ds_y): |
| | | xa = torch.cat((x, ds_y), dim=1) |
| | | x_att = self.local_att(xa) |
| | | x_att = 1.0 + torch.tanh(x_att) |
| | | xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att) |
| | | |
| | | return xo |
| | | |
| New file |
| | |
| | | # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. |
| | | # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | """ This implementation is adapted from https://github.com/wenet-e2e/wespeaker.""" |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | |
| | | |
| | | class TAP(nn.Module): |
| | | """ |
| | | Temporal average pooling, only first-order mean is considered |
| | | """ |
| | | |
| | | def __init__(self, **kwargs): |
| | | super(TAP, self).__init__() |
| | | |
| | | def forward(self, x): |
| | | pooling_mean = x.mean(dim=-1) |
| | | # To be compatable with 2D input |
| | | pooling_mean = pooling_mean.flatten(start_dim=1) |
| | | return pooling_mean |
| | | |
| | | |
| | | class TSDP(nn.Module): |
| | | """ |
| | | Temporal standard deviation pooling, only second-order std is considered |
| | | """ |
| | | |
| | | def __init__(self, **kwargs): |
| | | super(TSDP, self).__init__() |
| | | |
| | | def forward(self, x): |
| | | # The last dimension is the temporal axis |
| | | pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8) |
| | | pooling_std = pooling_std.flatten(start_dim=1) |
| | | return pooling_std |
| | | |
| | | |
| | | class TSTP(nn.Module): |
| | | """ |
| | | Temporal statistics pooling, concatenate mean and std, which is used in |
| | | x-vector |
| | | Comment: simple concatenation can not make full use of both statistics |
| | | """ |
| | | |
| | | def __init__(self, **kwargs): |
| | | super(TSTP, self).__init__() |
| | | |
| | | def forward(self, x, olens): |
| | | # The last dimension is the temporal axis |
| | | masks = (~make_pad_mask(olens, maxlen=x.shape[-1])[:, None, None, :]).to(x.device) |
| | | x_masked = x * masks |
| | | sum_without_padding = torch.sum(x_masked, axis=-1) |
| | | count_without_padding = torch.sum(masks, axis=-1) |
| | | mean_without_padding = sum_without_padding / count_without_padding |
| | | |
| | | var_without_padding = ((x_masked - mean_without_padding.unsqueeze(-1)) ** 2 * masks).sum(-1) / count_without_padding |
| | | |
| | | pooling_mean = mean_without_padding |
| | | pooling_std = torch.sqrt(var_without_padding + 1e-8) |
| | | pooling_mean = pooling_mean.flatten(start_dim=1) |
| | | pooling_std = pooling_std.flatten(start_dim=1) |
| | | |
| | | stats = torch.cat((pooling_mean, pooling_std), 1) |
| | | return stats |
| | | |
| | | |
| | | class ASTP(nn.Module): |
| | | """ Attentive statistics pooling: Channel- and context-dependent |
| | | statistics pooling, first used in ECAPA_TDNN. |
| | | """ |
| | | |
| | | def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False): |
| | | super(ASTP, self).__init__() |
| | | self.global_context_att = global_context_att |
| | | |
| | | # Use Conv1d with stride == 1 rather than Linear, then we don't |
| | | # need to transpose inputs. |
| | | if global_context_att: |
| | | self.linear1 = nn.Conv1d( |
| | | in_dim * 3, bottleneck_dim, |
| | | kernel_size=1) # equals W and b in the paper |
| | | else: |
| | | self.linear1 = nn.Conv1d( |
| | | in_dim, bottleneck_dim, |
| | | kernel_size=1) # equals W and b in the paper |
| | | self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, |
| | | kernel_size=1) # equals V and k in the paper |
| | | |
| | | def forward(self, x): |
| | | """ |
| | | x: a 3-dimensional tensor in tdnn-based architecture (B,F,T) |
| | | or a 4-dimensional tensor in resnet architecture (B,C,F,T) |
| | | 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) |
| | | """ |
| | | if len(x.shape) == 4: |
| | | x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) |
| | | assert len(x.shape) == 3 |
| | | |
| | | if self.global_context_att: |
| | | context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) |
| | | context_std = torch.sqrt( |
| | | torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) |
| | | x_in = torch.cat((x, context_mean, context_std), dim=1) |
| | | else: |
| | | x_in = x |
| | | |
| | | # DON'T use ReLU here! ReLU may be hard to converge. |
| | | alpha = torch.tanh( |
| | | self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in)) |
| | | alpha = torch.softmax(self.linear2(alpha), dim=2) |
| | | mean = torch.sum(alpha * x, dim=2) |
| | | var = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2 |
| | | std = torch.sqrt(var.clamp(min=1e-10)) |
| | | return torch.cat([mean, std], dim=1) |
| New file |
| | |
| | | import torch |
| | | |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | |
| | | class SimpleAvg(AbsEncoder): |
| | | def __init__(self, feat_dim): |
| | | super(SimpleAvg, self).__init__() |
| | | self.feat_dim = feat_dim |
| | | |
| | | def forward(self, x, ilens): |
| | | mask = ~make_pad_mask(ilens, maxlen=x.shape[1]).to(x.device) |
| | | avg_x = (x * mask[:, :, None]).sum(1) / mask.sum(-1)[:, None] |
| | | return avg_x |
| | | |
| | | def output_size(self) -> int: |
| | | return self.feat_dim |
| New file |
| | |
| | | from funasr.register import tables |
| | | from funasr.models.whisper_lid.eres2net.ResNet import ERes2Net, BasicBlockERes2Net, BasicBlockERes2Net_diff_AFF |
| | | |
| | | |
| | | @tables.register("lid_predictor_classes", "LidPredictor") |
| | | class LidPredictor(ERes2Net): |
| | | def __init__(self, |
| | | block=BasicBlockERes2Net, |
| | | block_fuse=BasicBlockERes2Net_diff_AFF, |
| | | num_blocks=[3, 4, 6, 3], |
| | | m_channels=32, |
| | | feat_dim=80, |
| | | embedding_size=192, |
| | | pooling_func='TSTP', |
| | | two_emb_layer=False): |
| | | super(LidPredictor, self).__init__( |
| | | block=block, |
| | | block_fuse=block_fuse, |
| | | num_blocks=num_blocks, |
| | | m_channels=m_channels, |
| | | feat_dim=feat_dim, |
| | | embedding_size=embedding_size, |
| | | pooling_func=pooling_func, |
| | | two_emb_layer=two_emb_layer |
| | | ) |
| New file |
| | |
| | | import logging |
| | | from typing import Union, Dict, List, Tuple, Optional |
| | | |
| | | import time |
| | | import torch |
| | | import numpy as np |
| | | import torch.nn as nn |
| | | from torch.cuda.amp import autocast |
| | | |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
| | | from funasr.models.ctc.ctc import CTC |
| | | from funasr.models.transformer.utils.add_sos_eos import add_sos_eos |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("model_classes", "OpenAIWhisperModel") |
| | | class OpenAIWhisperModel(nn.Module): |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | |
| | | |
| | | def __init__( |
| | | self, |
| | | specaug: str = None, |
| | | specaug_conf: dict = None, |
| | | normalize: str = None, |
| | | normalize_conf: dict = None, |
| | | encoder: str = None, |
| | | encoder_conf: dict = None, |
| | | decoder: str = None, |
| | | decoder_conf: dict = None, |
| | | ctc: str = None, |
| | | ctc_conf: dict = None, |
| | | ctc_weight: float = 0.5, |
| | | interctc_weight: float = 0.0, |
| | | input_size: int = 80, |
| | | vocab_size: int = -1, |
| | | ignore_id: int = -1, |
| | | blank_id: int = 0, |
| | | sos: int = 1, |
| | | eos: int = 2, |
| | | lsm_weight: float = 0.0, |
| | | length_normalized_loss: bool = False, |
| | | report_cer: bool = True, |
| | | report_wer: bool = True, |
| | | sym_space: str = "<space>", |
| | | sym_blank: str = "<blank>", |
| | | # extract_feats_in_collect_stats: bool = True, |
| | | share_embedding: bool = False, |
| | | # preencoder: Optional[AbsPreEncoder] = None, |
| | | # postencoder: Optional[AbsPostEncoder] = None, |
| | | **kwargs, |
| | | ): |
| | | |
| | | super().__init__() |
| | | |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**specaug_conf) |
| | | if normalize is not None: |
| | | normalize_class = tables.normalize_classes.get(normalize) |
| | | normalize = normalize_class(**normalize_conf) |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(input_size=input_size, **encoder_conf) |
| | | encoder_output_size = encoder.output_size() |
| | | if decoder is not None: |
| | | decoder_class = tables.decoder_classes.get(decoder) |
| | | decoder = decoder_class(decoder_conf) |
| | | if ctc_weight > 0.0: |
| | | |
| | | if ctc_conf is None: |
| | | ctc_conf = {} |
| | | |
| | | ctc = CTC( |
| | | odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf |
| | | ) |
| | | |
| | | self.blank_id = blank_id |
| | | self.sos = sos if sos is not None else vocab_size - 1 |
| | | self.eos = eos if eos is not None else vocab_size - 1 |
| | | self.vocab_size = vocab_size |
| | | self.ignore_id = ignore_id |
| | | self.ctc_weight = ctc_weight |
| | | self.specaug = specaug |
| | | self.normalize = normalize |
| | | self.encoder = encoder |
| | | |
| | | if not hasattr(self.encoder, "interctc_use_conditioning"): |
| | | self.encoder.interctc_use_conditioning = False |
| | | if self.encoder.interctc_use_conditioning: |
| | | self.encoder.conditioning_layer = torch.nn.Linear( |
| | | vocab_size, self.encoder.output_size() |
| | | ) |
| | | self.interctc_weight = interctc_weight |
| | | |
| | | # self.error_calculator = None |
| | | if ctc_weight == 1.0: |
| | | self.decoder = None |
| | | else: |
| | | self.decoder = decoder |
| | | |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=vocab_size, |
| | | padding_idx=ignore_id, |
| | | smoothing=lsm_weight, |
| | | normalize_length=length_normalized_loss, |
| | | ) |
| | | # |
| | | # if report_cer or report_wer: |
| | | # self.error_calculator = ErrorCalculator( |
| | | # token_list, sym_space, sym_blank, report_cer, report_wer |
| | | # ) |
| | | # |
| | | self.error_calculator = None |
| | | if ctc_weight == 0.0: |
| | | self.ctc = None |
| | | else: |
| | | self.ctc = ctc |
| | | |
| | | self.share_embedding = share_embedding |
| | | if self.share_embedding: |
| | | self.decoder.embed = None |
| | | |
| | | self.length_normalized_loss = length_normalized_loss |
| | | self.beam_search = None |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Encoder + Decoder + Calc loss |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | """ |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size = speech.shape[0] |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | |
| | | loss_att, acc_att, cer_att, wer_att = None, None, None, None |
| | | loss_ctc, cer_ctc = None, None |
| | | stats = dict() |
| | | |
| | | # decoder: CTC branch |
| | | if self.ctc_weight != 0.0: |
| | | loss_ctc, cer_ctc = self._calc_ctc_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | # Collect CTC branch stats |
| | | stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None |
| | | stats["cer_ctc"] = cer_ctc |
| | | |
| | | # Intermediate CTC (optional) |
| | | loss_interctc = 0.0 |
| | | if self.interctc_weight != 0.0 and intermediate_outs is not None: |
| | | for layer_idx, intermediate_out in intermediate_outs: |
| | | # we assume intermediate_out has the same length & padding |
| | | # as those of encoder_out |
| | | loss_ic, cer_ic = self._calc_ctc_loss( |
| | | intermediate_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | loss_interctc = loss_interctc + loss_ic |
| | | |
| | | # Collect Intermedaite CTC stats |
| | | stats["loss_interctc_layer{}".format(layer_idx)] = ( |
| | | loss_ic.detach() if loss_ic is not None else None |
| | | ) |
| | | stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic |
| | | |
| | | loss_interctc = loss_interctc / len(intermediate_outs) |
| | | |
| | | # calculate whole encoder loss |
| | | loss_ctc = ( |
| | | 1 - self.interctc_weight |
| | | ) * loss_ctc + self.interctc_weight * loss_interctc |
| | | |
| | | # decoder: Attention decoder branch |
| | | loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | # 3. CTC-Att loss definition |
| | | if self.ctc_weight == 0.0: |
| | | loss = loss_att |
| | | elif self.ctc_weight == 1.0: |
| | | loss = loss_ctc |
| | | else: |
| | | loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att |
| | | |
| | | # Collect Attn branch stats |
| | | stats["loss_att"] = loss_att.detach() if loss_att is not None else None |
| | | stats["acc"] = acc_att |
| | | stats["cer"] = cer_att |
| | | stats["wer"] = wer_att |
| | | |
| | | # Collect total loss stats |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + 1).sum()) |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | |
| | | def encode( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Frontend + Encoder. Note that this method is used by asr_inference.py |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | ind: int |
| | | """ |
| | | with autocast(False): |
| | | |
| | | # Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | speech, speech_lengths = self.specaug(speech, speech_lengths) |
| | | |
| | | # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN |
| | | if self.normalize is not None: |
| | | speech, speech_lengths = self.normalize(speech, speech_lengths) |
| | | |
| | | # Forward encoder |
| | | # feats: (Batch, Length, Dim) |
| | | # -> encoder_out: (Batch, Length2, Dim2) |
| | | if self.encoder.interctc_use_conditioning: |
| | | encoder_out, encoder_out_lens, _ = self.encoder( |
| | | speech, speech_lengths, ctc=self.ctc |
| | | ) |
| | | else: |
| | | encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | |
| | | if intermediate_outs is not None: |
| | | return (encoder_out, intermediate_outs), encoder_out_lens |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | def _calc_att_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | ): |
| | | ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) |
| | | ys_in_lens = ys_pad_lens + 1 |
| | | |
| | | # 1. Forward decoder |
| | | decoder_out, _ = self.decoder( |
| | | encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens |
| | | ) |
| | | |
| | | # 2. Compute attention loss |
| | | loss_att = self.criterion_att(decoder_out, ys_out_pad) |
| | | acc_att = th_accuracy( |
| | | decoder_out.view(-1, self.vocab_size), |
| | | ys_out_pad, |
| | | ignore_label=self.ignore_id, |
| | | ) |
| | | |
| | | # Compute cer/wer using attention-decoder |
| | | if self.training or self.error_calculator is None: |
| | | cer_att, wer_att = None, None |
| | | else: |
| | | ys_hat = decoder_out.argmax(dim=-1) |
| | | cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) |
| | | |
| | | return loss_att, acc_att, cer_att, wer_att |
| | | |
| | | def _calc_ctc_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | ): |
| | | # Calc CTC loss |
| | | loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) |
| | | |
| | | # Calc CER using CTC |
| | | cer_ctc = None |
| | | if not self.training and self.error_calculator is not None: |
| | | ys_hat = self.ctc.argmax(encoder_out).data |
| | | cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) |
| | | return loss_ctc, cer_ctc |
| | | |
| | | def init_beam_search(self, |
| | | **kwargs, |
| | | ): |
| | | from funasr.models.transformer.search import BeamSearch |
| | | from funasr.models.transformer.scorers.ctc import CTCPrefixScorer |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | | |
| | | if self.ctc != None: |
| | | ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos) |
| | | scorers.update( |
| | | ctc=ctc |
| | | ) |
| | | token_list = kwargs.get("token_list") |
| | | scorers.update( |
| | | decoder=self.decoder, |
| | | length_bonus=LengthBonus(len(token_list)), |
| | | ) |
| | | |
| | | |
| | | # 3. Build ngram model |
| | | # ngram is not supported now |
| | | ngram = None |
| | | scorers["ngram"] = ngram |
| | | |
| | | weights = dict( |
| | | decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.5), |
| | | ctc=kwargs.get("decoding_ctc_weight", 0.5), |
| | | lm=kwargs.get("lm_weight", 0.0), |
| | | ngram=kwargs.get("ngram_weight", 0.0), |
| | | length_bonus=kwargs.get("penalty", 0.0), |
| | | ) |
| | | beam_search = BeamSearch( |
| | | beam_size=kwargs.get("beam_size", 10), |
| | | weights=weights, |
| | | scorers=scorers, |
| | | sos=self.sos, |
| | | eos=self.eos, |
| | | vocab_size=len(token_list), |
| | | token_list=token_list, |
| | | pre_beam_score_key=None if self.ctc_weight == 1.0 else "full", |
| | | ) |
| | | |
| | | self.beam_search = beam_search |
| | | |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list=None, |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | # init beamsearch |
| | | if self.beam_search is None: |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | meta_data = {} |
| | | if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | | if speech_lengths is None: |
| | | speech_lengths = speech.shape[1] |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 |
| | | |
| | | speech = speech.to(device=kwargs["device"]) |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | # Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | if isinstance(encoder_out, tuple): |
| | | encoder_out = encoder_out[0] |
| | | |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search( |
| | | x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0) |
| | | ) |
| | | |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | |
| | | |
| | | results = [] |
| | | b, n, d = encoder_out.size() |
| | | for i in range(b): |
| | | |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # remove blank symbol id, which is assumed to be 0 |
| | | token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) |
| | | |
| | | # Change integer-ids to tokens |
| | | token = tokenizer.ids2tokens(token_int) |
| | | text = tokenizer.tokens2text(token) |
| | | |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | result_i = {"key": key[i], "token": token, "text": text_postprocessed} |
| | | results.append(result_i) |
| | | |
| | | if ibest_writer is not None: |
| | | ibest_writer["token"][key[i]] = " ".join(token) |
| | | ibest_writer["text"][key[i]] = text_postprocessed |
| | | |
| | | return results, meta_data |
| | | |
| | | |
| | | @tables.register("model_classes", "OpenAIWhisperLIDModel") |
| | | class OpenAIWhisperLIDModel(nn.Module): |
| | | """WhisperEncoder and EResNet based LID Model""" |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | specaug: str = None, |
| | | specaug_conf: dict = None, |
| | | encoder: str = None, |
| | | encoder_conf: dict = None, |
| | | lid_predictor: str = None, |
| | | lid_predictor_conf: dict = None, |
| | | proj_dim: int = None, |
| | | clip_frames: int = None, |
| | | random_clip: bool = False, |
| | | **kwargs, |
| | | ): |
| | | super().__init__() |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**specaug_conf) |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(**encoder_conf) |
| | | lid_predictor_class = tables.lid_predictor_classes.get(lid_predictor) |
| | | lid_predictor = lid_predictor_class(**lid_predictor_conf) |
| | | if encoder.output_size() != proj_dim: |
| | | self.proj_layer = torch.nn.Linear(encoder.output_size(), proj_dim) |
| | | else: |
| | | self.proj_layer = None |
| | | self.output_layer = torch.nn.Linear(lid_predictor.output_size(), vocab_size) |
| | | self.criterion_lid = LabelSmoothingLoss( |
| | | size=vocab_size, |
| | | padding_idx=-1, |
| | | smoothing=0.0, |
| | | normalize_length=False, |
| | | ) |
| | | |
| | | self.specaug = specaug |
| | | self.encoder = encoder |
| | | self.lid_predictor = lid_predictor |
| | | self.clip_frames = clip_frames |
| | | self.random_clip = random_clip |
| | | self.normalize = None |
| | | self.beam_search = None |
| | | if not hasattr(self.encoder, "interctc_use_conditioning"): |
| | | self.encoder.interctc_use_conditioning = False |
| | | |
| | | def forward(self, |
| | | speech: torch.Tensor, # may be padding |
| | | speech_lengths: torch.Tensor, # actual length |
| | | lid: torch.Tensor, # lid label, (batch_size, 1) |
| | | lid_lengths: torch.Tensor, |
| | | ): |
| | | assert lid.shape[1] == 1 |
| | | batch_size = speech.shape[0] |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | # re-generate encoder_out |
| | | if self.clip_frames is None: |
| | | reduced_encoder_out = torch.zeros(batch_size, encoder_out_lens.max(), encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device) |
| | | for i, enc_length in enumerate(encoder_out_lens): |
| | | reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length] |
| | | else: |
| | | reduced_encoder_out = torch.zeros(batch_size, self.clip_frames, encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device) |
| | | if self.random_clip: |
| | | for i, enc_length in enumerate(encoder_out_lens): |
| | | if enc_length <= self.clip_frames: |
| | | reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length] |
| | | encoder_out_lens[i] = enc_length |
| | | else: |
| | | max_start_index = enc_length.item() - self.clip_frames |
| | | start_index = np.random.randint(0, max_start_index + 1) |
| | | reduced_encoder_out[i, :self.clip_frames] = encoder_out[i, start_index:start_index + self.clip_frames] |
| | | encoder_out_lens[i] = self.clip_frames |
| | | else: |
| | | for i, enc_length in enumerate(encoder_out_lens): |
| | | enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length |
| | | reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length] |
| | | encoder_out_lens[i] = enc_length |
| | | if self.proj_layer is not None: |
| | | reduced_encoder_out = self.proj_layer(reduced_encoder_out) |
| | | lid_output = self.lid_predictor(reduced_encoder_out, encoder_out_lens) # (B, D) |
| | | lid_logits = self.output_layer(lid_output) # (B, num_classes) |
| | | loss = self.criterion_lid(lid_logits[:, None, :], lid) |
| | | with torch.no_grad(): |
| | | _, predicted_lid = torch.max(lid_logits, 1) |
| | | correct = (predicted_lid == lid[:, 0]).sum().item() |
| | | lid_acc = correct * 1.0 / lid_logits.shape[0] |
| | | stats = dict() |
| | | stats["batch_size"] = batch_size |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | stats["acc"] = lid_acc |
| | | stats["token_length"] = speech_lengths.max() |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def encode( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Frontend + Encoder. Note that this method is used by asr_inference.py |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | """ |
| | | with autocast(False): |
| | | |
| | | # Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | speech = speech.permute(0, 2, 1) |
| | | # suit for whisper padding |
| | | padded_speech_lengths = torch.ones_like(speech_lengths) * speech.shape[1] |
| | | speech, padded_speech_lengths = self.specaug(speech, padded_speech_lengths) |
| | | speech = speech.permute(0, 2, 1) |
| | | |
| | | # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN |
| | | if self.normalize is not None: |
| | | speech, speech_lengths = self.normalize(speech, speech_lengths) |
| | | |
| | | # Forward encoder |
| | | # feats: (Batch, Length, Dim) |
| | | # -> encoder_out: (Batch, Length2, Dim2) |
| | | if self.encoder.interctc_use_conditioning: |
| | | encoder_out, encoder_out_lens, _ = self.encoder( |
| | | speech, speech_lengths, ctc=self.ctc |
| | | ) |
| | | else: |
| | | encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | |
| | | if intermediate_outs is not None: |
| | | return (encoder_out, intermediate_outs), encoder_out_lens |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | meta_data = {} |
| | | if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | | if speech_lengths is None: |
| | | speech_lengths = speech.shape[1] |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 |
| | | |
| | | speech = speech.to(device=kwargs["device"]) |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | # Encoder |
| | | enc, enc_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | inference_clip_length = kwargs.get("inference_clip_length", None) |
| | | if self.clip_frames is not None: |
| | | if inference_clip_length is None: |
| | | reduced_enc = torch.zeros(enc.shape[0], self.clip_frames, enc.shape[-1]).to(enc.dtype).to(enc.device) |
| | | for i, enc_length in enumerate(enc_out_lens): |
| | | enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length |
| | | reduced_enc[i, :enc_length] = enc[i, :enc_length] |
| | | enc_out_lens[i] = enc_length |
| | | else: |
| | | assert inference_clip_length > 0, "inference_clip_length must be larger than 0" |
| | | reduced_enc = torch.zeros(enc.shape[0], inference_clip_length, enc.shape[-1]).to(enc.dtype).to(enc.device) |
| | | for i, enc_length in enumerate(enc_out_lens): |
| | | enc_length = inference_clip_length if enc_length >= inference_clip_length else enc_length |
| | | reduced_enc[i, :enc_length] = enc[i, :enc_length] |
| | | enc_out_lens[i] = enc_length |
| | | else: |
| | | reduced_enc = torch.zeros(enc.shape[0], enc_out_lens.max(), enc.shape[-1]).to(enc.dtype).to(enc.device) |
| | | for i, enc_length in enumerate(enc_out_lens): |
| | | reduced_enc[i, :enc_length] = enc[i, :enc_length] |
| | | |
| | | if self.proj_layer is not None: |
| | | reduced_enc = self.proj_layer(reduced_enc) |
| | | lid_output = self.lid_predictor(reduced_enc, enc_out_lens) # (B, D) |
| | | lid_logits = self.output_layer(lid_output) # (B, num_classes) |
| | | |
| | | _, predicted_lid_index = torch.max(lid_logits, 1) |
| | | predicted_lid = tokenizer.ids2tokens([predicted_lid_index[0].cpu()])[0] |
| | | |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | lid_writer = self.writer["lid"] |
| | | lid_writer[key[0]] = predicted_lid |
| | | |
| | | results = [{"key": key[0], "lid": predicted_lid}] |
| | | |
| | | return results, meta_data |