Merge pull request #328 from alibaba-damo-academy/dev_cmz2
Dev cmz2
18个文件已修改
4个文件已删除
12个文件已添加
6 文件已重命名
| | |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | |
| | | |
| | | header_colors = '\033[95m' |
| | | end_colors = '\033[0m' |
| | | |
| | | |
| | | class Speech2Text: |
| | | """Speech2Text class |
| | |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.punctuation.text_preprocessor import split_to_mini_sentence |
| | | from funasr.datasets.preprocessor import split_to_mini_sentence |
| | | |
| | | |
| | | class Text2Punc: |
| | |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.punctuation.text_preprocessor import split_to_mini_sentence |
| | | from funasr.datasets.preprocessor import split_to_mini_sentence |
| | | |
| | | |
| | | class Text2Punc: |
| | |
| | | data[self.vad_name] = np.array([vad], dtype=np.int64) |
| | | text_ints = self.token_id_converter[i].tokens2ids(tokens) |
| | | data[text_name] = np.array(text_ints, dtype=np.int64) |
| | | |
| | | |
| | | def split_to_mini_sentence(words: list, word_limit: int = 20): |
| | | assert word_limit > 1 |
| | | if len(words) <= word_limit: |
| | | return [words] |
| | | sentences = [] |
| | | length = len(words) |
| | | sentence_len = length // word_limit |
| | | for i in range(sentence_len): |
| | | sentences.append(words[i * word_limit:(i + 1) * word_limit]) |
| | | if length % word_limit > 0: |
| | | sentences.append(words[sentence_len * word_limit:]) |
| | | return sentences |
| | |
| | | |
| | | def export(self, |
| | | tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', |
| | | mode: str = 'paraformer', |
| | | mode: str = None, |
| | | ): |
| | | |
| | | model_dir = tag_name |
| | | if model_dir.startswith('damo/'): |
| | | if model_dir.startswith('damo'): |
| | | from modelscope.hub.snapshot_download import snapshot_download |
| | | model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir) |
| | | asr_train_config = os.path.join(model_dir, 'config.yaml') |
| | | asr_model_file = os.path.join(model_dir, 'model.pb') |
| | | cmvn_file = os.path.join(model_dir, 'am.mvn') |
| | | json_file = os.path.join(model_dir, 'configuration.json') |
| | | |
| | | if mode is None: |
| | | import json |
| | | json_file = os.path.join(model_dir, 'configuration.json') |
| | | with open(json_file, 'r') as f: |
| | | config_data = json.load(f) |
| | | mode = config_data['model']['model_config']['mode'] |
| | | if config_data['task'] == "punctuation": |
| | | mode = config_data['model']['punc_model_config']['mode'] |
| | | else: |
| | | mode = config_data['model']['model_config']['mode'] |
| | | if mode.startswith('paraformer'): |
| | | from funasr.tasks.asr import ASRTaskParaformer as ASRTask |
| | | elif mode.startswith('uniasr'): |
| | | from funasr.tasks.asr import ASRTaskUniASR as ASRTask |
| | | config = os.path.join(model_dir, 'config.yaml') |
| | | model_file = os.path.join(model_dir, 'model.pb') |
| | | cmvn_file = os.path.join(model_dir, 'am.mvn') |
| | | model, asr_train_args = ASRTask.build_model_from_file( |
| | | config, model_file, cmvn_file, 'cpu' |
| | | ) |
| | | self.frontend = model.frontend |
| | | elif mode.startswith('offline'): |
| | | from funasr.tasks.vad import VADTask |
| | | config = os.path.join(model_dir, 'vad.yaml') |
| | | model_file = os.path.join(model_dir, 'vad.pb') |
| | | cmvn_file = os.path.join(model_dir, 'vad.mvn') |
| | | |
| | | model, asr_train_args = ASRTask.build_model_from_file( |
| | | asr_train_config, asr_model_file, cmvn_file, 'cpu' |
| | | ) |
| | | self.frontend = model.frontend |
| | | model, vad_infer_args = VADTask.build_model_from_file( |
| | | config, model_file, cmvn_file=cmvn_file, device='cpu' |
| | | ) |
| | | self.export_config["feats_dim"] = 400 |
| | | self.frontend = model.frontend |
| | | elif mode.startswith('punc'): |
| | | from funasr.tasks.punctuation import PunctuationTask as PUNCTask |
| | | punc_train_config = os.path.join(model_dir, 'config.yaml') |
| | | punc_model_file = os.path.join(model_dir, 'punc.pb') |
| | | model, punc_train_args = PUNCTask.build_model_from_file( |
| | | punc_train_config, punc_model_file, 'cpu' |
| | | ) |
| | | elif mode.startswith('punc_VadRealtime'): |
| | | from funasr.tasks.punctuation import PunctuationTask as PUNCTask |
| | | punc_train_config = os.path.join(model_dir, 'config.yaml') |
| | | punc_model_file = os.path.join(model_dir, 'punc.pb') |
| | | model, punc_train_args = PUNCTask.build_model_from_file( |
| | | punc_train_config, punc_model_file, 'cpu' |
| | | ) |
| | | self._export(model, tag_name) |
| | | |
| | | |
| | |
| | | from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer |
| | | from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export |
| | | from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export |
| | | from funasr.models.e2e_uni_asr import UniASR |
| | | |
| | | from funasr.models.e2e_vad import E2EVadModel |
| | | from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export |
| | | from funasr.models.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.export.models.target_delay_transformer import CT_Transformer as CT_Transformer_export |
| | | from funasr.train.abs_model import PunctuationModel |
| | | from funasr.models.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.export.models.target_delay_transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export |
| | | |
| | | def get_model(model, export_config=None): |
| | | if isinstance(model, BiCifParaformer): |
| | | return BiCifParaformer_export(model, **export_config) |
| | | elif isinstance(model, Paraformer): |
| | | return Paraformer_export(model, **export_config) |
| | | elif isinstance(model, E2EVadModel): |
| | | return E2EVadModel_export(model, **export_config) |
| | | elif isinstance(model, PunctuationModel): |
| | | if isinstance(model.punc_model, TargetDelayTransformer): |
| | | return CT_Transformer_export(model.punc_model, **export_config) |
| | | elif isinstance(model.punc_model, VadRealtimeTransformer): |
| | | return CT_Transformer_VadRealtime_export(model.punc_model, **export_config) |
| | | else: |
| | | raise "Funasr does not support the given model type currently." |
| | | raise "Funasr does not support the given model type currently." |
| New file |
| | |
| | | from enum import Enum |
| | | from typing import List, Tuple, Dict, Any |
| | | |
| | | import torch |
| | | from torch import nn |
| | | import math |
| | | |
| | | from funasr.models.encoder.fsmn_encoder import FSMN |
| | | from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export |
| | | |
| | | class E2EVadModel(nn.Module): |
| | | def __init__(self, model, |
| | | max_seq_len=512, |
| | | feats_dim=400, |
| | | model_name='model', |
| | | **kwargs,): |
| | | super(E2EVadModel, self).__init__() |
| | | self.feats_dim = feats_dim |
| | | self.max_seq_len = max_seq_len |
| | | self.model_name = model_name |
| | | if isinstance(model.encoder, FSMN): |
| | | self.encoder = FSMN_export(model.encoder) |
| | | else: |
| | | raise "unsupported encoder" |
| | | |
| | | |
| | | def forward(self, feats: torch.Tensor, *args, ): |
| | | |
| | | scores, out_caches = self.encoder(feats, *args) |
| | | return scores, out_caches |
| | | |
| | | def get_dummy_inputs(self, frame=30): |
| | | speech = torch.randn(1, frame, self.feats_dim) |
| | | in_cache0 = torch.randn(1, 128, 19, 1) |
| | | in_cache1 = torch.randn(1, 128, 19, 1) |
| | | in_cache2 = torch.randn(1, 128, 19, 1) |
| | | in_cache3 = torch.randn(1, 128, 19, 1) |
| | | |
| | | return (speech, in_cache0, in_cache1, in_cache2, in_cache3) |
| | | |
| | | # def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"): |
| | | # import numpy as np |
| | | # fbank = np.loadtxt(txt_file) |
| | | # fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32) |
| | | # speech = torch.from_numpy(fbank[None, :, :].astype(np.float32)) |
| | | # speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32)) |
| | | # return (speech, speech_lengths) |
| | | |
| | | def get_input_names(self): |
| | | return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3'] |
| | | |
| | | def get_output_names(self): |
| | | return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3'] |
| | | |
| | | def get_dynamic_axes(self): |
| | | return { |
| | | 'speech': { |
| | | 1: 'feats_length' |
| | | }, |
| | | } |
| New file |
| | |
| | | from typing import Tuple, Dict |
| | | import copy |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | from funasr.models.encoder.fsmn_encoder import BasicBlock |
| | | |
| | | class LinearTransform(nn.Module): |
| | | |
| | | def __init__(self, input_dim, output_dim): |
| | | super(LinearTransform, self).__init__() |
| | | self.input_dim = input_dim |
| | | self.output_dim = output_dim |
| | | self.linear = nn.Linear(input_dim, output_dim, bias=False) |
| | | |
| | | def forward(self, input): |
| | | output = self.linear(input) |
| | | |
| | | return output |
| | | |
| | | |
| | | class AffineTransform(nn.Module): |
| | | |
| | | def __init__(self, input_dim, output_dim): |
| | | super(AffineTransform, self).__init__() |
| | | self.input_dim = input_dim |
| | | self.output_dim = output_dim |
| | | self.linear = nn.Linear(input_dim, output_dim) |
| | | |
| | | def forward(self, input): |
| | | output = self.linear(input) |
| | | |
| | | return output |
| | | |
| | | |
| | | class RectifiedLinear(nn.Module): |
| | | |
| | | def __init__(self, input_dim, output_dim): |
| | | super(RectifiedLinear, self).__init__() |
| | | self.dim = input_dim |
| | | self.relu = nn.ReLU() |
| | | self.dropout = nn.Dropout(0.1) |
| | | |
| | | def forward(self, input): |
| | | out = self.relu(input) |
| | | return out |
| | | |
| | | |
| | | class FSMNBlock(nn.Module): |
| | | |
| | | def __init__( |
| | | self, |
| | | input_dim: int, |
| | | output_dim: int, |
| | | lorder=None, |
| | | rorder=None, |
| | | lstride=1, |
| | | rstride=1, |
| | | ): |
| | | super(FSMNBlock, self).__init__() |
| | | |
| | | self.dim = input_dim |
| | | |
| | | if lorder is None: |
| | | return |
| | | |
| | | self.lorder = lorder |
| | | self.rorder = rorder |
| | | self.lstride = lstride |
| | | self.rstride = rstride |
| | | |
| | | self.conv_left = nn.Conv2d( |
| | | self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False) |
| | | |
| | | if self.rorder > 0: |
| | | self.conv_right = nn.Conv2d( |
| | | self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False) |
| | | else: |
| | | self.conv_right = None |
| | | |
| | | def forward(self, input: torch.Tensor, cache: torch.Tensor): |
| | | x = torch.unsqueeze(input, 1) |
| | | x_per = x.permute(0, 3, 2, 1) # B D T C |
| | | |
| | | cache = cache.to(x_per.device) |
| | | y_left = torch.cat((cache, x_per), dim=2) |
| | | cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :] |
| | | y_left = self.conv_left(y_left) |
| | | out = x_per + y_left |
| | | |
| | | if self.conv_right is not None: |
| | | # maybe need to check |
| | | y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride]) |
| | | y_right = y_right[:, :, self.rstride:, :] |
| | | y_right = self.conv_right(y_right) |
| | | out += y_right |
| | | |
| | | out_per = out.permute(0, 3, 2, 1) |
| | | output = out_per.squeeze(1) |
| | | |
| | | return output, cache |
| | | |
| | | |
| | | class BasicBlock_export(nn.Module): |
| | | def __init__(self, |
| | | model, |
| | | ): |
| | | super(BasicBlock_export, self).__init__() |
| | | self.linear = model.linear |
| | | self.fsmn_block = model.fsmn_block |
| | | self.affine = model.affine |
| | | self.relu = model.relu |
| | | |
| | | def forward(self, input: torch.Tensor, in_cache: torch.Tensor): |
| | | x = self.linear(input) # B T D |
| | | # cache_layer_name = 'cache_layer_{}'.format(self.stack_layer) |
| | | # if cache_layer_name not in in_cache: |
| | | # in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1) |
| | | x, out_cache = self.fsmn_block(x, in_cache) |
| | | x = self.affine(x) |
| | | x = self.relu(x) |
| | | return x, out_cache |
| | | |
| | | |
| | | # class FsmnStack(nn.Sequential): |
| | | # def __init__(self, *args): |
| | | # super(FsmnStack, self).__init__(*args) |
| | | # |
| | | # def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]): |
| | | # x = input |
| | | # for module in self._modules.values(): |
| | | # x = module(x, in_cache) |
| | | # return x |
| | | |
| | | |
| | | ''' |
| | | FSMN net for keyword spotting |
| | | input_dim: input dimension |
| | | linear_dim: fsmn input dimensionll |
| | | proj_dim: fsmn projection dimension |
| | | lorder: fsmn left order |
| | | rorder: fsmn right order |
| | | num_syn: output dimension |
| | | fsmn_layers: no. of sequential fsmn layers |
| | | ''' |
| | | |
| | | |
| | | class FSMN(nn.Module): |
| | | def __init__( |
| | | self, model, |
| | | ): |
| | | super(FSMN, self).__init__() |
| | | |
| | | # self.input_dim = input_dim |
| | | # self.input_affine_dim = input_affine_dim |
| | | # self.fsmn_layers = fsmn_layers |
| | | # self.linear_dim = linear_dim |
| | | # self.proj_dim = proj_dim |
| | | # self.output_affine_dim = output_affine_dim |
| | | # self.output_dim = output_dim |
| | | # |
| | | # self.in_linear1 = AffineTransform(input_dim, input_affine_dim) |
| | | # self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) |
| | | # self.relu = RectifiedLinear(linear_dim, linear_dim) |
| | | # self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in |
| | | # range(fsmn_layers)]) |
| | | # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) |
| | | # self.out_linear2 = AffineTransform(output_affine_dim, output_dim) |
| | | # self.softmax = nn.Softmax(dim=-1) |
| | | self.in_linear1 = model.in_linear1 |
| | | self.in_linear2 = model.in_linear2 |
| | | self.relu = model.relu |
| | | # self.fsmn = model.fsmn |
| | | self.out_linear1 = model.out_linear1 |
| | | self.out_linear2 = model.out_linear2 |
| | | self.softmax = model.softmax |
| | | self.fsmn = model.fsmn |
| | | for i, d in enumerate(model.fsmn): |
| | | if isinstance(d, BasicBlock): |
| | | self.fsmn[i] = BasicBlock_export(d) |
| | | |
| | | def fuse_modules(self): |
| | | pass |
| | | |
| | | def forward( |
| | | self, |
| | | input: torch.Tensor, |
| | | *args, |
| | | ): |
| | | """ |
| | | Args: |
| | | input (torch.Tensor): Input tensor (B, T, D) |
| | | in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs, |
| | | {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame |
| | | """ |
| | | |
| | | x = self.in_linear1(input) |
| | | x = self.in_linear2(x) |
| | | x = self.relu(x) |
| | | # x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn |
| | | out_caches = list() |
| | | for i, d in enumerate(self.fsmn): |
| | | in_cache = args[i] |
| | | x, out_cache = d(x, in_cache) |
| | | out_caches.append(out_cache) |
| | | x = self.out_linear1(x) |
| | | x = self.out_linear2(x) |
| | | x = self.softmax(x) |
| | | |
| | | return x, out_caches |
| | | |
| | | |
| | | ''' |
| | | one deep fsmn layer |
| | | dimproj: projection dimension, input and output dimension of memory blocks |
| | | dimlinear: dimension of mapping layer |
| | | lorder: left order |
| | | rorder: right order |
| | | lstride: left stride |
| | | rstride: right stride |
| | | ''' |
| | | |
| | | |
| | | class DFSMN(nn.Module): |
| | | |
| | | def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1): |
| | | super(DFSMN, self).__init__() |
| | | |
| | | self.lorder = lorder |
| | | self.rorder = rorder |
| | | self.lstride = lstride |
| | | self.rstride = rstride |
| | | |
| | | self.expand = AffineTransform(dimproj, dimlinear) |
| | | self.shrink = LinearTransform(dimlinear, dimproj) |
| | | |
| | | self.conv_left = nn.Conv2d( |
| | | dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False) |
| | | |
| | | if rorder > 0: |
| | | self.conv_right = nn.Conv2d( |
| | | dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False) |
| | | else: |
| | | self.conv_right = None |
| | | |
| | | def forward(self, input): |
| | | f1 = F.relu(self.expand(input)) |
| | | p1 = self.shrink(f1) |
| | | |
| | | x = torch.unsqueeze(p1, 1) |
| | | x_per = x.permute(0, 3, 2, 1) |
| | | |
| | | y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) |
| | | |
| | | if self.conv_right is not None: |
| | | y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) |
| | | y_right = y_right[:, :, self.rstride:, :] |
| | | out = x_per + self.conv_left(y_left) + self.conv_right(y_right) |
| | | else: |
| | | out = x_per + self.conv_left(y_left) |
| | | |
| | | out1 = out.permute(0, 3, 2, 1) |
| | | output = input + out1.squeeze(1) |
| | | |
| | | return output |
| | | |
| | | |
| | | ''' |
| | | build stacked dfsmn layers |
| | | ''' |
| | | |
| | | |
| | | def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6): |
| | | repeats = [ |
| | | nn.Sequential( |
| | | DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) |
| | | for i in range(fsmn_layers) |
| | | ] |
| | | |
| | | return nn.Sequential(*repeats) |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599) |
| | | print(fsmn) |
| | | |
| | | num_params = sum(p.numel() for p in fsmn.parameters()) |
| | | print('the number of model params: {}'.format(num_params)) |
| | | x = torch.zeros(128, 200, 400) # batch-size * time * dim |
| | | y, _ = fsmn(x) # batch-size * time * dim |
| | | print('input shape: {}'.format(x.shape)) |
| | | print('output shape: {}'.format(y.shape)) |
| | | |
| | | print(fsmn.to_kaldi_net()) |
| | |
| | | from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward |
| | | from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export |
| | | |
| | | |
| | | class SANMEncoder(nn.Module): |
| | | def __init__( |
| | | self, |
| | |
| | | } |
| | | |
| | | } |
| | | |
| | | |
| | | class SANMVadEncoder(nn.Module): |
| | | def __init__( |
| | | self, |
| | | model, |
| | | max_seq_len=512, |
| | | feats_dim=560, |
| | | model_name='encoder', |
| | | onnx: bool = True, |
| | | ): |
| | | super().__init__() |
| | | self.embed = model.embed |
| | | self.model = model |
| | | self.feats_dim = feats_dim |
| | | self._output_size = model._output_size |
| | | |
| | | if onnx: |
| | | self.make_pad_mask = MakePadMask(max_seq_len, flip=False) |
| | | else: |
| | | self.make_pad_mask = sequence_mask(max_seq_len, flip=False) |
| | | |
| | | if hasattr(model, 'encoders0'): |
| | | for i, d in enumerate(self.model.encoders0): |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANM): |
| | | d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn) |
| | | if isinstance(d.feed_forward, PositionwiseFeedForward): |
| | | d.feed_forward = PositionwiseFeedForward_export(d.feed_forward) |
| | | self.model.encoders0[i] = EncoderLayerSANM_export(d) |
| | | |
| | | for i, d in enumerate(self.model.encoders): |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANM): |
| | | d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn) |
| | | if isinstance(d.feed_forward, PositionwiseFeedForward): |
| | | d.feed_forward = PositionwiseFeedForward_export(d.feed_forward) |
| | | self.model.encoders[i] = EncoderLayerSANM_export(d) |
| | | |
| | | self.model_name = model_name |
| | | self.num_heads = model.encoders[0].self_attn.h |
| | | self.hidden_size = model.encoders[0].self_attn.linear_out.out_features |
| | | |
| | | def prepare_mask(self, mask, sub_masks): |
| | | mask_3d_btd = mask[:, :, None] |
| | | mask_4d_bhlt = (1 - sub_masks) * -10000.0 |
| | | |
| | | return mask_3d_btd, mask_4d_bhlt |
| | | |
| | | def forward(self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | vad_masks: torch.Tensor, |
| | | sub_masks: torch.Tensor, |
| | | ): |
| | | speech = speech * self._output_size ** 0.5 |
| | | mask = self.make_pad_mask(speech_lengths) |
| | | vad_masks = self.prepare_mask(mask, vad_masks) |
| | | mask = self.prepare_mask(mask, sub_masks) |
| | | |
| | | if self.embed is None: |
| | | xs_pad = speech |
| | | else: |
| | | xs_pad = self.embed(speech) |
| | | |
| | | encoder_outs = self.model.encoders0(xs_pad, mask) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | |
| | | # encoder_outs = self.model.encoders(xs_pad, mask) |
| | | for layer_idx, encoder_layer in enumerate(self.model.encoders): |
| | | if layer_idx == len(self.model.encoders) - 1: |
| | | mask = vad_masks |
| | | encoder_outs = encoder_layer(xs_pad, mask) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | |
| | | xs_pad = self.model.after_norm(xs_pad) |
| | | |
| | | return xs_pad, speech_lengths |
| | | |
| | | def get_output_size(self): |
| | | return self.model.encoders[0].size |
| | | |
| | | # def get_dummy_inputs(self): |
| | | # feats = torch.randn(1, 100, self.feats_dim) |
| | | # return (feats) |
| | | # |
| | | # def get_input_names(self): |
| | | # return ['feats'] |
| | | # |
| | | # def get_output_names(self): |
| | | # return ['encoder_out', 'encoder_out_lens', 'predictor_weight'] |
| | | # |
| | | # def get_dynamic_axes(self): |
| | | # return { |
| | | # 'feats': { |
| | | # 1: 'feats_length' |
| | | # }, |
| | | # 'encoder_out': { |
| | | # 1: 'enc_out_length' |
| | | # }, |
| | | # 'predictor_weight': { |
| | | # 1: 'pre_out_length' |
| | | # } |
| | | # |
| | | # } |
| New file |
| | |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.models.encoder.sanm_encoder import SANMEncoder |
| | | from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export |
| | | from funasr.models.encoder.sanm_encoder import SANMVadEncoder |
| | | from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export |
| | | |
| | | class CT_Transformer(nn.Module): |
| | | |
| | | def __init__( |
| | | self, |
| | | model, |
| | | max_seq_len=512, |
| | | model_name='punc_model', |
| | | **kwargs, |
| | | ): |
| | | super().__init__() |
| | | onnx = False |
| | | if "onnx" in kwargs: |
| | | onnx = kwargs["onnx"] |
| | | self.embed = model.embed |
| | | self.decoder = model.decoder |
| | | # self.model = model |
| | | self.feats_dim = self.embed.embedding_dim |
| | | self.num_embeddings = self.embed.num_embeddings |
| | | self.model_name = model_name |
| | | |
| | | if isinstance(model.encoder, SANMEncoder): |
| | | self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) |
| | | else: |
| | | assert False, "Only support samn encode." |
| | | |
| | | def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: |
| | | """Compute loss value from buffer sequences. |
| | | |
| | | Args: |
| | | input (torch.Tensor): Input ids. (batch, len) |
| | | hidden (torch.Tensor): Target ids. (batch, len) |
| | | |
| | | """ |
| | | x = self.embed(inputs) |
| | | # mask = self._target_mask(input) |
| | | h, _ = self.encoder(x, text_lengths) |
| | | y = self.decoder(h) |
| | | return y |
| | | |
| | | def get_dummy_inputs(self): |
| | | length = 120 |
| | | text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)) |
| | | text_lengths = torch.tensor([length-20, length], dtype=torch.int32) |
| | | return (text_indexes, text_lengths) |
| | | |
| | | def get_input_names(self): |
| | | return ['inputs', 'text_lengths'] |
| | | |
| | | def get_output_names(self): |
| | | return ['logits'] |
| | | |
| | | def get_dynamic_axes(self): |
| | | return { |
| | | 'inputs': { |
| | | 0: 'batch_size', |
| | | 1: 'feats_length' |
| | | }, |
| | | 'text_lengths': { |
| | | 0: 'batch_size', |
| | | }, |
| | | 'logits': { |
| | | 0: 'batch_size', |
| | | 1: 'logits_length' |
| | | }, |
| | | } |
| | | |
| | | |
| | | class CT_Transformer_VadRealtime(nn.Module): |
| | | |
| | | def __init__( |
| | | self, |
| | | model, |
| | | max_seq_len=512, |
| | | model_name='punc_model', |
| | | **kwargs, |
| | | ): |
| | | super().__init__() |
| | | onnx = False |
| | | if "onnx" in kwargs: |
| | | onnx = kwargs["onnx"] |
| | | |
| | | self.embed = model.embed |
| | | if isinstance(model.encoder, SANMVadEncoder): |
| | | self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx) |
| | | else: |
| | | assert False, "Only support samn encode." |
| | | self.decoder = model.decoder |
| | | self.model_name = model_name |
| | | |
| | | |
| | | |
| | | def forward(self, inputs: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | vad_indexes: torch.Tensor, |
| | | sub_masks: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, None]: |
| | | """Compute loss value from buffer sequences. |
| | | |
| | | Args: |
| | | input (torch.Tensor): Input ids. (batch, len) |
| | | hidden (torch.Tensor): Target ids. (batch, len) |
| | | |
| | | """ |
| | | x = self.embed(inputs) |
| | | # mask = self._target_mask(input) |
| | | h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks) |
| | | y = self.decoder(h) |
| | | return y |
| | | |
| | | def with_vad(self): |
| | | return True |
| | | |
| | | def get_dummy_inputs(self): |
| | | length = 120 |
| | | text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)) |
| | | text_lengths = torch.tensor([length], dtype=torch.int32) |
| | | vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :] |
| | | sub_masks = torch.ones(length, length, dtype=torch.float32) |
| | | sub_masks = torch.tril(sub_masks).type(torch.float32) |
| | | return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :]) |
| | | |
| | | def get_input_names(self): |
| | | return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks'] |
| | | |
| | | def get_output_names(self): |
| | | return ['logits'] |
| | | |
| | | def get_dynamic_axes(self): |
| | | return { |
| | | 'inputs': { |
| | | 1: 'feats_length' |
| | | }, |
| | | 'vad_masks': { |
| | | 2: 'feats_length1', |
| | | 3: 'feats_length2' |
| | | }, |
| | | 'sub_masks': { |
| | | 2: 'feats_length1', |
| | | 3: 'feats_length2' |
| | | }, |
| | | 'logits': { |
| | | 1: 'logits_length' |
| | | }, |
| | | } |
| New file |
| | |
| | | import onnxruntime |
| | | import numpy as np |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | onnx_path = "../damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.onnx" |
| | | sess = onnxruntime.InferenceSession(onnx_path) |
| | | input_name = [nd.name for nd in sess.get_inputs()] |
| | | output_name = [nd.name for nd in sess.get_outputs()] |
| | | |
| | | def _get_feed_dict(text_length): |
| | | return {'inputs': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32)} |
| | | |
| | | def _run(feed_dict): |
| | | output = sess.run(output_name, input_feed=feed_dict) |
| | | for name, value in zip(output_name, output): |
| | | print('{}: {}'.format(name, value)) |
| | | _run(_get_feed_dict(10)) |
| New file |
| | |
| | | import onnxruntime |
| | | import numpy as np |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | onnx_path = "./export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/model.onnx" |
| | | sess = onnxruntime.InferenceSession(onnx_path) |
| | | input_name = [nd.name for nd in sess.get_inputs()] |
| | | output_name = [nd.name for nd in sess.get_outputs()] |
| | | |
| | | def _get_feed_dict(text_length): |
| | | return {'inputs': np.ones((1, text_length), dtype=np.int64), |
| | | 'text_lengths': np.array([text_length,], dtype=np.int32), |
| | | 'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32), |
| | | 'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32) |
| | | } |
| | | |
| | | def _run(feed_dict): |
| | | output = sess.run(output_name, input_feed=feed_dict) |
| | | for name, value in zip(output_name, output): |
| | | print('{}: {}'.format(name, value)) |
| | | _run(_get_feed_dict(10)) |
| New file |
| | |
| | | import onnxruntime |
| | | import numpy as np |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | onnx_path = "/mnt/workspace/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx" |
| | | sess = onnxruntime.InferenceSession(onnx_path) |
| | | input_name = [nd.name for nd in sess.get_inputs()] |
| | | output_name = [nd.name for nd in sess.get_outputs()] |
| | | |
| | | def _get_feed_dict(feats_length): |
| | | |
| | | return {'speech': np.random.rand(1, feats_length, 400).astype(np.float32), |
| | | 'in_cache0': np.random.rand(1, 128, 19, 1).astype(np.float32), |
| | | 'in_cache1': np.random.rand(1, 128, 19, 1).astype(np.float32), |
| | | 'in_cache2': np.random.rand(1, 128, 19, 1).astype(np.float32), |
| | | 'in_cache3': np.random.rand(1, 128, 19, 1).astype(np.float32), |
| | | } |
| | | |
| | | def _run(feed_dict): |
| | | output = sess.run(output_name, input_feed=feed_dict) |
| | | for name, value in zip(output_name, output): |
| | | print('{}: {}'.format(name, value.shape)) |
| | | |
| | | _run(_get_feed_dict(100)) |
| | | _run(_get_feed_dict(200)) |
| | |
| | | import torch |
| | | |
| | | from funasr.modules.scorers.scorer_interface import BatchScorerInterface |
| | | from typing import Dict |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | |
| | | class AbsLM(torch.nn.Module, BatchScorerInterface, ABC): |
| | | """The abstract LM class |
| | |
| | | self, input: torch.Tensor, hidden: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | raise NotImplementedError |
| | | |
| | | |
| | | class LanguageModel(AbsESPnetModel): |
| | | def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self.lm = lm |
| | | self.sos = 1 |
| | | self.eos = 2 |
| | | |
| | | # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR. |
| | | self.ignore_id = ignore_id |
| | | |
| | | def nll( |
| | | self, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | max_length: Optional[int] = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Compute negative log likelihood(nll) |
| | | |
| | | Normally, this function is called in batchify_nll. |
| | | Args: |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | max_lengths: int |
| | | """ |
| | | batch_size = text.size(0) |
| | | # For data parallel |
| | | if max_length is None: |
| | | text = text[:, : text_lengths.max()] |
| | | else: |
| | | text = text[:, :max_length] |
| | | |
| | | # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>' |
| | | # text: (Batch, Length) -> x, y: (Batch, Length + 1) |
| | | x = F.pad(text, [1, 0], "constant", self.sos) |
| | | t = F.pad(text, [0, 1], "constant", self.ignore_id) |
| | | for i, l in enumerate(text_lengths): |
| | | t[i, l] = self.eos |
| | | x_lengths = text_lengths + 1 |
| | | |
| | | # 2. Forward Language model |
| | | # x: (Batch, Length) -> y: (Batch, Length, NVocab) |
| | | y, _ = self.lm(x, None) |
| | | |
| | | # 3. Calc negative log likelihood |
| | | # nll: (BxL,) |
| | | nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") |
| | | # nll: (BxL,) -> (BxL,) |
| | | if max_length is None: |
| | | nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0) |
| | | else: |
| | | nll.masked_fill_( |
| | | make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1), |
| | | 0.0, |
| | | ) |
| | | # nll: (BxL,) -> (B, L) |
| | | nll = nll.view(batch_size, -1) |
| | | return nll, x_lengths |
| | | |
| | | def batchify_nll( |
| | | self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100 |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Compute negative log likelihood(nll) from transformer language model |
| | | |
| | | To avoid OOM, this fuction seperate the input into batches. |
| | | Then call nll for each batch and combine and return results. |
| | | Args: |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | batch_size: int, samples each batch contain when computing nll, |
| | | you may change this to avoid OOM or increase |
| | | |
| | | """ |
| | | total_num = text.size(0) |
| | | if total_num <= batch_size: |
| | | nll, x_lengths = self.nll(text, text_lengths) |
| | | else: |
| | | nlls = [] |
| | | x_lengths = [] |
| | | max_length = text_lengths.max() |
| | | |
| | | start_idx = 0 |
| | | while True: |
| | | end_idx = min(start_idx + batch_size, total_num) |
| | | batch_text = text[start_idx:end_idx, :] |
| | | batch_text_lengths = text_lengths[start_idx:end_idx] |
| | | # batch_nll: [B * T] |
| | | batch_nll, batch_x_lengths = self.nll( |
| | | batch_text, batch_text_lengths, max_length=max_length |
| | | ) |
| | | nlls.append(batch_nll) |
| | | x_lengths.append(batch_x_lengths) |
| | | start_idx = end_idx |
| | | if start_idx == total_num: |
| | | break |
| | | nll = torch.cat(nlls) |
| | | x_lengths = torch.cat(x_lengths) |
| | | assert nll.size(0) == total_num |
| | | assert x_lengths.size(0) == total_num |
| | | return nll, x_lengths |
| | | |
| | | def forward( |
| | | self, text: torch.Tensor, text_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | nll, y_lengths = self.nll(text, text_lengths) |
| | | ntokens = y_lengths.sum() |
| | | loss = nll.sum() / ntokens |
| | | stats = dict(loss=loss.detach()) |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def collect_feats( |
| | | self, text: torch.Tensor, text_lengths: torch.Tensor |
| | | ) -> Dict[str, torch.Tensor]: |
| | | return {} |
| | |
| | | |
| | | |
| | | class E2EVadModel(nn.Module): |
| | | def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]): |
| | | def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None): |
| | | super(E2EVadModel, self).__init__() |
| | | self.vad_opts = VADXOptions(**vad_post_args) |
| | | self.windows_detector = WindowDetector(self.vad_opts.window_size_ms, |
| | |
| | | self.data_buf_all = None |
| | | self.waveform = None |
| | | self.ResetDetection() |
| | | self.frontend = frontend |
| | | |
| | | def AllResetDetection(self): |
| | | self.is_final = False |
| | |
| | | ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]: |
| | | self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres |
| | | self.waveform = waveform # compute decibel for each frame |
| | | self.ComputeDecibel() |
| | | |
| | | self.ComputeScores(feats, in_cache) |
| | | self.ComputeDecibel() |
| | | if not is_final: |
| | | self.DetectCommonFrames() |
| | | else: |
| | |
| | | from typeguard import check_argument_types |
| | | import numpy as np |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM |
| | | from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | from funasr.modules.multi_layer_conv import Conv1dLinear |
| | |
| | | from funasr.modules.subsampling import check_short_utt |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | |
| | | from funasr.modules.mask import subsequent_mask, vad_mask |
| | | |
| | | class EncoderLayerSANM(nn.Module): |
| | | def __init__( |
| | |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | return var_dict_torch_update |
| | | |
| | | |
| | | class SANMVadEncoder(AbsEncoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | input_layer: Optional[str] = "conv2d", |
| | | pos_enc_class=SinusoidalPositionEncoder, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 1, |
| | | padding_idx: int = -1, |
| | | interctc_layer_idx: List[int] = [], |
| | | interctc_use_conditioning: bool = False, |
| | | kernel_size : int = 11, |
| | | sanm_shfit : int = 0, |
| | | selfattention_layer_type: str = "sanm", |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self._output_size = output_size |
| | | |
| | | if input_layer == "linear": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Linear(input_size, output_size), |
| | | torch.nn.LayerNorm(output_size), |
| | | torch.nn.Dropout(dropout_rate), |
| | | torch.nn.ReLU(), |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "conv2d": |
| | | self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d2": |
| | | self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d6": |
| | | self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d8": |
| | | self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) |
| | | elif input_layer == "embed": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), |
| | | SinusoidalPositionEncoder(), |
| | | ) |
| | | elif input_layer is None: |
| | | if input_size == output_size: |
| | | self.embed = None |
| | | else: |
| | | self.embed = torch.nn.Linear(input_size, output_size) |
| | | elif input_layer == "pe": |
| | | self.embed = SinusoidalPositionEncoder() |
| | | else: |
| | | raise ValueError("unknown input_layer: " + input_layer) |
| | | self.normalize_before = normalize_before |
| | | if positionwise_layer_type == "linear": |
| | | positionwise_layer = PositionwiseFeedForward |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | dropout_rate, |
| | | ) |
| | | elif positionwise_layer_type == "conv1d": |
| | | positionwise_layer = MultiLayeredConv1d |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | positionwise_conv_kernel_size, |
| | | dropout_rate, |
| | | ) |
| | | elif positionwise_layer_type == "conv1d-linear": |
| | | positionwise_layer = Conv1dLinear |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | positionwise_conv_kernel_size, |
| | | dropout_rate, |
| | | ) |
| | | else: |
| | | raise NotImplementedError("Support only linear or conv1d.") |
| | | |
| | | if selfattention_layer_type == "selfattn": |
| | | encoder_selfattn_layer = MultiHeadedAttention |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | ) |
| | | |
| | | elif selfattention_layer_type == "sanm": |
| | | self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask |
| | | encoder_selfattn_layer_args0 = ( |
| | | attention_heads, |
| | | input_size, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | ) |
| | | |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | ) |
| | | |
| | | self.encoders0 = repeat( |
| | | 1, |
| | | lambda lnum: EncoderLayerSANM( |
| | | input_size, |
| | | output_size, |
| | | self.encoder_selfattn_layer(*encoder_selfattn_layer_args0), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | |
| | | self.encoders = repeat( |
| | | num_blocks-1, |
| | | lambda lnum: EncoderLayerSANM( |
| | | output_size, |
| | | output_size, |
| | | self.encoder_selfattn_layer(*encoder_selfattn_layer_args), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | if self.normalize_before: |
| | | self.after_norm = LayerNorm(output_size) |
| | | |
| | | self.interctc_layer_idx = interctc_layer_idx |
| | | if len(interctc_layer_idx) > 0: |
| | | assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks |
| | | self.interctc_use_conditioning = interctc_use_conditioning |
| | | self.conditioning_layer = None |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | vad_indexes: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ctc: CTC = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | """Embed positions in tensor. |
| | | |
| | | Args: |
| | | xs_pad: input tensor (B, L, D) |
| | | ilens: input length (B) |
| | | prev_states: Not to be used now. |
| | | Returns: |
| | | position embedded tensor and mask |
| | | """ |
| | | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) |
| | | sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0) |
| | | no_future_masks = masks & sub_masks |
| | | xs_pad *= self.output_size()**0.5 |
| | | if self.embed is None: |
| | | xs_pad = xs_pad |
| | | elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2) |
| | | or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)): |
| | | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {xs_pad.size(1)} frames and is too short for subsampling " + |
| | | f"(it needs more than {limit_size} frames), return empty results", |
| | | xs_pad.size(1), |
| | | limit_size, |
| | | ) |
| | | xs_pad, masks = self.embed(xs_pad, masks) |
| | | else: |
| | | xs_pad = self.embed(xs_pad) |
| | | |
| | | # xs_pad = self.dropout(xs_pad) |
| | | mask_tup0 = [masks, no_future_masks] |
| | | encoder_outs = self.encoders0(xs_pad, mask_tup0) |
| | | xs_pad, _ = encoder_outs[0], encoder_outs[1] |
| | | intermediate_outs = [] |
| | | |
| | | |
| | | for layer_idx, encoder_layer in enumerate(self.encoders): |
| | | if layer_idx + 1 == len(self.encoders): |
| | | # This is last layer. |
| | | coner_mask = torch.ones(masks.size(0), |
| | | masks.size(-1), |
| | | masks.size(-1), |
| | | device=xs_pad.device, |
| | | dtype=torch.bool) |
| | | for word_index, length in enumerate(ilens): |
| | | coner_mask[word_index, :, :] = vad_mask(masks.size(-1), |
| | | vad_indexes[word_index], |
| | | device=xs_pad.device) |
| | | layer_mask = masks & coner_mask |
| | | else: |
| | | layer_mask = no_future_masks |
| | | mask_tup1 = [masks, layer_mask] |
| | | encoder_outs = encoder_layer(xs_pad, mask_tup1) |
| | | xs_pad, layer_mask = encoder_outs[0], encoder_outs[1] |
| | | |
| | | if self.normalize_before: |
| | | xs_pad = self.after_norm(xs_pad) |
| | | |
| | | olens = masks.squeeze(1).sum(1) |
| | | if len(intermediate_outs) > 0: |
| | | return (xs_pad, intermediate_outs), olens, None |
| | | return xs_pad, olens, None |
| File was renamed from funasr/punctuation/target_delay_transformer.py |
| | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.modules.embedding import PositionalEncoding |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder |
| | | #from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder |
| | | from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder |
| | | from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder |
| | | #from funasr.modules.mask import subsequent_n_mask |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.train.abs_model import AbsPunctuation |
| | | |
| | | |
| | | class TargetDelayTransformer(AbsPunctuation): |
| File was renamed from funasr/punctuation/vad_realtime_transformer.py |
| | |
| | | import torch.nn as nn |
| | | |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder |
| | | from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.models.encoder.sanm_encoder import SANMVadEncoder as Encoder |
| | | from funasr.train.abs_model import AbsPunctuation |
| | | |
| | | |
| | | class VadRealtimeTransformer(AbsPunctuation): |
| | |
| | | |
| | | |
| | | @functools.lru_cache() |
| | | def get_logger(name='torch_paraformer'): |
| | | def get_logger(name='funasr_torch'): |
| | | """Initialize and get a logger by name. |
| | | If the logger has not been initialized, this method will initialize the |
| | | logger by adding one or two handlers, otherwise the initialized logger will |
| | |
| | | from funasr_onnx import Paraformer |
| | | |
| | | |
| | | model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" |
| | | |
| | | model = Paraformer(model_dir, batch_size=2, plot_timestamp_to="./", pred_bias=0) # cpu |
| | |
| | | |
| | | # when using paraformer-large-vad-punc model, you can set plot_timestamp_to="./xx.png" to get figure of alignment besides timestamps |
| | | # model = Paraformer(model_dir, batch_size=1, plot_timestamp_to="test.png") |
| | | |
| | | |
| | | wav_path = "YourPath/xx.wav" |
| | | |
| New file |
| | |
| | | from funasr_onnx import CT_Transformer |
| | | |
| | | model_dir = "../../../export/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" |
| | | model = CT_Transformer(model_dir) |
| | | |
| | | text_in="跨境河流是养育沿岸人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切愿意进一步完善双方联合工作机制凡是中方能做的我们都会去做而且会做得更好我请印度朋友们放心中国在上游的任何开发利用都会经过科学规划和论证兼顾上下游的利益" |
| | | result = model(text_in) |
| | | print(result[0]) |
| New file |
| | |
| | | from funasr_onnx import CT_Transformer_VadRealtime |
| | | |
| | | model_dir = "../../../export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727" |
| | | model = CT_Transformer_VadRealtime(model_dir) |
| | | |
| | | text_in = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益" |
| | | |
| | | vads = text_in.split("|") |
| | | rec_result_all="" |
| | | param_dict = {"cache": []} |
| | | for vad in vads: |
| | | result = model(vad, param_dict=param_dict) |
| | | rec_result_all += result[0] |
| | | |
| | | print(rec_result_all) |
| New file |
| | |
| | | import soundfile |
| | | from funasr_onnx import Fsmn_vad |
| | | |
| | | |
| | | model_dir = "/Users/zhifu/Downloads/speech_fsmn_vad_zh-cn-16k-common-pytorch" |
| | | wav_path = "/Users/zhifu/Downloads/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav" |
| | | model = Fsmn_vad(model_dir) |
| | | |
| | | #offline vad |
| | | # result = model(wav_path) |
| | | # print(result) |
| | | |
| | | #online vad |
| | | speech, sample_rate = soundfile.read(wav_path) |
| | | speech_length = speech.shape[0] |
| | | |
| | | sample_offset = 0 |
| | | step = 160 * 10 |
| | | param_dict = {'in_cache': []} |
| | | for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)): |
| | | if sample_offset + step >= speech_length - 1: |
| | | step = speech_length - sample_offset |
| | | is_final = True |
| | | else: |
| | | is_final = False |
| | | param_dict['is_final'] = is_final |
| | | segments_result = model(audio_in=speech[sample_offset: sample_offset + step], |
| | | param_dict=param_dict) |
| | | print(segments_result) |
| | | |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | from .paraformer_bin import Paraformer |
| | | from .vad_bin import Fsmn_vad |
| | | from .punc_bin import CT_Transformer |
| | | from .punc_bin import CT_Transformer_VadRealtime |
| New file |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | |
| | | import os.path |
| | | from pathlib import Path |
| | | from typing import List, Union, Tuple |
| | | import numpy as np |
| | | |
| | | from .utils.utils import (ONNXRuntimeError, |
| | | OrtInferSession, get_logger, |
| | | read_yaml) |
| | | from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words) |
| | | logging = get_logger() |
| | | |
| | | |
| | | class CT_Transformer(): |
| | | def __init__(self, model_dir: Union[str, Path] = None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | | quantize: bool = False, |
| | | intra_op_num_threads: int = 4 |
| | | ): |
| | | |
| | | if not Path(model_dir).exists(): |
| | | raise FileNotFoundError(f'{model_dir} does not exist.') |
| | | |
| | | model_file = os.path.join(model_dir, 'model.onnx') |
| | | if quantize: |
| | | model_file = os.path.join(model_dir, 'model_quant.onnx') |
| | | config_file = os.path.join(model_dir, 'punc.yaml') |
| | | config = read_yaml(config_file) |
| | | |
| | | self.converter = TokenIDConverter(config['token_list']) |
| | | self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) |
| | | self.batch_size = 1 |
| | | self.punc_list = config['punc_list'] |
| | | self.period = 0 |
| | | for i in range(len(self.punc_list)): |
| | | if self.punc_list[i] == ",": |
| | | self.punc_list[i] = "," |
| | | elif self.punc_list[i] == "?": |
| | | self.punc_list[i] = "?" |
| | | elif self.punc_list[i] == "。": |
| | | self.period = i |
| | | |
| | | def __call__(self, text: Union[list, str], split_size=20): |
| | | split_text = code_mix_split_words(text) |
| | | split_text_id = self.converter.tokens2ids(split_text) |
| | | mini_sentences = split_to_mini_sentence(split_text, split_size) |
| | | mini_sentences_id = split_to_mini_sentence(split_text_id, split_size) |
| | | assert len(mini_sentences) == len(mini_sentences_id) |
| | | cache_sent = [] |
| | | cache_sent_id = [] |
| | | new_mini_sentence = "" |
| | | new_mini_sentence_punc = [] |
| | | cache_pop_trigger_limit = 200 |
| | | for mini_sentence_i in range(len(mini_sentences)): |
| | | mini_sentence = mini_sentences[mini_sentence_i] |
| | | mini_sentence_id = mini_sentences_id[mini_sentence_i] |
| | | mini_sentence = cache_sent + mini_sentence |
| | | mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64') |
| | | data = { |
| | | "text": mini_sentence_id[None,:], |
| | | "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'), |
| | | } |
| | | try: |
| | | outputs = self.infer(data['text'], data['text_lengths']) |
| | | y = outputs[0] |
| | | punctuations = np.argmax(y,axis=-1)[0] |
| | | assert punctuations.size == len(mini_sentence) |
| | | except ONNXRuntimeError: |
| | | logging.warning("error") |
| | | |
| | | # Search for the last Period/QuestionMark as cache |
| | | if mini_sentence_i < len(mini_sentences) - 1: |
| | | sentenceEnd = -1 |
| | | last_comma_index = -1 |
| | | for i in range(len(punctuations) - 2, 1, -1): |
| | | if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?": |
| | | sentenceEnd = i |
| | | break |
| | | if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",": |
| | | last_comma_index = i |
| | | |
| | | if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0: |
| | | # The sentence it too long, cut off at a comma. |
| | | sentenceEnd = last_comma_index |
| | | punctuations[sentenceEnd] = self.period |
| | | cache_sent = mini_sentence[sentenceEnd + 1:] |
| | | cache_sent_id = mini_sentence_id[sentenceEnd + 1:].tolist() |
| | | mini_sentence = mini_sentence[0:sentenceEnd + 1] |
| | | punctuations = punctuations[0:sentenceEnd + 1] |
| | | |
| | | new_mini_sentence_punc += [int(x) for x in punctuations] |
| | | words_with_punc = [] |
| | | for i in range(len(mini_sentence)): |
| | | if i > 0: |
| | | if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1: |
| | | mini_sentence[i] = " " + mini_sentence[i] |
| | | words_with_punc.append(mini_sentence[i]) |
| | | if self.punc_list[punctuations[i]] != "_": |
| | | words_with_punc.append(self.punc_list[punctuations[i]]) |
| | | new_mini_sentence += "".join(words_with_punc) |
| | | # Add Period for the end of the sentence |
| | | new_mini_sentence_out = new_mini_sentence |
| | | new_mini_sentence_punc_out = new_mini_sentence_punc |
| | | if mini_sentence_i == len(mini_sentences) - 1: |
| | | if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、": |
| | | new_mini_sentence_out = new_mini_sentence[:-1] + "。" |
| | | new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] |
| | | elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?": |
| | | new_mini_sentence_out = new_mini_sentence + "。" |
| | | new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] |
| | | return new_mini_sentence_out, new_mini_sentence_punc_out |
| | | |
| | | def infer(self, feats: np.ndarray, |
| | | feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | | outputs = self.ort_infer([feats, feats_len]) |
| | | return outputs |
| | | |
| | | |
| | | class CT_Transformer_VadRealtime(CT_Transformer): |
| | | def __init__(self, model_dir: Union[str, Path] = None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | | quantize: bool = False, |
| | | intra_op_num_threads: int = 4 |
| | | ): |
| | | super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads) |
| | | |
| | | def __call__(self, text: str, param_dict: map, split_size=20): |
| | | cache_key = "cache" |
| | | assert cache_key in param_dict |
| | | cache = param_dict[cache_key] |
| | | if cache is not None and len(cache) > 0: |
| | | precache = "".join(cache) |
| | | else: |
| | | precache = "" |
| | | cache = [] |
| | | full_text = precache + text |
| | | split_text = code_mix_split_words(full_text) |
| | | split_text_id = self.converter.tokens2ids(split_text) |
| | | mini_sentences = split_to_mini_sentence(split_text, split_size) |
| | | mini_sentences_id = split_to_mini_sentence(split_text_id, split_size) |
| | | new_mini_sentence_punc = [] |
| | | assert len(mini_sentences) == len(mini_sentences_id) |
| | | |
| | | cache_sent = [] |
| | | cache_sent_id = np.array([], dtype='int32') |
| | | sentence_punc_list = [] |
| | | sentence_words_list = [] |
| | | cache_pop_trigger_limit = 200 |
| | | skip_num = 0 |
| | | for mini_sentence_i in range(len(mini_sentences)): |
| | | mini_sentence = mini_sentences[mini_sentence_i] |
| | | mini_sentence_id = mini_sentences_id[mini_sentence_i] |
| | | mini_sentence = cache_sent + mini_sentence |
| | | mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0) |
| | | text_length = len(mini_sentence_id) |
| | | data = { |
| | | "input": mini_sentence_id[None,:], |
| | | "text_lengths": np.array([text_length], dtype='int32'), |
| | | "vad_mask": self.vad_mask(text_length, len(cache) - 1)[None, None, :, :].astype(np.float32), |
| | | "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32) |
| | | } |
| | | try: |
| | | outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"]) |
| | | y = outputs[0] |
| | | punctuations = np.argmax(y,axis=-1)[0] |
| | | assert punctuations.size == len(mini_sentence) |
| | | except ONNXRuntimeError: |
| | | logging.warning("error") |
| | | |
| | | # Search for the last Period/QuestionMark as cache |
| | | if mini_sentence_i < len(mini_sentences) - 1: |
| | | sentenceEnd = -1 |
| | | last_comma_index = -1 |
| | | for i in range(len(punctuations) - 2, 1, -1): |
| | | if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?": |
| | | sentenceEnd = i |
| | | break |
| | | if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",": |
| | | last_comma_index = i |
| | | |
| | | if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0: |
| | | # The sentence it too long, cut off at a comma. |
| | | sentenceEnd = last_comma_index |
| | | punctuations[sentenceEnd] = self.period |
| | | cache_sent = mini_sentence[sentenceEnd + 1:] |
| | | cache_sent_id = mini_sentence_id[sentenceEnd + 1:] |
| | | mini_sentence = mini_sentence[0:sentenceEnd + 1] |
| | | punctuations = punctuations[0:sentenceEnd + 1] |
| | | |
| | | punctuations_np = [int(x) for x in punctuations] |
| | | new_mini_sentence_punc += punctuations_np |
| | | sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np] |
| | | sentence_words_list += mini_sentence |
| | | |
| | | assert len(sentence_punc_list) == len(sentence_words_list) |
| | | words_with_punc = [] |
| | | sentence_punc_list_out = [] |
| | | for i in range(0, len(sentence_words_list)): |
| | | if i > 0: |
| | | if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1: |
| | | sentence_words_list[i] = " " + sentence_words_list[i] |
| | | if skip_num < len(cache): |
| | | skip_num += 1 |
| | | else: |
| | | words_with_punc.append(sentence_words_list[i]) |
| | | if skip_num >= len(cache): |
| | | sentence_punc_list_out.append(sentence_punc_list[i]) |
| | | if sentence_punc_list[i] != "_": |
| | | words_with_punc.append(sentence_punc_list[i]) |
| | | sentence_out = "".join(words_with_punc) |
| | | |
| | | sentenceEnd = -1 |
| | | for i in range(len(sentence_punc_list) - 2, 1, -1): |
| | | if sentence_punc_list[i] == "。" or sentence_punc_list[i] == "?": |
| | | sentenceEnd = i |
| | | break |
| | | cache_out = sentence_words_list[sentenceEnd + 1:] |
| | | if sentence_out[-1] in self.punc_list: |
| | | sentence_out = sentence_out[:-1] |
| | | sentence_punc_list_out[-1] = "_" |
| | | param_dict[cache_key] = cache_out |
| | | return sentence_out, sentence_punc_list_out, cache_out |
| | | |
| | | def vad_mask(self, size, vad_pos, dtype=np.bool): |
| | | """Create mask for decoder self-attention. |
| | | |
| | | :param int size: size of mask |
| | | :param int vad_pos: index of vad index |
| | | :param torch.dtype dtype: result dtype |
| | | :rtype: torch.Tensor (B, Lmax, Lmax) |
| | | """ |
| | | ret = np.ones((size, size), dtype=dtype) |
| | | if vad_pos <= 0 or vad_pos >= size: |
| | | return ret |
| | | sub_corner = np.zeros( |
| | | (vad_pos - 1, size - vad_pos), dtype=dtype) |
| | | ret[0:vad_pos - 1, vad_pos:] = sub_corner |
| | | return ret |
| | | |
| | | def infer(self, feats: np.ndarray, |
| | | feats_len: np.ndarray, |
| | | vad_mask: np.ndarray, |
| | | sub_masks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | | outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks]) |
| | | return outputs |
| | | |
| New file |
| | |
| | | from enum import Enum |
| | | from typing import List, Tuple, Dict, Any |
| | | |
| | | import math |
| | | import numpy as np |
| | | |
| | | class VadStateMachine(Enum): |
| | | kVadInStateStartPointNotDetected = 1 |
| | | kVadInStateInSpeechSegment = 2 |
| | | kVadInStateEndPointDetected = 3 |
| | | |
| | | |
| | | class FrameState(Enum): |
| | | kFrameStateInvalid = -1 |
| | | kFrameStateSpeech = 1 |
| | | kFrameStateSil = 0 |
| | | |
| | | |
| | | # final voice/unvoice state per frame |
| | | class AudioChangeState(Enum): |
| | | kChangeStateSpeech2Speech = 0 |
| | | kChangeStateSpeech2Sil = 1 |
| | | kChangeStateSil2Sil = 2 |
| | | kChangeStateSil2Speech = 3 |
| | | kChangeStateNoBegin = 4 |
| | | kChangeStateInvalid = 5 |
| | | |
| | | |
| | | class VadDetectMode(Enum): |
| | | kVadSingleUtteranceDetectMode = 0 |
| | | kVadMutipleUtteranceDetectMode = 1 |
| | | |
| | | |
| | | class VADXOptions: |
| | | def __init__( |
| | | self, |
| | | sample_rate: int = 16000, |
| | | detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value, |
| | | snr_mode: int = 0, |
| | | max_end_silence_time: int = 800, |
| | | max_start_silence_time: int = 3000, |
| | | do_start_point_detection: bool = True, |
| | | do_end_point_detection: bool = True, |
| | | window_size_ms: int = 200, |
| | | sil_to_speech_time_thres: int = 150, |
| | | speech_to_sil_time_thres: int = 150, |
| | | speech_2_noise_ratio: float = 1.0, |
| | | do_extend: int = 1, |
| | | lookback_time_start_point: int = 200, |
| | | lookahead_time_end_point: int = 100, |
| | | max_single_segment_time: int = 60000, |
| | | nn_eval_block_size: int = 8, |
| | | dcd_block_size: int = 4, |
| | | snr_thres: int = -100.0, |
| | | noise_frame_num_used_for_snr: int = 100, |
| | | decibel_thres: int = -100.0, |
| | | speech_noise_thres: float = 0.6, |
| | | fe_prior_thres: float = 1e-4, |
| | | silence_pdf_num: int = 1, |
| | | sil_pdf_ids: List[int] = [0], |
| | | speech_noise_thresh_low: float = -0.1, |
| | | speech_noise_thresh_high: float = 0.3, |
| | | output_frame_probs: bool = False, |
| | | frame_in_ms: int = 10, |
| | | frame_length_ms: int = 25, |
| | | ): |
| | | self.sample_rate = sample_rate |
| | | self.detect_mode = detect_mode |
| | | self.snr_mode = snr_mode |
| | | self.max_end_silence_time = max_end_silence_time |
| | | self.max_start_silence_time = max_start_silence_time |
| | | self.do_start_point_detection = do_start_point_detection |
| | | self.do_end_point_detection = do_end_point_detection |
| | | self.window_size_ms = window_size_ms |
| | | self.sil_to_speech_time_thres = sil_to_speech_time_thres |
| | | self.speech_to_sil_time_thres = speech_to_sil_time_thres |
| | | self.speech_2_noise_ratio = speech_2_noise_ratio |
| | | self.do_extend = do_extend |
| | | self.lookback_time_start_point = lookback_time_start_point |
| | | self.lookahead_time_end_point = lookahead_time_end_point |
| | | self.max_single_segment_time = max_single_segment_time |
| | | self.nn_eval_block_size = nn_eval_block_size |
| | | self.dcd_block_size = dcd_block_size |
| | | self.snr_thres = snr_thres |
| | | self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr |
| | | self.decibel_thres = decibel_thres |
| | | self.speech_noise_thres = speech_noise_thres |
| | | self.fe_prior_thres = fe_prior_thres |
| | | self.silence_pdf_num = silence_pdf_num |
| | | self.sil_pdf_ids = sil_pdf_ids |
| | | self.speech_noise_thresh_low = speech_noise_thresh_low |
| | | self.speech_noise_thresh_high = speech_noise_thresh_high |
| | | self.output_frame_probs = output_frame_probs |
| | | self.frame_in_ms = frame_in_ms |
| | | self.frame_length_ms = frame_length_ms |
| | | |
| | | |
| | | class E2EVadSpeechBufWithDoa(object): |
| | | def __init__(self): |
| | | self.start_ms = 0 |
| | | self.end_ms = 0 |
| | | self.buffer = [] |
| | | self.contain_seg_start_point = False |
| | | self.contain_seg_end_point = False |
| | | self.doa = 0 |
| | | |
| | | def Reset(self): |
| | | self.start_ms = 0 |
| | | self.end_ms = 0 |
| | | self.buffer = [] |
| | | self.contain_seg_start_point = False |
| | | self.contain_seg_end_point = False |
| | | self.doa = 0 |
| | | |
| | | |
| | | class E2EVadFrameProb(object): |
| | | def __init__(self): |
| | | self.noise_prob = 0.0 |
| | | self.speech_prob = 0.0 |
| | | self.score = 0.0 |
| | | self.frame_id = 0 |
| | | self.frm_state = 0 |
| | | |
| | | |
| | | class WindowDetector(object): |
| | | def __init__(self, window_size_ms: int, sil_to_speech_time: int, |
| | | speech_to_sil_time: int, frame_size_ms: int): |
| | | self.window_size_ms = window_size_ms |
| | | self.sil_to_speech_time = sil_to_speech_time |
| | | self.speech_to_sil_time = speech_to_sil_time |
| | | self.frame_size_ms = frame_size_ms |
| | | |
| | | self.win_size_frame = int(window_size_ms / frame_size_ms) |
| | | self.win_sum = 0 |
| | | self.win_state = [0] * self.win_size_frame # 初始化窗 |
| | | |
| | | self.cur_win_pos = 0 |
| | | self.pre_frame_state = FrameState.kFrameStateSil |
| | | self.cur_frame_state = FrameState.kFrameStateSil |
| | | self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms) |
| | | self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms) |
| | | |
| | | self.voice_last_frame_count = 0 |
| | | self.noise_last_frame_count = 0 |
| | | self.hydre_frame_count = 0 |
| | | |
| | | def Reset(self) -> None: |
| | | self.cur_win_pos = 0 |
| | | self.win_sum = 0 |
| | | self.win_state = [0] * self.win_size_frame |
| | | self.pre_frame_state = FrameState.kFrameStateSil |
| | | self.cur_frame_state = FrameState.kFrameStateSil |
| | | self.voice_last_frame_count = 0 |
| | | self.noise_last_frame_count = 0 |
| | | self.hydre_frame_count = 0 |
| | | |
| | | def GetWinSize(self) -> int: |
| | | return int(self.win_size_frame) |
| | | |
| | | def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState: |
| | | cur_frame_state = FrameState.kFrameStateSil |
| | | if frameState == FrameState.kFrameStateSpeech: |
| | | cur_frame_state = 1 |
| | | elif frameState == FrameState.kFrameStateSil: |
| | | cur_frame_state = 0 |
| | | else: |
| | | return AudioChangeState.kChangeStateInvalid |
| | | self.win_sum -= self.win_state[self.cur_win_pos] |
| | | self.win_sum += cur_frame_state |
| | | self.win_state[self.cur_win_pos] = cur_frame_state |
| | | self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame |
| | | |
| | | if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres: |
| | | self.pre_frame_state = FrameState.kFrameStateSpeech |
| | | return AudioChangeState.kChangeStateSil2Speech |
| | | |
| | | if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres: |
| | | self.pre_frame_state = FrameState.kFrameStateSil |
| | | return AudioChangeState.kChangeStateSpeech2Sil |
| | | |
| | | if self.pre_frame_state == FrameState.kFrameStateSil: |
| | | return AudioChangeState.kChangeStateSil2Sil |
| | | if self.pre_frame_state == FrameState.kFrameStateSpeech: |
| | | return AudioChangeState.kChangeStateSpeech2Speech |
| | | return AudioChangeState.kChangeStateInvalid |
| | | |
| | | def FrameSizeMs(self) -> int: |
| | | return int(self.frame_size_ms) |
| | | |
| | | |
| | | class E2EVadModel(): |
| | | def __init__(self, vad_post_args: Dict[str, Any]): |
| | | super(E2EVadModel, self).__init__() |
| | | self.vad_opts = VADXOptions(**vad_post_args) |
| | | self.windows_detector = WindowDetector(self.vad_opts.window_size_ms, |
| | | self.vad_opts.sil_to_speech_time_thres, |
| | | self.vad_opts.speech_to_sil_time_thres, |
| | | self.vad_opts.frame_in_ms) |
| | | # self.encoder = encoder |
| | | # init variables |
| | | self.is_final = False |
| | | self.data_buf_start_frame = 0 |
| | | self.frm_cnt = 0 |
| | | self.latest_confirmed_speech_frame = 0 |
| | | self.lastest_confirmed_silence_frame = -1 |
| | | self.continous_silence_frame_count = 0 |
| | | self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected |
| | | self.confirmed_start_frame = -1 |
| | | self.confirmed_end_frame = -1 |
| | | self.number_end_time_detected = 0 |
| | | self.sil_frame = 0 |
| | | self.sil_pdf_ids = self.vad_opts.sil_pdf_ids |
| | | self.noise_average_decibel = -100.0 |
| | | self.pre_end_silence_detected = False |
| | | self.next_seg = True |
| | | |
| | | self.output_data_buf = [] |
| | | self.output_data_buf_offset = 0 |
| | | self.frame_probs = [] |
| | | self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres |
| | | self.speech_noise_thres = self.vad_opts.speech_noise_thres |
| | | self.scores = None |
| | | self.max_time_out = False |
| | | self.decibel = [] |
| | | self.data_buf = None |
| | | self.data_buf_all = None |
| | | self.waveform = None |
| | | self.ResetDetection() |
| | | |
| | | def AllResetDetection(self): |
| | | self.is_final = False |
| | | self.data_buf_start_frame = 0 |
| | | self.frm_cnt = 0 |
| | | self.latest_confirmed_speech_frame = 0 |
| | | self.lastest_confirmed_silence_frame = -1 |
| | | self.continous_silence_frame_count = 0 |
| | | self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected |
| | | self.confirmed_start_frame = -1 |
| | | self.confirmed_end_frame = -1 |
| | | self.number_end_time_detected = 0 |
| | | self.sil_frame = 0 |
| | | self.sil_pdf_ids = self.vad_opts.sil_pdf_ids |
| | | self.noise_average_decibel = -100.0 |
| | | self.pre_end_silence_detected = False |
| | | self.next_seg = True |
| | | |
| | | self.output_data_buf = [] |
| | | self.output_data_buf_offset = 0 |
| | | self.frame_probs = [] |
| | | self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres |
| | | self.speech_noise_thres = self.vad_opts.speech_noise_thres |
| | | self.scores = None |
| | | self.max_time_out = False |
| | | self.decibel = [] |
| | | self.data_buf = None |
| | | self.data_buf_all = None |
| | | self.waveform = None |
| | | self.ResetDetection() |
| | | |
| | | def ResetDetection(self): |
| | | self.continous_silence_frame_count = 0 |
| | | self.latest_confirmed_speech_frame = 0 |
| | | self.lastest_confirmed_silence_frame = -1 |
| | | self.confirmed_start_frame = -1 |
| | | self.confirmed_end_frame = -1 |
| | | self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected |
| | | self.windows_detector.Reset() |
| | | self.sil_frame = 0 |
| | | self.frame_probs = [] |
| | | |
| | | def ComputeDecibel(self) -> None: |
| | | frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000) |
| | | frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) |
| | | if self.data_buf_all is None: |
| | | self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0] |
| | | self.data_buf = self.data_buf_all |
| | | else: |
| | | self.data_buf_all = np.concatenate((self.data_buf_all, self.waveform[0])) |
| | | for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length): |
| | | self.decibel.append( |
| | | 10 * math.log10(np.square((self.waveform[0][offset: offset + frame_sample_length])).sum() + \ |
| | | 0.000001)) |
| | | |
| | | def ComputeScores(self, scores: np.ndarray) -> None: |
| | | # scores = self.encoder(feats, in_cache) # return B * T * D |
| | | self.vad_opts.nn_eval_block_size = scores.shape[1] |
| | | self.frm_cnt += scores.shape[1] # count total frames |
| | | if self.scores is None: |
| | | self.scores = scores # the first calculation |
| | | else: |
| | | self.scores = np.concatenate((self.scores, scores), axis=1) |
| | | |
| | | def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again |
| | | while self.data_buf_start_frame < frame_idx: |
| | | if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000): |
| | | self.data_buf_start_frame += 1 |
| | | self.data_buf = self.data_buf_all[self.data_buf_start_frame * int( |
| | | self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] |
| | | |
| | | def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool, |
| | | last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None: |
| | | self.PopDataBufTillFrame(start_frm) |
| | | expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000) |
| | | if last_frm_is_end_point: |
| | | extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \ |
| | | self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)) |
| | | expected_sample_number += int(extra_sample) |
| | | if end_point_is_sent_end: |
| | | expected_sample_number = max(expected_sample_number, len(self.data_buf)) |
| | | if len(self.data_buf) < expected_sample_number: |
| | | print('error in calling pop data_buf\n') |
| | | |
| | | if len(self.output_data_buf) == 0 or first_frm_is_start_point: |
| | | self.output_data_buf.append(E2EVadSpeechBufWithDoa()) |
| | | self.output_data_buf[-1].Reset() |
| | | self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms |
| | | self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms |
| | | self.output_data_buf[-1].doa = 0 |
| | | cur_seg = self.output_data_buf[-1] |
| | | if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: |
| | | print('warning\n') |
| | | out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作 |
| | | data_to_pop = 0 |
| | | if end_point_is_sent_end: |
| | | data_to_pop = expected_sample_number |
| | | else: |
| | | data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) |
| | | if data_to_pop > len(self.data_buf): |
| | | print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n') |
| | | data_to_pop = len(self.data_buf) |
| | | expected_sample_number = len(self.data_buf) |
| | | |
| | | cur_seg.doa = 0 |
| | | for sample_cpy_out in range(0, data_to_pop): |
| | | # cur_seg.buffer[out_pos ++] = data_buf_.back(); |
| | | out_pos += 1 |
| | | for sample_cpy_out in range(data_to_pop, expected_sample_number): |
| | | # cur_seg.buffer[out_pos++] = data_buf_.back() |
| | | out_pos += 1 |
| | | if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: |
| | | print('Something wrong with the VAD algorithm\n') |
| | | self.data_buf_start_frame += frm_cnt |
| | | cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms |
| | | if first_frm_is_start_point: |
| | | cur_seg.contain_seg_start_point = True |
| | | if last_frm_is_end_point: |
| | | cur_seg.contain_seg_end_point = True |
| | | |
| | | def OnSilenceDetected(self, valid_frame: int): |
| | | self.lastest_confirmed_silence_frame = valid_frame |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | self.PopDataBufTillFrame(valid_frame) |
| | | # silence_detected_callback_ |
| | | # pass |
| | | |
| | | def OnVoiceDetected(self, valid_frame: int) -> None: |
| | | self.latest_confirmed_speech_frame = valid_frame |
| | | self.PopDataToOutputBuf(valid_frame, 1, False, False, False) |
| | | |
| | | def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None: |
| | | if self.vad_opts.do_start_point_detection: |
| | | pass |
| | | if self.confirmed_start_frame != -1: |
| | | print('not reset vad properly\n') |
| | | else: |
| | | self.confirmed_start_frame = start_frame |
| | | |
| | | if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False) |
| | | |
| | | def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None: |
| | | for t in range(self.latest_confirmed_speech_frame + 1, end_frame): |
| | | self.OnVoiceDetected(t) |
| | | if self.vad_opts.do_end_point_detection: |
| | | pass |
| | | if self.confirmed_end_frame != -1: |
| | | print('not reset vad properly\n') |
| | | else: |
| | | self.confirmed_end_frame = end_frame |
| | | if not fake_result: |
| | | self.sil_frame = 0 |
| | | self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame) |
| | | self.number_end_time_detected += 1 |
| | | |
| | | def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None: |
| | | if is_final_frame: |
| | | self.OnVoiceEnd(cur_frm_idx, False, True) |
| | | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | | |
| | | def GetLatency(self) -> int: |
| | | return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms) |
| | | |
| | | def LatencyFrmNumAtStartPoint(self) -> int: |
| | | vad_latency = self.windows_detector.GetWinSize() |
| | | if self.vad_opts.do_extend: |
| | | vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms) |
| | | return vad_latency |
| | | |
| | | def GetFrameState(self, t: int) -> FrameState: |
| | | frame_state = FrameState.kFrameStateInvalid |
| | | cur_decibel = self.decibel[t] |
| | | cur_snr = cur_decibel - self.noise_average_decibel |
| | | # for each frame, calc log posterior probability of each state |
| | | if cur_decibel < self.vad_opts.decibel_thres: |
| | | frame_state = FrameState.kFrameStateSil |
| | | self.DetectOneFrame(frame_state, t, False) |
| | | return frame_state |
| | | |
| | | sum_score = 0.0 |
| | | noise_prob = 0.0 |
| | | assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num |
| | | if len(self.sil_pdf_ids) > 0: |
| | | assert len(self.scores) == 1 # 只支持batch_size = 1的测试 |
| | | sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids] |
| | | sum_score = sum(sil_pdf_scores) |
| | | noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio |
| | | total_score = 1.0 |
| | | sum_score = total_score - sum_score |
| | | speech_prob = math.log(sum_score) |
| | | if self.vad_opts.output_frame_probs: |
| | | frame_prob = E2EVadFrameProb() |
| | | frame_prob.noise_prob = noise_prob |
| | | frame_prob.speech_prob = speech_prob |
| | | frame_prob.score = sum_score |
| | | frame_prob.frame_id = t |
| | | self.frame_probs.append(frame_prob) |
| | | if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres: |
| | | if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres: |
| | | frame_state = FrameState.kFrameStateSpeech |
| | | else: |
| | | frame_state = FrameState.kFrameStateSil |
| | | else: |
| | | frame_state = FrameState.kFrameStateSil |
| | | if self.noise_average_decibel < -99.9: |
| | | self.noise_average_decibel = cur_decibel |
| | | else: |
| | | self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * ( |
| | | self.vad_opts.noise_frame_num_used_for_snr |
| | | - 1)) / self.vad_opts.noise_frame_num_used_for_snr |
| | | |
| | | return frame_state |
| | | |
| | | |
| | | def __call__(self, score: np.ndarray, waveform: np.ndarray, |
| | | is_final: bool = False, max_end_sil: int = 800 |
| | | ): |
| | | self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres |
| | | self.waveform = waveform # compute decibel for each frame |
| | | self.ComputeDecibel() |
| | | self.ComputeScores(score) |
| | | if not is_final: |
| | | self.DetectCommonFrames() |
| | | else: |
| | | self.DetectLastFrames() |
| | | segments = [] |
| | | for batch_num in range(0, score.shape[0]): # only support batch_size = 1 now |
| | | segment_batch = [] |
| | | if len(self.output_data_buf) > 0: |
| | | for i in range(self.output_data_buf_offset, len(self.output_data_buf)): |
| | | if not self.output_data_buf[i].contain_seg_start_point: |
| | | continue |
| | | if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point: |
| | | continue |
| | | start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1 |
| | | if self.output_data_buf[i].contain_seg_end_point: |
| | | end_ms = self.output_data_buf[i].end_ms |
| | | self.next_seg = True |
| | | self.output_data_buf_offset += 1 |
| | | else: |
| | | end_ms = -1 |
| | | self.next_seg = False |
| | | segment = [start_ms, end_ms] |
| | | segment_batch.append(segment) |
| | | if segment_batch: |
| | | segments.append(segment_batch) |
| | | if is_final: |
| | | # reset class variables and clear the dict for the next query |
| | | self.AllResetDetection() |
| | | return segments |
| | | |
| | | def DetectCommonFrames(self) -> int: |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| | | return 0 |
| | | for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| | | frame_state = FrameState.kFrameStateInvalid |
| | | frame_state = self.GetFrameState(self.frm_cnt - 1 - i) |
| | | self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False) |
| | | |
| | | return 0 |
| | | |
| | | def DetectLastFrames(self) -> int: |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| | | return 0 |
| | | for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| | | frame_state = FrameState.kFrameStateInvalid |
| | | frame_state = self.GetFrameState(self.frm_cnt - 1 - i) |
| | | if i != 0: |
| | | self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False) |
| | | else: |
| | | self.DetectOneFrame(frame_state, self.frm_cnt - 1, True) |
| | | |
| | | return 0 |
| | | |
| | | def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None: |
| | | tmp_cur_frm_state = FrameState.kFrameStateInvalid |
| | | if cur_frm_state == FrameState.kFrameStateSpeech: |
| | | if math.fabs(1.0) > self.vad_opts.fe_prior_thres: |
| | | tmp_cur_frm_state = FrameState.kFrameStateSpeech |
| | | else: |
| | | tmp_cur_frm_state = FrameState.kFrameStateSil |
| | | elif cur_frm_state == FrameState.kFrameStateSil: |
| | | tmp_cur_frm_state = FrameState.kFrameStateSil |
| | | state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx) |
| | | frm_shift_in_ms = self.vad_opts.frame_in_ms |
| | | if AudioChangeState.kChangeStateSil2Speech == state_change: |
| | | silence_frame_count = self.continous_silence_frame_count |
| | | self.continous_silence_frame_count = 0 |
| | | self.pre_end_silence_detected = False |
| | | start_frame = 0 |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint()) |
| | | self.OnVoiceStart(start_frame) |
| | | self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment |
| | | for t in range(start_frame + 1, cur_frm_idx + 1): |
| | | self.OnVoiceDetected(t) |
| | | elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| | | for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx): |
| | | self.OnVoiceDetected(t) |
| | | if cur_frm_idx - self.confirmed_start_frame + 1 > \ |
| | | self.vad_opts.max_single_segment_time / frm_shift_in_ms: |
| | | self.OnVoiceEnd(cur_frm_idx, False, False) |
| | | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | | elif not is_final_frame: |
| | | self.OnVoiceDetected(cur_frm_idx) |
| | | else: |
| | | self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) |
| | | else: |
| | | pass |
| | | elif AudioChangeState.kChangeStateSpeech2Sil == state_change: |
| | | self.continous_silence_frame_count = 0 |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | pass |
| | | elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| | | if cur_frm_idx - self.confirmed_start_frame + 1 > \ |
| | | self.vad_opts.max_single_segment_time / frm_shift_in_ms: |
| | | self.OnVoiceEnd(cur_frm_idx, False, False) |
| | | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | | elif not is_final_frame: |
| | | self.OnVoiceDetected(cur_frm_idx) |
| | | else: |
| | | self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) |
| | | else: |
| | | pass |
| | | elif AudioChangeState.kChangeStateSpeech2Speech == state_change: |
| | | self.continous_silence_frame_count = 0 |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| | | if cur_frm_idx - self.confirmed_start_frame + 1 > \ |
| | | self.vad_opts.max_single_segment_time / frm_shift_in_ms: |
| | | self.max_time_out = True |
| | | self.OnVoiceEnd(cur_frm_idx, False, False) |
| | | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | | elif not is_final_frame: |
| | | self.OnVoiceDetected(cur_frm_idx) |
| | | else: |
| | | self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) |
| | | else: |
| | | pass |
| | | elif AudioChangeState.kChangeStateSil2Sil == state_change: |
| | | self.continous_silence_frame_count += 1 |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | # silence timeout, return zero length decision |
| | | if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and ( |
| | | self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ |
| | | or (is_final_frame and self.number_end_time_detected == 0): |
| | | for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx): |
| | | self.OnSilenceDetected(t) |
| | | self.OnVoiceStart(0, True) |
| | | self.OnVoiceEnd(0, True, False); |
| | | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | | else: |
| | | if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(): |
| | | self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint()) |
| | | elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| | | if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh: |
| | | lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms) |
| | | if self.vad_opts.do_extend: |
| | | lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms) |
| | | lookback_frame -= 1 |
| | | lookback_frame = max(0, lookback_frame) |
| | | self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False) |
| | | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | | elif cur_frm_idx - self.confirmed_start_frame + 1 > \ |
| | | self.vad_opts.max_single_segment_time / frm_shift_in_ms: |
| | | self.OnVoiceEnd(cur_frm_idx, False, False) |
| | | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | | elif self.vad_opts.do_extend and not is_final_frame: |
| | | if self.continous_silence_frame_count <= int( |
| | | self.vad_opts.lookahead_time_end_point / frm_shift_in_ms): |
| | | self.OnVoiceDetected(cur_frm_idx) |
| | | else: |
| | | self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) |
| | | else: |
| | | pass |
| | | |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \ |
| | | self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value: |
| | | self.ResetDetection() |
| | |
| | | input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray: |
| | | input_dict = dict(zip(self.get_input_names(), input_content)) |
| | | try: |
| | | return self.session.run(None, input_dict) |
| | | return self.session.run(self.get_output_names(), input_dict) |
| | | except Exception as e: |
| | | raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e |
| | | |
| | |
| | | if not model_path.is_file(): |
| | | raise FileExistsError(f'{model_path} is not a file.') |
| | | |
| | | def split_to_mini_sentence(words: list, word_limit: int = 20): |
| | | assert word_limit > 1 |
| | | if len(words) <= word_limit: |
| | | return [words] |
| | | sentences = [] |
| | | length = len(words) |
| | | sentence_len = length // word_limit |
| | | for i in range(sentence_len): |
| | | sentences.append(words[i * word_limit:(i + 1) * word_limit]) |
| | | if length % word_limit > 0: |
| | | sentences.append(words[sentence_len * word_limit:]) |
| | | return sentences |
| | | |
| | | def code_mix_split_words(text: str): |
| | | words = [] |
| | | segs = text.split() |
| | | for seg in segs: |
| | | # There is no space in seg. |
| | | current_word = "" |
| | | for c in seg: |
| | | if len(c.encode()) == 1: |
| | | # This is an ASCII char. |
| | | current_word += c |
| | | else: |
| | | # This is a Chinese char. |
| | | if len(current_word) > 0: |
| | | words.append(current_word) |
| | | current_word = "" |
| | | words.append(c) |
| | | if len(current_word) > 0: |
| | | words.append(current_word) |
| | | return words |
| | | |
| | | def read_yaml(yaml_path: Union[str, Path]) -> Dict: |
| | | if not Path(yaml_path).exists(): |
| | |
| | | |
| | | |
| | | @functools.lru_cache() |
| | | def get_logger(name='rapdi_paraformer'): |
| | | def get_logger(name='funasr_onnx'): |
| | | """Initialize and get a logger by name. |
| | | If the logger has not been initialized, this method will initialize the |
| | | logger by adding one or two handlers, otherwise the initialized logger will |
| New file |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | |
| | | import os.path |
| | | from pathlib import Path |
| | | from typing import List, Union, Tuple |
| | | |
| | | import copy |
| | | import librosa |
| | | import numpy as np |
| | | |
| | | from .utils.utils import (ONNXRuntimeError, |
| | | OrtInferSession, get_logger, |
| | | read_yaml) |
| | | from .utils.frontend import WavFrontend |
| | | from .utils.e2e_vad import E2EVadModel |
| | | |
| | | logging = get_logger() |
| | | |
| | | |
| | | class Fsmn_vad(): |
| | | def __init__(self, model_dir: Union[str, Path] = None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | | quantize: bool = False, |
| | | intra_op_num_threads: int = 4, |
| | | max_end_sil: int = None, |
| | | ): |
| | | |
| | | if not Path(model_dir).exists(): |
| | | raise FileNotFoundError(f'{model_dir} does not exist.') |
| | | |
| | | model_file = os.path.join(model_dir, 'model.onnx') |
| | | if quantize: |
| | | model_file = os.path.join(model_dir, 'model_quant.onnx') |
| | | config_file = os.path.join(model_dir, 'vad.yaml') |
| | | cmvn_file = os.path.join(model_dir, 'vad.mvn') |
| | | config = read_yaml(config_file) |
| | | |
| | | self.frontend = WavFrontend( |
| | | cmvn_file=cmvn_file, |
| | | **config['frontend_conf'] |
| | | ) |
| | | self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) |
| | | self.batch_size = batch_size |
| | | self.vad_scorer = E2EVadModel(config["vad_post_conf"]) |
| | | self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"] |
| | | self.encoder_conf = config["encoder_conf"] |
| | | |
| | | def prepare_cache(self, in_cache: list = []): |
| | | if len(in_cache) > 0: |
| | | return in_cache |
| | | fsmn_layers = self.encoder_conf["fsmn_layers"] |
| | | proj_dim = self.encoder_conf["proj_dim"] |
| | | lorder = self.encoder_conf["lorder"] |
| | | for i in range(fsmn_layers): |
| | | cache = np.zeros((1, proj_dim, lorder-1, 1)).astype(np.float32) |
| | | in_cache.append(cache) |
| | | return in_cache |
| | | |
| | | |
| | | def __call__(self, audio_in: Union[str, np.ndarray, List[str]], **kwargs) -> List: |
| | | # waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq) |
| | | |
| | | param_dict = kwargs.get('param_dict', dict()) |
| | | is_final = param_dict.get('is_final', False) |
| | | audio_in_cache = param_dict.get('audio_in_cache', None) |
| | | audio_in_cum = audio_in |
| | | if audio_in_cache is not None: |
| | | audio_in_cum = np.concatenate((audio_in_cache, audio_in_cum)) |
| | | param_dict['audio_in_cache'] = audio_in_cum |
| | | feats, feats_len = self.extract_feat([audio_in_cum]) |
| | | |
| | | in_cache = param_dict.get('in_cache', list()) |
| | | in_cache = self.prepare_cache(in_cache) |
| | | beg_idx = param_dict.get('beg_idx',0) |
| | | feats = feats[:, beg_idx:beg_idx+8, :] |
| | | param_dict['beg_idx'] = beg_idx + feats.shape[1] |
| | | try: |
| | | inputs = [feats] |
| | | inputs.extend(in_cache) |
| | | scores, out_caches = self.infer(inputs) |
| | | param_dict['in_cache'] = out_caches |
| | | segments = self.vad_scorer(scores, audio_in[None, :], is_final=is_final, max_end_sil=self.max_end_sil) |
| | | # print(segments) |
| | | if len(segments) == 1 and segments[0][0][1] != -1: |
| | | self.frontend.reset_status() |
| | | |
| | | |
| | | except ONNXRuntimeError: |
| | | logging.warning(traceback.format_exc()) |
| | | logging.warning("input wav is silence or noise") |
| | | segments = [] |
| | | |
| | | return segments |
| | | |
| | | def load_data(self, |
| | | wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: |
| | | def load_wav(path: str) -> np.ndarray: |
| | | waveform, _ = librosa.load(path, sr=fs) |
| | | return waveform |
| | | |
| | | if isinstance(wav_content, np.ndarray): |
| | | return [wav_content] |
| | | |
| | | if isinstance(wav_content, str): |
| | | return [load_wav(wav_content)] |
| | | |
| | | if isinstance(wav_content, list): |
| | | return [load_wav(path) for path in wav_content] |
| | | |
| | | raise TypeError( |
| | | f'The type of {wav_content} is not in [str, np.ndarray, list]') |
| | | |
| | | def extract_feat(self, |
| | | waveform_list: List[np.ndarray] |
| | | ) -> Tuple[np.ndarray, np.ndarray]: |
| | | feats, feats_len = [], [] |
| | | for waveform in waveform_list: |
| | | speech, _ = self.frontend.fbank(waveform) |
| | | feat, feat_len = self.frontend.lfr_cmvn(speech) |
| | | feats.append(feat) |
| | | feats_len.append(feat_len) |
| | | |
| | | feats = self.pad_feats(feats, np.max(feats_len)) |
| | | feats_len = np.array(feats_len).astype(np.int32) |
| | | return feats, feats_len |
| | | |
| | | @staticmethod |
| | | def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: |
| | | def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: |
| | | pad_width = ((0, max_feat_len - cur_len), (0, 0)) |
| | | return np.pad(feat, pad_width, 'constant', constant_values=0) |
| | | |
| | | feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] |
| | | feats = np.array(feat_res).astype(np.float32) |
| | | return feats |
| | | |
| | | def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]: |
| | | |
| | | outputs = self.ort_infer(feats) |
| | | scores, out_caches = outputs[0], outputs[1:] |
| | | return scores, out_caches |
| | | |
| | |
| | | |
| | | |
| | | MODULE_NAME = 'funasr_onnx' |
| | | VERSION_NUM = '0.0.2' |
| | | VERSION_NUM = '0.0.3' |
| | | |
| | | setuptools.setup( |
| | | name=MODULE_NAME, |
| | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.preprocessor import CommonPreprocessor |
| | | from funasr.lm.abs_model import AbsLM |
| | | from funasr.lm.espnet_model import ESPnetLanguageModel |
| | | from funasr.lm.abs_model import LanguageModel |
| | | from funasr.lm.seq_rnn_lm import SequentialRNNLM |
| | | from funasr.lm.transformer_lm import TransformerLM |
| | | from funasr.tasks.abs_task import AbsTask |
| | |
| | | group.add_argument( |
| | | "--model_conf", |
| | | action=NestedDictAction, |
| | | default=get_default_kwargs(ESPnetLanguageModel), |
| | | default=get_default_kwargs(LanguageModel), |
| | | help="The keyword arguments for model class.", |
| | | ) |
| | | |
| | |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel: |
| | | def build_model(cls, args: argparse.Namespace) -> LanguageModel: |
| | | assert check_argument_types() |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | |
| | | |
| | | # 2. Build ESPnetModel |
| | | # Assume the last-id is sos_and_eos |
| | | model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf) |
| | | model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf) |
| | | |
| | | # 3. Initialize |
| | | if args.init is not None: |
| | |
| | | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.punctuation.espnet_model import ESPnetPunctuationModel |
| | | from funasr.punctuation.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.train.abs_model import AbsPunctuation |
| | | from funasr.train.abs_model import PunctuationModel |
| | | from funasr.models.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.models.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.tasks.abs_task import AbsTask |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.torch_utils.initialize import initialize |
| | |
| | | group.add_argument( |
| | | "--model_conf", |
| | | action=NestedDictAction, |
| | | default=get_default_kwargs(ESPnetPunctuationModel), |
| | | default=get_default_kwargs(PunctuationModel), |
| | | help="The keyword arguments for model class.", |
| | | ) |
| | | |
| | |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel: |
| | | def build_model(cls, args: argparse.Namespace) -> PunctuationModel: |
| | | assert check_argument_types() |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | |
| | | # Assume the last-id is sos_and_eos |
| | | if "punc_weight" in args.model_conf: |
| | | args.model_conf.pop("punc_weight") |
| | | model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) |
| | | model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) |
| | | |
| | | # FIXME(kamo): Should be done in model? |
| | | # 3. Initialize |
| | |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.frontend.default import DefaultFrontend |
| | | from funasr.models.frontend.fused import FusedFrontends |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline |
| | | from funasr.models.frontend.s3prl import S3prlFrontend |
| | | from funasr.models.frontend.windowing import SlidingWindow |
| | | from funasr.models.postencoder.abs_postencoder import AbsPostEncoder |
| | |
| | | s3prl=S3prlFrontend, |
| | | fused=FusedFrontends, |
| | | wav_frontend=WavFrontend, |
| | | wav_frontend_online=WavFrontendOnline, |
| | | ), |
| | | type_check=AbsFrontend, |
| | | default="default", |
| | |
| | | model_class = model_choices.get_class(args.model) |
| | | except AttributeError: |
| | | model_class = model_choices.get_class("e2evad") |
| | | model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf) |
| | | |
| | | # 1. frontend |
| | | if args.input_size is None: |
| | | # Extract features in the model |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | if args.frontend == 'wav_frontend': |
| | | frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) |
| | | else: |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | | # Give features from data-loader |
| | | args.frontend = None |
| | | args.frontend_conf = {} |
| | | frontend = None |
| | | input_size = args.input_size |
| | | |
| | | model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend) |
| | | |
| | | return model |
| | | |
| | |
| | | cls, |
| | | config_file: Union[Path, str] = None, |
| | | model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | ): |
| | | """Build model from the files. |
| | |
| | | |
| | | with config_file.open("r", encoding="utf-8") as f: |
| | | args = yaml.safe_load(f) |
| | | if cmvn_file is not None: |
| | | args["cmvn_file"] = cmvn_file |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | model.to(device) |
| File was renamed from funasr/punctuation/espnet_model.py |
| | |
| | | from abc import ABC |
| | | from abc import abstractmethod |
| | | |
| | | |
| | | from typing import Dict |
| | | from typing import Optional |
| | | from typing import Tuple |
| | |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | |
| | | from funasr.modules.scorers.scorer_interface import BatchScorerInterface |
| | | |
| | | class ESPnetPunctuationModel(AbsESPnetModel): |
| | | |
| | | class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC): |
| | | """The abstract class |
| | | |
| | | To share the loss calculation way among different models, |
| | | We uses delegate pattern here: |
| | | The instance of this class should be passed to "LanguageModel" |
| | | |
| | | This "model" is one of mediator objects for "Task" class. |
| | | |
| | | """ |
| | | |
| | | @abstractmethod |
| | | def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | raise NotImplementedError |
| | | |
| | | @abstractmethod |
| | | def with_vad(self) -> bool: |
| | | raise NotImplementedError |
| | | |
| | | |
| | | class PunctuationModel(AbsESPnetModel): |
| | | |
| | | def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | |
| | | self.punc_weight = torch.Tensor(punc_weight) |
| | | self.sos = 1 |
| | | self.eos = 2 |
| | | |
| | | |
| | | # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR. |
| | | self.ignore_id = ignore_id |
| | | #if self.punc_model.with_vad(): |
| | | # if self.punc_model.with_vad(): |
| | | # print("This is a vad puncuation model.") |
| | | |
| | | |
| | | def nll( |
| | | self, |
| | | text: torch.Tensor, |
| | |
| | | else: |
| | | text = text[:, :max_length] |
| | | punc = punc[:, :max_length] |
| | | |
| | | |
| | | if self.punc_model.with_vad(): |
| | | # Should be VadRealtimeTransformer |
| | | assert vad_indexes is not None |
| | |
| | | else: |
| | | # Should be TargetDelayTransformer, |
| | | y, _ = self.punc_model(text, text_lengths) |
| | | |
| | | |
| | | # Calc negative log likelihood |
| | | # nll: (BxL,) |
| | | if self.training == False: |
| | |
| | | return nll, text_lengths |
| | | else: |
| | | self.punc_weight = self.punc_weight.to(punc.device) |
| | | nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id) |
| | | nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", |
| | | ignore_index=self.ignore_id) |
| | | # nll: (BxL,) -> (BxL,) |
| | | if max_length is None: |
| | | nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0) |
| | |
| | | # nll: (BxL,) -> (B, L) |
| | | nll = nll.view(batch_size, -1) |
| | | return nll, text_lengths |
| | | |
| | | |
| | | def batchify_nll(self, |
| | | text: torch.Tensor, |
| | | punc: torch.Tensor, |
| | |
| | | nlls = [] |
| | | x_lengths = [] |
| | | max_length = text_lengths.max() |
| | | |
| | | |
| | | start_idx = 0 |
| | | while True: |
| | | end_idx = min(start_idx + batch_size, total_num) |
| | |
| | | assert nll.size(0) == total_num |
| | | assert x_lengths.size(0) == total_num |
| | | return nll, x_lengths |
| | | |
| | | |
| | | def forward( |
| | | self, |
| | | text: torch.Tensor, |
| | |
| | | ntokens = y_lengths.sum() |
| | | loss = nll.sum() / ntokens |
| | | stats = dict(loss=loss.detach()) |
| | | |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | |
| | | def collect_feats(self, text: torch.Tensor, punc: torch.Tensor, |
| | | text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]: |
| | | return {} |
| | | |
| | | |
| | | def inference(self, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |