Merge pull request #761 from alibaba-damo-academy/dev_add_english_paraformer
Dev add english paraformer
| New file |
| | |
| | | ../../TEMPLATE/README.md |
| New file |
| | |
| | | from modelscope.pipelines import pipeline |
| | | from modelscope.utils.constant import Tasks |
| | | |
| | | inference_pipeline = pipeline( |
| | | task=Tasks.auto_speech_recognition, |
| | | model='damo/damo/speech_paraformer_asr-en-16k-vocab4199-pytorch', |
| | | model_revision="v1.0.1", |
| | | ) |
| | | audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav' |
| | | rec_result = inference_pipeline(audio_in=audio_in) |
| | | print(rec_result) |
| New file |
| | |
| | | import os |
| | | |
| | | from modelscope.metainfo import Trainers |
| | | from modelscope.trainers import build_trainer |
| | | |
| | | from funasr.datasets.ms_dataset import MsDataset |
| | | from funasr.utils.modelscope_param import modelscope_args |
| | | |
| | | |
| | | def modelscope_finetune(params): |
| | | if not os.path.exists(params.output_dir): |
| | | os.makedirs(params.output_dir, exist_ok=True) |
| | | # dataset split ["train", "validation"] |
| | | ds_dict = MsDataset.load(params.data_path) |
| | | kwargs = dict( |
| | | model=params.model, |
| | | data_dir=ds_dict, |
| | | dataset_type=params.dataset_type, |
| | | work_dir=params.output_dir, |
| | | batch_bins=params.batch_bins, |
| | | max_epoch=params.max_epoch, |
| | | lr=params.lr) |
| | | trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs) |
| | | trainer.train() |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | params = modelscope_args(model="damo/speech_paraformer_asr-en-16k-vocab4199-pytorch", data_path="./data") |
| | | params.output_dir = "./checkpoint" # m模型保存路径 |
| | | params.data_path = "./example_data/" # 数据路径 |
| | | params.dataset_type = "small" # 小数据量设置small,若数据量大于1000小时,请使用large |
| | | params.batch_bins = 2000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒, |
| | | params.max_epoch = 50 # 最大训练轮数 |
| | | params.lr = 0.00005 # 设置学习率 |
| | | |
| | | modelscope_finetune(params) |
| New file |
| | |
| | | ../../TEMPLATE/infer.py |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | set -e |
| | | set -u |
| | | set -o pipefail |
| | | |
| | | stage=1 |
| | | stop_stage=2 |
| | | model="damo/speech_paraformer_asr-en-16k-vocab4199-pytorch" |
| | | data_dir="./data/test" |
| | | output_dir="./results" |
| | | batch_size=64 |
| | | gpu_inference=true # whether to perform gpu decoding |
| | | gpuid_list="0,1" # set gpus, e.g., gpuid_list="0,1" |
| | | njob=64 # the number of jobs for CPU decoding, if gpu_inference=false, use CPU decoding, please set njob |
| | | checkpoint_dir= |
| | | checkpoint_name="valid.cer_ctc.ave.pb" |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | if ${gpu_inference} == "true"; then |
| | | nj=$(echo $gpuid_list | awk -F "," '{print NF}') |
| | | else |
| | | nj=$njob |
| | | batch_size=1 |
| | | gpuid_list="" |
| | | for JOB in $(seq ${nj}); do |
| | | gpuid_list=$gpuid_list"-1," |
| | | done |
| | | fi |
| | | |
| | | mkdir -p $output_dir/split |
| | | split_scps="" |
| | | for JOB in $(seq ${nj}); do |
| | | split_scps="$split_scps $output_dir/split/wav.$JOB.scp" |
| | | done |
| | | perl utils/split_scp.pl ${data_dir}/wav.scp ${split_scps} |
| | | |
| | | if [ -n "${checkpoint_dir}" ]; then |
| | | python utils/prepare_checkpoint.py ${model} ${checkpoint_dir} ${checkpoint_name} |
| | | model=${checkpoint_dir}/${model} |
| | | fi |
| | | |
| | | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ];then |
| | | echo "Decoding ..." |
| | | gpuid_list_array=(${gpuid_list//,/ }) |
| | | for JOB in $(seq ${nj}); do |
| | | { |
| | | id=$((JOB-1)) |
| | | gpuid=${gpuid_list_array[$id]} |
| | | mkdir -p ${output_dir}/output.$JOB |
| | | python infer.py \ |
| | | --model ${model} \ |
| | | --audio_in ${output_dir}/split/wav.$JOB.scp \ |
| | | --output_dir ${output_dir}/output.$JOB \ |
| | | --batch_size ${batch_size} \ |
| | | --gpuid ${gpuid} |
| | | }& |
| | | done |
| | | wait |
| | | |
| | | mkdir -p ${output_dir}/1best_recog |
| | | for f in token score text; do |
| | | if [ -f "${output_dir}/output.1/1best_recog/${f}" ]; then |
| | | for i in $(seq "${nj}"); do |
| | | cat "${output_dir}/output.${i}/1best_recog/${f}" |
| | | done | sort -k1 >"${output_dir}/1best_recog/${f}" |
| | | fi |
| | | done |
| | | fi |
| | | |
| | | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then |
| | | echo "Computing WER ..." |
| | | cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc |
| | | cp ${data_dir}/text ${output_dir}/1best_recog/text.ref |
| | | python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer |
| | | tail -n 3 ${output_dir}/1best_recog/text.cer |
| | | fi |
| | | |
| | | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ];then |
| | | echo "SpeechIO TIOBE textnorm" |
| | | echo "$0 --> Normalizing REF text ..." |
| | | ./utils/textnorm_zh.py \ |
| | | --has_key --to_upper \ |
| | | ${data_dir}/text \ |
| | | ${output_dir}/1best_recog/ref.txt |
| | | |
| | | echo "$0 --> Normalizing HYP text ..." |
| | | ./utils/textnorm_zh.py \ |
| | | --has_key --to_upper \ |
| | | ${output_dir}/1best_recog/text.proc \ |
| | | ${output_dir}/1best_recog/rec.txt |
| | | grep -v $'\t$' ${output_dir}/1best_recog/rec.txt > ${output_dir}/1best_recog/rec_non_empty.txt |
| | | |
| | | echo "$0 --> computing WER/CER and alignment ..." |
| | | ./utils/error_rate_zh \ |
| | | --tokenizer char \ |
| | | --ref ${output_dir}/1best_recog/ref.txt \ |
| | | --hyp ${output_dir}/1best_recog/rec_non_empty.txt \ |
| | | ${output_dir}/1best_recog/DETAILS.txt | tee ${output_dir}/1best_recog/RESULTS.txt |
| | | rm -rf ${output_dir}/1best_recog/rec.txt ${output_dir}/1best_recog/rec_non_empty.txt |
| | | fi |
| | | |
| New file |
| | |
| | | ../../../../egs/aishell/transformer/utils |
| New file |
| | |
| | | import json |
| | | from typing import Union, Dict |
| | | from pathlib import Path |
| | | |
| | | import os |
| | | import logging |
| | | import torch |
| | | |
| | | from funasr.export.models import get_model |
| | | import numpy as np |
| | | import random |
| | | from funasr.utils.types import str2bool, str2triple_str |
| | | # torch_version = float(".".join(torch.__version__.split(".")[:2])) |
| | | # assert torch_version > 1.9 |
| | | |
| | | class ModelExport: |
| | | def __init__( |
| | | self, |
| | | cache_dir: Union[Path, str] = None, |
| | | onnx: bool = True, |
| | | device: str = "cpu", |
| | | quant: bool = True, |
| | | fallback_num: int = 0, |
| | | audio_in: str = None, |
| | | calib_num: int = 200, |
| | | model_revision: str = None, |
| | | ): |
| | | self.set_all_random_seed(0) |
| | | |
| | | self.cache_dir = cache_dir |
| | | self.export_config = dict( |
| | | feats_dim=560, |
| | | onnx=False, |
| | | ) |
| | | |
| | | self.onnx = onnx |
| | | self.device = device |
| | | self.quant = quant |
| | | self.fallback_num = fallback_num |
| | | self.frontend = None |
| | | self.audio_in = audio_in |
| | | self.calib_num = calib_num |
| | | self.model_revision = model_revision |
| | | |
| | | def _export( |
| | | self, |
| | | model, |
| | | model_dir: str = None, |
| | | verbose: bool = False, |
| | | ): |
| | | |
| | | export_dir = model_dir |
| | | os.makedirs(export_dir, exist_ok=True) |
| | | |
| | | self.export_config["model_name"] = "model" |
| | | model = get_model( |
| | | model, |
| | | self.export_config, |
| | | ) |
| | | model.eval() |
| | | |
| | | if self.onnx: |
| | | self._export_onnx(model, verbose, export_dir) |
| | | |
| | | print("output dir: {}".format(export_dir)) |
| | | |
| | | def _export_onnx(self, model, verbose, path): |
| | | model._export_onnx(verbose, path) |
| | | |
| | | def set_all_random_seed(self, seed: int): |
| | | random.seed(seed) |
| | | np.random.seed(seed) |
| | | torch.random.manual_seed(seed) |
| | | |
| | | def parse_audio_in(self, audio_in): |
| | | |
| | | wav_list, name_list = [], [] |
| | | if audio_in.endswith(".scp"): |
| | | f = open(audio_in, 'r') |
| | | lines = f.readlines()[:self.calib_num] |
| | | for line in lines: |
| | | name, path = line.strip().split() |
| | | name_list.append(name) |
| | | wav_list.append(path) |
| | | else: |
| | | wav_list = [audio_in,] |
| | | name_list = ["test",] |
| | | return wav_list, name_list |
| | | |
| | | def load_feats(self, audio_in: str = None): |
| | | import torchaudio |
| | | |
| | | wav_list, name_list = self.parse_audio_in(audio_in) |
| | | feats = [] |
| | | feats_len = [] |
| | | for line in wav_list: |
| | | path = line.strip() |
| | | waveform, sampling_rate = torchaudio.load(path) |
| | | if sampling_rate != self.frontend.fs: |
| | | waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, |
| | | new_freq=self.frontend.fs)(waveform) |
| | | fbank, fbank_len = self.frontend(waveform, [waveform.size(1)]) |
| | | feats.append(fbank) |
| | | feats_len.append(fbank_len) |
| | | return feats, feats_len |
| | | |
| | | def export(self, |
| | | mode: str = None, |
| | | ): |
| | | |
| | | if mode.startswith('conformer'): |
| | | from funasr.tasks.asr import ASRTask |
| | | config = os.path.join(model_dir, 'config.yaml') |
| | | model_file = os.path.join(model_dir, 'model.pb') |
| | | cmvn_file = os.path.join(model_dir, 'am.mvn') |
| | | model, asr_train_args = ASRTask.build_model_from_file( |
| | | config, model_file, cmvn_file, 'cpu' |
| | | ) |
| | | self.frontend = model.frontend |
| | | self.export_config["feats_dim"] = 560 |
| | | |
| | | self._export(model, self.cache_dir) |
| | | |
| | | if __name__ == '__main__': |
| | | import argparse |
| | | parser = argparse.ArgumentParser() |
| | | # parser.add_argument('--model-name', type=str, required=True) |
| | | parser.add_argument('--model-name', type=str, action="append", required=True, default=[]) |
| | | parser.add_argument('--export-dir', type=str, required=True) |
| | | parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]') |
| | | parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]') |
| | | parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model') |
| | | parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number') |
| | | parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]') |
| | | parser.add_argument('--calib_num', type=int, default=200, help='calib max num') |
| | | parser.add_argument('--model_revision', type=str, default=None, help='model_revision') |
| | | args = parser.parse_args() |
| | | |
| | | export_model = ModelExport( |
| | | cache_dir=args.export_dir, |
| | | onnx=args.type == 'onnx', |
| | | device=args.device, |
| | | quant=args.quantize, |
| | | fallback_num=args.fallback_num, |
| | | audio_in=args.audio_in, |
| | | calib_num=args.calib_num, |
| | | model_revision=args.model_revision, |
| | | ) |
| | | for model_name in args.model_name: |
| | | print("export model: {}".format(model_name)) |
| | | export_model.export(model_name) |
| | |
| | | from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer |
| | | from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export |
| | | from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export |
| | | from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export |
| | | |
| | | from funasr.models.e2e_vad import E2EVadModel |
| | | from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export |
| | | from funasr.models.target_delay_transformer import TargetDelayTransformer |
| | |
| | | return BiCifParaformer_export(model, **export_config) |
| | | elif isinstance(model, Paraformer): |
| | | return Paraformer_export(model, **export_config) |
| | | elif isinstance(model, Conformer_export): |
| | | return Conformer_export(model, **export_config) |
| | | elif isinstance(model, E2EVadModel): |
| | | return E2EVadModel_export(model, **export_config) |
| | | elif isinstance(model, PunctuationModel): |
| New file |
| | |
| | | import os |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.modules.attention import MultiHeadedAttention |
| | | |
| | | from funasr.export.models.modules.decoder_layer import DecoderLayer as OnnxDecoderLayer |
| | | from funasr.export.models.language_models.embed import Embedding |
| | | from funasr.export.models.modules.multihead_att import \ |
| | | OnnxMultiHeadedAttention |
| | | |
| | | from funasr.export.utils.torch_function import MakePadMask, subsequent_mask |
| | | |
| | | class XformerDecoder(nn.Module): |
| | | def __init__(self, |
| | | model, |
| | | max_seq_len = 512, |
| | | model_name = 'decoder', |
| | | onnx: bool = True,): |
| | | super().__init__() |
| | | self.embed = Embedding(model.embed, max_seq_len) |
| | | self.model = model |
| | | if onnx: |
| | | self.make_pad_mask = MakePadMask(max_seq_len, flip=False) |
| | | else: |
| | | self.make_pad_mask = subsequent_mask(max_seq_len, flip=False) |
| | | |
| | | if isinstance(self.model.decoders[0].self_attn, MultiHeadedAttention): |
| | | self.num_heads = self.model.decoders[0].self_attn.h |
| | | self.hidden_size = self.model.decoders[0].self_attn.linear_out.out_features |
| | | |
| | | # replace multi-head attention module into customized module. |
| | | for i, d in enumerate(self.model.decoders): |
| | | # d is DecoderLayer |
| | | if isinstance(d.self_attn, MultiHeadedAttention): |
| | | d.self_attn = OnnxMultiHeadedAttention(d.self_attn) |
| | | if isinstance(d.src_attn, MultiHeadedAttention): |
| | | d.src_attn = OnnxMultiHeadedAttention(d.src_attn) |
| | | self.model.decoders[i] = OnnxDecoderLayer(d) |
| | | |
| | | self.model_name = model_name |
| | | |
| | | def prepare_mask(self, mask): |
| | | mask_3d_btd = mask[:, :, None] |
| | | if len(mask.shape) == 2: |
| | | mask_4d_bhlt = 1 - mask[:, None, None, :] |
| | | elif len(mask.shape) == 3: |
| | | mask_4d_bhlt = 1 - mask[:, None, :] |
| | | |
| | | mask_4d_bhlt = mask_4d_bhlt * -10000.0 |
| | | return mask_3d_btd, mask_4d_bhlt |
| | | |
| | | def forward(self, |
| | | tgt, |
| | | memory, |
| | | cache): |
| | | |
| | | mask = subsequent_mask(tgt.size(-1)).unsqueeze(0) # (B, T) |
| | | |
| | | x = self.embed(tgt) |
| | | mask = self.prepare_mask(mask) |
| | | new_cache = [] |
| | | for c, decoder in zip(cache, self.model.decoders): |
| | | x, mask = decoder(x, mask, memory, None, c) |
| | | new_cache.append(x) |
| | | x = x[:, 1:, :] |
| | | |
| | | if self.model.normalize_before: |
| | | y = self.model.after_norm(x[:, -1]) |
| | | else: |
| | | y = x[:, -1] |
| | | |
| | | if self.model.output_layer is not None: |
| | | y = torch.log_softmax(self.model.output_layer(y), dim=-1) |
| | | return y, new_cache |
| | | |
| | | def get_dummy_inputs(self, enc_size): |
| | | tgt = torch.LongTensor([0]).unsqueeze(0) |
| | | memory = torch.randn(1, 100, enc_size) |
| | | cache_num = len(self.model.decoders) |
| | | cache = [ |
| | | torch.zeros((1, 1, self.model.decoders[0].size)) |
| | | for _ in range(cache_num) |
| | | ] |
| | | return (tgt, memory, cache) |
| | | |
| | | def is_optimizable(self): |
| | | return True |
| | | |
| | | def get_input_names(self): |
| | | cache_num = len(self.model.decoders) |
| | | return ["tgt", "memory"] + [ |
| | | "cache_%d" % i for i in range(cache_num) |
| | | ] |
| | | |
| | | def get_output_names(self): |
| | | cache_num = len(self.model.decoders) |
| | | return ["y"] + ["out_cache_%d" % i for i in range(cache_num)] |
| | | |
| | | def get_dynamic_axes(self): |
| | | ret = { |
| | | "tgt": {0: "tgt_batch", 1: "tgt_length"}, |
| | | "memory": {0: "memory_batch", 1: "memory_length"}, |
| | | } |
| | | cache_num = len(self.model.decoders) |
| | | ret.update( |
| | | { |
| | | "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d} |
| | | for d in range(cache_num) |
| | | } |
| | | ) |
| | | return ret |
| | | |
| | | def get_model_config(self, path): |
| | | return { |
| | | "dec_type": "XformerDecoder", |
| | | "model_path": os.path.join(path, f"{self.model_name}.onnx"), |
| | | "n_layers": len(self.model.decoders), |
| | | "odim": self.model.decoders[0].size, |
| | | } |
| New file |
| | |
| | | import os |
| | | import logging |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.export.utils.torch_function import MakePadMask |
| | | from funasr.export.utils.torch_function import sequence_mask |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder |
| | | from funasr.models.decoder.transformer_decoder import TransformerDecoder |
| | | from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export |
| | | from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export |
| | | |
| | | class Conformer(nn.Module): |
| | | """ |
| | | export conformer into onnx format |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | model, |
| | | max_seq_len=512, |
| | | feats_dim=560, |
| | | model_name='model', |
| | | **kwargs, |
| | | ): |
| | | super().__init__() |
| | | onnx = False |
| | | if "onnx" in kwargs: |
| | | onnx = kwargs["onnx"] |
| | | if isinstance(model.encoder, ConformerEncoder): |
| | | self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx) |
| | | elif isinstance(model.decoder, TransformerDecoder): |
| | | self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx) |
| | | |
| | | self.feats_dim = feats_dim |
| | | self.model_name = model_name |
| | | |
| | | if onnx: |
| | | self.make_pad_mask = MakePadMask(max_seq_len, flip=False) |
| | | else: |
| | | self.make_pad_mask = sequence_mask(max_seq_len, flip=False) |
| | | |
| | | def _export_model(self, model, verbose, path): |
| | | dummy_input = model.get_dummy_inputs() |
| | | model_script = model |
| | | model_path = os.path.join(path, f'{model.model_name}.onnx') |
| | | if not os.path.exists(model_path): |
| | | torch.onnx.export( |
| | | model_script, |
| | | dummy_input, |
| | | model_path, |
| | | verbose=verbose, |
| | | opset_version=14, |
| | | input_names=model.get_input_names(), |
| | | output_names=model.get_output_names(), |
| | | dynamic_axes=model.get_dynamic_axes() |
| | | ) |
| | | |
| | | def _export_encoder_onnx(self, verbose, path): |
| | | model_encoder = self.encoder |
| | | self._export_model(model_encoder, verbose, path) |
| | | |
| | | def _export_decoder_onnx(self, verbose, path): |
| | | model_decoder = self.decoder |
| | | self._export_model(model_decoder, verbose, path) |
| | | |
| | | def _export_onnx(self, verbose, path): |
| | | self._export_encoder_onnx(verbose, path) |
| | | self._export_decoder_onnx(verbose, path) |
| New file |
| | |
| | | """Positional Encoding Module.""" |
| | | |
| | | import math |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | from funasr.modules.embedding import ( |
| | | LegacyRelPositionalEncoding, PositionalEncoding, RelPositionalEncoding, |
| | | ScaledPositionalEncoding, StreamPositionalEncoding) |
| | | from funasr.modules.subsampling import ( |
| | | Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, |
| | | Conv2dSubsampling8) |
| | | from funasr.modules.subsampling_without_posenc import \ |
| | | Conv2dSubsamplingWOPosEnc |
| | | |
| | | from funasr.export.models.language_models.subsampling import ( |
| | | OnnxConv2dSubsampling, OnnxConv2dSubsampling2, OnnxConv2dSubsampling6, |
| | | OnnxConv2dSubsampling8) |
| | | |
| | | |
| | | def get_pos_emb(pos_emb, max_seq_len=512, use_cache=True): |
| | | if isinstance(pos_emb, LegacyRelPositionalEncoding): |
| | | return OnnxLegacyRelPositionalEncoding(pos_emb, max_seq_len, use_cache) |
| | | elif isinstance(pos_emb, ScaledPositionalEncoding): |
| | | return OnnxScaledPositionalEncoding(pos_emb, max_seq_len, use_cache) |
| | | elif isinstance(pos_emb, RelPositionalEncoding): |
| | | return OnnxRelPositionalEncoding(pos_emb, max_seq_len, use_cache) |
| | | elif isinstance(pos_emb, PositionalEncoding): |
| | | return OnnxPositionalEncoding(pos_emb, max_seq_len, use_cache) |
| | | elif isinstance(pos_emb, StreamPositionalEncoding): |
| | | return OnnxStreamPositionalEncoding(pos_emb, max_seq_len, use_cache) |
| | | elif (isinstance(pos_emb, nn.Sequential) and len(pos_emb) == 0) or ( |
| | | isinstance(pos_emb, Conv2dSubsamplingWOPosEnc) |
| | | ): |
| | | return pos_emb |
| | | else: |
| | | raise ValueError("Embedding model is not supported.") |
| | | |
| | | |
| | | class Embedding(nn.Module): |
| | | def __init__(self, model, max_seq_len=512, use_cache=True): |
| | | super().__init__() |
| | | self.model = model |
| | | if not isinstance(model, nn.Embedding): |
| | | if isinstance(model, Conv2dSubsampling): |
| | | self.model = OnnxConv2dSubsampling(model) |
| | | self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len) |
| | | elif isinstance(model, Conv2dSubsampling2): |
| | | self.model = OnnxConv2dSubsampling2(model) |
| | | self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len) |
| | | elif isinstance(model, Conv2dSubsampling6): |
| | | self.model = OnnxConv2dSubsampling6(model) |
| | | self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len) |
| | | elif isinstance(model, Conv2dSubsampling8): |
| | | self.model = OnnxConv2dSubsampling8(model) |
| | | self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len) |
| | | else: |
| | | self.model[-1] = get_pos_emb(model[-1], max_seq_len) |
| | | |
| | | def forward(self, x, mask=None): |
| | | if mask is None: |
| | | return self.model(x) |
| | | else: |
| | | return self.model(x, mask) |
| | | |
| | | |
| | | def _pre_hook( |
| | | state_dict, |
| | | prefix, |
| | | local_metadata, |
| | | strict, |
| | | missing_keys, |
| | | unexpected_keys, |
| | | error_msgs, |
| | | ): |
| | | """Perform pre-hook in load_state_dict for backward compatibility. |
| | | |
| | | Note: |
| | | We saved self.pe until v.0.5.2 but we have omitted it later. |
| | | Therefore, we remove the item "pe" from `state_dict` for backward compatibility. |
| | | |
| | | """ |
| | | k = prefix + "pe" |
| | | if k in state_dict: |
| | | state_dict.pop(k) |
| | | |
| | | |
| | | class OnnxPositionalEncoding(torch.nn.Module): |
| | | """Positional encoding. |
| | | |
| | | Args: |
| | | d_model (int): Embedding dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | max_seq_len (int): Maximum input length. |
| | | reverse (bool): Whether to reverse the input position. Only for |
| | | the class LegacyRelPositionalEncoding. We remove it in the current |
| | | class RelPositionalEncoding. |
| | | """ |
| | | |
| | | def __init__(self, model, max_seq_len=512, reverse=False, use_cache=True): |
| | | """Construct an PositionalEncoding object.""" |
| | | super(OnnxPositionalEncoding, self).__init__() |
| | | self.d_model = model.d_model |
| | | self.reverse = reverse |
| | | self.max_seq_len = max_seq_len |
| | | self.xscale = math.sqrt(self.d_model) |
| | | self._register_load_state_dict_pre_hook(_pre_hook) |
| | | self.pe = model.pe |
| | | self.use_cache = use_cache |
| | | self.model = model |
| | | if self.use_cache: |
| | | self.extend_pe() |
| | | else: |
| | | self.div_term = torch.exp( |
| | | torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| | | * -(math.log(10000.0) / self.d_model) |
| | | ) |
| | | |
| | | def extend_pe(self): |
| | | """Reset the positional encodings.""" |
| | | pe_length = len(self.pe[0]) |
| | | if self.max_seq_len < pe_length: |
| | | self.pe = self.pe[:, : self.max_seq_len] |
| | | else: |
| | | self.model.extend_pe(torch.tensor(0.0).expand(1, self.max_seq_len)) |
| | | self.pe = self.model.pe |
| | | |
| | | def _add_pe(self, x): |
| | | """Computes positional encoding""" |
| | | if self.reverse: |
| | | position = torch.arange( |
| | | x.size(1) - 1, -1, -1.0, dtype=torch.float32 |
| | | ).unsqueeze(1) |
| | | else: |
| | | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
| | | |
| | | x = x * self.xscale |
| | | x[:, :, 0::2] += torch.sin(position * self.div_term) |
| | | x[:, :, 1::2] += torch.cos(position * self.div_term) |
| | | return x |
| | | |
| | | def forward(self, x: torch.Tensor): |
| | | """Add positional encoding. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | | |
| | | Returns: |
| | | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | | """ |
| | | if self.use_cache: |
| | | x = x * self.xscale + self.pe[:, : x.size(1)] |
| | | else: |
| | | x = self._add_pe(x) |
| | | return x |
| | | |
| | | |
| | | class OnnxScaledPositionalEncoding(OnnxPositionalEncoding): |
| | | """Scaled positional encoding module. |
| | | |
| | | See Sec. 3.2 https://arxiv.org/abs/1809.08895 |
| | | |
| | | Args: |
| | | d_model (int): Embedding dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | max_seq_len (int): Maximum input length. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, model, max_seq_len=512, use_cache=True): |
| | | """Initialize class.""" |
| | | super().__init__(model, max_seq_len, use_cache=use_cache) |
| | | self.alpha = torch.nn.Parameter(torch.tensor(1.0)) |
| | | |
| | | def reset_parameters(self): |
| | | """Reset parameters.""" |
| | | self.alpha.data = torch.tensor(1.0) |
| | | |
| | | def _add_pe(self, x): |
| | | """Computes positional encoding""" |
| | | if self.reverse: |
| | | position = torch.arange( |
| | | x.size(1) - 1, -1, -1.0, dtype=torch.float32 |
| | | ).unsqueeze(1) |
| | | else: |
| | | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
| | | |
| | | x = x * self.alpha |
| | | x[:, :, 0::2] += torch.sin(position * self.div_term) |
| | | x[:, :, 1::2] += torch.cos(position * self.div_term) |
| | | return x |
| | | |
| | | def forward(self, x): |
| | | """Add positional encoding. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | | |
| | | Returns: |
| | | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | | |
| | | """ |
| | | if self.use_cache: |
| | | x = x + self.alpha * self.pe[:, : x.size(1)] |
| | | else: |
| | | x = self._add_pe(x) |
| | | return x |
| | | |
| | | |
| | | class OnnxLegacyRelPositionalEncoding(OnnxPositionalEncoding): |
| | | """Relative positional encoding module (old version). |
| | | |
| | | Details can be found in https://github.com/espnet/espnet/pull/2816. |
| | | |
| | | See : Appendix B in https://arxiv.org/abs/1901.02860 |
| | | |
| | | Args: |
| | | d_model (int): Embedding dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | max_seq_len (int): Maximum input length. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, model, max_seq_len=512, use_cache=True): |
| | | """Initialize class.""" |
| | | super().__init__(model, max_seq_len, reverse=True, use_cache=use_cache) |
| | | |
| | | def _get_pe(self, x): |
| | | """Computes positional encoding""" |
| | | if self.reverse: |
| | | position = torch.arange( |
| | | x.size(1) - 1, -1, -1.0, dtype=torch.float32 |
| | | ).unsqueeze(1) |
| | | else: |
| | | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
| | | |
| | | pe = torch.zeros(x.shape) |
| | | pe[:, :, 0::2] += torch.sin(position * self.div_term) |
| | | pe[:, :, 1::2] += torch.cos(position * self.div_term) |
| | | return pe |
| | | |
| | | def forward(self, x): |
| | | """Compute positional encoding. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | | |
| | | Returns: |
| | | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | | torch.Tensor: Positional embedding tensor (1, time, `*`). |
| | | |
| | | """ |
| | | x = x * self.xscale |
| | | if self.use_cache: |
| | | pos_emb = self.pe[:, : x.size(1)] |
| | | else: |
| | | pos_emb = self._get_pe(x) |
| | | return x, pos_emb |
| | | |
| | | |
| | | class OnnxRelPositionalEncoding(torch.nn.Module): |
| | | """Relative positional encoding module (new implementation). |
| | | Details can be found in https://github.com/espnet/espnet/pull/2816. |
| | | See : Appendix B in https://arxiv.org/abs/1901.02860 |
| | | Args: |
| | | d_model (int): Embedding dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | max_seq_len (int): Maximum input length. |
| | | """ |
| | | |
| | | def __init__(self, model, max_seq_len=512, use_cache=True): |
| | | """Construct an PositionalEncoding object.""" |
| | | super(OnnxRelPositionalEncoding, self).__init__() |
| | | self.d_model = model.d_model |
| | | self.xscale = math.sqrt(self.d_model) |
| | | self.pe = None |
| | | self.use_cache = use_cache |
| | | if self.use_cache: |
| | | self.extend_pe(torch.tensor(0.0).expand(1, max_seq_len)) |
| | | else: |
| | | self.div_term = torch.exp( |
| | | torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| | | * -(math.log(10000.0) / self.d_model) |
| | | ) |
| | | |
| | | def extend_pe(self, x): |
| | | """Reset the positional encodings.""" |
| | | if self.pe is not None and self.pe.size(1) >= x.size(1) * 2 - 1: |
| | | # self.pe contains both positive and negative parts |
| | | # the length of self.pe is 2 * input_len - 1 |
| | | if self.pe.dtype != x.dtype or self.pe.device != x.device: |
| | | self.pe = self.pe.to(dtype=x.dtype, device=x.device) |
| | | return |
| | | # Suppose `i` means to the position of query vecotr and `j` means the |
| | | # position of key vector. We use position relative positions when keys |
| | | # are to the left (i>j) and negative relative positions otherwise (i<j). |
| | | pe_positive = torch.zeros(x.size(1), self.d_model) |
| | | pe_negative = torch.zeros(x.size(1), self.d_model) |
| | | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
| | | div_term = torch.exp( |
| | | torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| | | * -(math.log(10000.0) / self.d_model) |
| | | ) |
| | | pe_positive[:, 0::2] = torch.sin(position * div_term) |
| | | pe_positive[:, 1::2] = torch.cos(position * div_term) |
| | | pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) |
| | | pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) |
| | | |
| | | # Reserve the order of positive indices and concat both positive and |
| | | # negative indices. This is used to support the shifting trick |
| | | # as in https://arxiv.org/abs/1901.02860 |
| | | pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) |
| | | pe_negative = pe_negative[1:].unsqueeze(0) |
| | | pe = torch.cat([pe_positive, pe_negative], dim=1) |
| | | self.pe = pe.to(device=x.device, dtype=x.dtype) |
| | | |
| | | def _get_pe(self, x): |
| | | pe_positive = torch.zeros(x.size(1), self.d_model) |
| | | pe_negative = torch.zeros(x.size(1), self.d_model) |
| | | theta = ( |
| | | torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) * self.div_term |
| | | ) |
| | | pe_positive[:, 0::2] = torch.sin(theta) |
| | | pe_positive[:, 1::2] = torch.cos(theta) |
| | | pe_negative[:, 0::2] = -1 * torch.sin(theta) |
| | | pe_negative[:, 1::2] = torch.cos(theta) |
| | | |
| | | # Reserve the order of positive indices and concat both positive and |
| | | # negative indices. This is used to support the shifting trick |
| | | # as in https://arxiv.org/abs/1901.02860 |
| | | pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) |
| | | pe_negative = pe_negative[1:].unsqueeze(0) |
| | | return torch.cat([pe_positive, pe_negative], dim=1) |
| | | |
| | | def forward(self, x: torch.Tensor, use_cache=True): |
| | | """Add positional encoding. |
| | | Args: |
| | | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | | Returns: |
| | | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | | """ |
| | | x = x * self.xscale |
| | | if self.use_cache: |
| | | pos_emb = self.pe[ |
| | | :, |
| | | self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1), |
| | | ] |
| | | else: |
| | | pos_emb = self._get_pe(x) |
| | | return x, pos_emb |
| | | |
| | | |
| | | class OnnxStreamPositionalEncoding(torch.nn.Module): |
| | | """Streaming Positional encoding.""" |
| | | |
| | | def __init__(self, model, max_seq_len=5000, use_cache=True): |
| | | """Construct an PositionalEncoding object.""" |
| | | super(StreamPositionalEncoding, self).__init__() |
| | | self.use_cache = use_cache |
| | | self.d_model = model.d_model |
| | | self.xscale = model.xscale |
| | | self.pe = model.pe |
| | | self.use_cache = use_cache |
| | | self.max_seq_len = max_seq_len |
| | | if self.use_cache: |
| | | self.extend_pe() |
| | | else: |
| | | self.div_term = torch.exp( |
| | | torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| | | * -(math.log(10000.0) / self.d_model) |
| | | ) |
| | | self._register_load_state_dict_pre_hook(_pre_hook) |
| | | |
| | | def extend_pe(self): |
| | | """Reset the positional encodings.""" |
| | | pe_length = len(self.pe[0]) |
| | | if self.max_seq_len < pe_length: |
| | | self.pe = self.pe[:, : self.max_seq_len] |
| | | else: |
| | | self.model.extend_pe(self.max_seq_len) |
| | | self.pe = self.model.pe |
| | | |
| | | def _add_pe(self, x, start_idx): |
| | | position = torch.arange(start_idx, x.size(1), dtype=torch.float32).unsqueeze(1) |
| | | x = x * self.xscale |
| | | x[:, :, 0::2] += torch.sin(position * self.div_term) |
| | | x[:, :, 1::2] += torch.cos(position * self.div_term) |
| | | return x |
| | | |
| | | def forward(self, x: torch.Tensor, start_idx: int = 0): |
| | | """Add positional encoding. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | | |
| | | Returns: |
| | | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | | |
| | | """ |
| | | if self.use_cache: |
| | | return x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)] |
| | | else: |
| | | return self._add_pe(x, start_idx) |
| New file |
| | |
| | | import os |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | class SequentialRNNLM(nn.Module): |
| | | def __init__(self, model, **kwargs): |
| | | super().__init__() |
| | | self.encoder = model.encoder |
| | | self.rnn = model.rnn |
| | | self.rnn_type = model.rnn_type |
| | | self.decoder = model.decoder |
| | | self.nlayers = model.nlayers |
| | | self.nhid = model.nhid |
| | | self.model_name = "seq_rnnlm" |
| | | |
| | | def forward(self, y, hidden1, hidden2=None): |
| | | # batch_score function. |
| | | emb = self.encoder(y) |
| | | if self.rnn_type == "LSTM": |
| | | output, (hidden1, hidden2) = self.rnn(emb, (hidden1, hidden2)) |
| | | else: |
| | | output, hidden1 = self.rnn(emb, hidden1) |
| | | |
| | | decoded = self.decoder( |
| | | output.contiguous().view(output.size(0) * output.size(1), output.size(2)) |
| | | ) |
| | | if self.rnn_type == "LSTM": |
| | | return ( |
| | | decoded.view(output.size(0), output.size(1), decoded.size(1)), |
| | | hidden1, |
| | | hidden2, |
| | | ) |
| | | else: |
| | | return ( |
| | | decoded.view(output.size(0), output.size(1), decoded.size(1)), |
| | | hidden1, |
| | | ) |
| | | |
| | | def get_dummy_inputs(self): |
| | | tgt = torch.LongTensor([0, 1]).unsqueeze(0) |
| | | hidden = torch.randn(self.nlayers, 1, self.nhid) |
| | | if self.rnn_type == "LSTM": |
| | | return (tgt, hidden, hidden) |
| | | else: |
| | | return (tgt, hidden) |
| | | |
| | | def get_input_names(self): |
| | | if self.rnn_type == "LSTM": |
| | | return ["x", "in_hidden1", "in_hidden2"] |
| | | else: |
| | | return ["x", "in_hidden1"] |
| | | |
| | | def get_output_names(self): |
| | | if self.rnn_type == "LSTM": |
| | | return ["y", "out_hidden1", "out_hidden2"] |
| | | else: |
| | | return ["y", "out_hidden1"] |
| | | |
| | | def get_dynamic_axes(self): |
| | | ret = { |
| | | "x": {0: "x_batch", 1: "x_length"}, |
| | | "y": {0: "y_batch"}, |
| | | "in_hidden1": {1: "hidden1_batch"}, |
| | | "out_hidden1": {1: "out_hidden1_batch"}, |
| | | } |
| | | if self.rnn_type == "LSTM": |
| | | ret.update( |
| | | { |
| | | "in_hidden2": {1: "hidden2_batch"}, |
| | | "out_hidden2": {1: "out_hidden2_batch"}, |
| | | } |
| | | ) |
| | | return ret |
| | | |
| | | def get_model_config(self, path): |
| | | return { |
| | | "use_lm": True, |
| | | "model_path": os.path.join(path, f"{self.model_name}.onnx"), |
| | | "lm_type": "SequentialRNNLM", |
| | | "rnn_type": self.rnn_type, |
| | | "nhid": self.nhid, |
| | | "nlayers": self.nlayers, |
| | | } |
| New file |
| | |
| | | """Subsampling layer definition.""" |
| | | |
| | | import torch |
| | | |
| | | |
| | | class OnnxConv2dSubsampling(torch.nn.Module): |
| | | """Convolutional 2D subsampling (to 1/4 length). |
| | | |
| | | Args: |
| | | idim (int): Input dimension. |
| | | odim (int): Output dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | pos_enc (torch.nn.Module): Custom position encoding layer. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, model): |
| | | """Construct an Conv2dSubsampling object.""" |
| | | super().__init__() |
| | | self.conv = model.conv |
| | | self.out = model.out |
| | | |
| | | def forward(self, x, x_mask): |
| | | """Subsample x. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (#batch, time, idim). |
| | | x_mask (torch.Tensor): Input mask (#batch, 1, time). |
| | | |
| | | Returns: |
| | | torch.Tensor: Subsampled tensor (#batch, time', odim), |
| | | where time' = time // 4. |
| | | torch.Tensor: Subsampled mask (#batch, 1, time'), |
| | | where time' = time // 4. |
| | | |
| | | """ |
| | | x = x.unsqueeze(1) # (b, c, t, f) |
| | | x = self.conv(x) |
| | | b, c, t, f = x.size() |
| | | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) |
| | | if x_mask is None: |
| | | return x, None |
| | | return x, x_mask[:, :-2:2][:, :-2:2] |
| | | |
| | | def __getitem__(self, key): |
| | | """Get item. |
| | | |
| | | When reset_parameters() is called, if use_scaled_pos_enc is used, |
| | | return the positioning encoding. |
| | | |
| | | """ |
| | | if key != -1: |
| | | raise NotImplementedError("Support only `-1` (for `reset_parameters`).") |
| | | return self.out[key] |
| | | |
| | | |
| | | class OnnxConv2dSubsampling2(torch.nn.Module): |
| | | """Convolutional 2D subsampling (to 1/2 length). |
| | | |
| | | Args: |
| | | idim (int): Input dimension. |
| | | odim (int): Output dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | pos_enc (torch.nn.Module): Custom position encoding layer. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, model): |
| | | """Construct an Conv2dSubsampling object.""" |
| | | super().__init__() |
| | | self.conv = model.conv |
| | | self.out = model.out |
| | | |
| | | def forward(self, x, x_mask): |
| | | """Subsample x. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (#batch, time, idim). |
| | | x_mask (torch.Tensor): Input mask (#batch, 1, time). |
| | | |
| | | Returns: |
| | | torch.Tensor: Subsampled tensor (#batch, time', odim), |
| | | where time' = time // 2. |
| | | torch.Tensor: Subsampled mask (#batch, 1, time'), |
| | | where time' = time // 2. |
| | | |
| | | """ |
| | | x = x.unsqueeze(1) # (b, c, t, f) |
| | | x = self.conv(x) |
| | | b, c, t, f = x.size() |
| | | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) |
| | | if x_mask is None: |
| | | return x, None |
| | | return x, x_mask[:, :-2:2][:, :-2:1] |
| | | |
| | | def __getitem__(self, key): |
| | | """Get item. |
| | | |
| | | When reset_parameters() is called, if use_scaled_pos_enc is used, |
| | | return the positioning encoding. |
| | | |
| | | """ |
| | | if key != -1: |
| | | raise NotImplementedError("Support only `-1` (for `reset_parameters`).") |
| | | return self.out[key] |
| | | |
| | | |
| | | class OnnxConv2dSubsampling6(torch.nn.Module): |
| | | """Convolutional 2D subsampling (to 1/6 length). |
| | | |
| | | Args: |
| | | idim (int): Input dimension. |
| | | odim (int): Output dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | pos_enc (torch.nn.Module): Custom position encoding layer. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, model): |
| | | """Construct an Conv2dSubsampling object.""" |
| | | super().__init__() |
| | | self.conv = model.conv |
| | | self.out = model.out |
| | | |
| | | def forward(self, x, x_mask): |
| | | """Subsample x. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (#batch, time, idim). |
| | | x_mask (torch.Tensor): Input mask (#batch, 1, time). |
| | | |
| | | Returns: |
| | | torch.Tensor: Subsampled tensor (#batch, time', odim), |
| | | where time' = time // 6. |
| | | torch.Tensor: Subsampled mask (#batch, 1, time'), |
| | | where time' = time // 6. |
| | | |
| | | """ |
| | | x = x.unsqueeze(1) # (b, c, t, f) |
| | | x = self.conv(x) |
| | | b, c, t, f = x.size() |
| | | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) |
| | | if x_mask is None: |
| | | return x, None |
| | | return x, x_mask[:, :-2:2][:, :-4:3] |
| | | |
| | | |
| | | class OnnxConv2dSubsampling8(torch.nn.Module): |
| | | """Convolutional 2D subsampling (to 1/8 length). |
| | | |
| | | Args: |
| | | idim (int): Input dimension. |
| | | odim (int): Output dimension. |
| | | dropout_rate (float): Dropout rate. |
| | | pos_enc (torch.nn.Module): Custom position encoding layer. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, model): |
| | | """Construct an Conv2dSubsampling object.""" |
| | | super().__init__() |
| | | self.conv = model.conv |
| | | self.out = model.out |
| | | |
| | | def forward(self, x, x_mask): |
| | | """Subsample x. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (#batch, time, idim). |
| | | x_mask (torch.Tensor): Input mask (#batch, 1, time). |
| | | |
| | | Returns: |
| | | torch.Tensor: Subsampled tensor (#batch, time', odim), |
| | | where time' = time // 8. |
| | | torch.Tensor: Subsampled mask (#batch, 1, time'), |
| | | where time' = time // 8. |
| | | |
| | | """ |
| | | x = x.unsqueeze(1) # (b, c, t, f) |
| | | x = self.conv(x) |
| | | b, c, t, f = x.size() |
| | | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) |
| | | if x_mask is None: |
| | | return x, None |
| | | return x, x_mask[:, :-2:2][:, :-2:2][:, :-2:2] |
| New file |
| | |
| | | import os |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | from funasr.modules.vgg2l import import VGG2L |
| | | from funasr.modules.attention import MultiHeadedAttention |
| | | from funasr.modules.subsampling import ( |
| | | Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8) |
| | | |
| | | from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as OnnxEncoderLayer |
| | | from funasr.export.models.language_models.embed import Embedding |
| | | from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention |
| | | |
| | | from funasr.export.utils.torch_function import MakePadMask |
| | | |
| | | class TransformerLM(nn.Module, AbsExportModel): |
| | | def __init__(self, model, max_seq_len=512, **kwargs): |
| | | super().__init__() |
| | | self.embed = Embedding(model.embed, max_seq_len) |
| | | self.encoder = model.encoder |
| | | self.decoder = model.decoder |
| | | self.make_pad_mask = MakePadMask(max_seq_len, flip=False) |
| | | # replace multihead attention module into customized module. |
| | | for i, d in enumerate(self.encoder.encoders): |
| | | # d is EncoderLayer |
| | | if isinstance(d.self_attn, MultiHeadedAttention): |
| | | d.self_attn = OnnxMultiHeadedAttention(d.self_attn) |
| | | self.encoder.encoders[i] = OnnxEncoderLayer(d) |
| | | |
| | | self.model_name = "transformer_lm" |
| | | self.num_heads = self.encoder.encoders[0].self_attn.h |
| | | self.hidden_size = self.encoder.encoders[0].self_attn.linear_out.out_features |
| | | |
| | | def prepare_mask(self, mask): |
| | | if len(mask.shape) == 2: |
| | | mask = mask[:, None, None, :] |
| | | elif len(mask.shape) == 3: |
| | | mask = mask[:, None, :] |
| | | mask = 1 - mask |
| | | return mask * -10000.0 |
| | | |
| | | def forward(self, y, cache): |
| | | feats_length = torch.ones(y.shape).sum(dim=-1).type(torch.long) |
| | | mask = self.make_pad_mask(feats_length) # (B, T) |
| | | mask = (y != 0) * mask |
| | | |
| | | xs = self.embed(y) |
| | | # forward_one_step of Encoder |
| | | if isinstance( |
| | | self.encoder.embed, |
| | | (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L), |
| | | ): |
| | | xs, mask = self.encoder.embed(xs, mask) |
| | | else: |
| | | xs = self.encoder.embed(xs) |
| | | |
| | | new_cache = [] |
| | | mask = self.prepare_mask(mask) |
| | | for c, e in zip(cache, self.encoder.encoders): |
| | | xs, mask = e(xs, mask, c) |
| | | new_cache.append(xs) |
| | | |
| | | if self.encoder.normalize_before: |
| | | xs = self.encoder.after_norm(xs) |
| | | |
| | | h = self.decoder(xs[:, -1]) |
| | | return h, new_cache |
| | | |
| | | def get_dummy_inputs(self): |
| | | tgt = torch.LongTensor([1]).unsqueeze(0) |
| | | cache = [ |
| | | torch.zeros((1, 1, self.encoder.encoders[0].size)) |
| | | for _ in range(len(self.encoder.encoders)) |
| | | ] |
| | | return (tgt, cache) |
| | | |
| | | def is_optimizable(self): |
| | | return True |
| | | |
| | | def get_input_names(self): |
| | | return ["tgt"] + ["cache_%d" % i for i in range(len(self.encoder.encoders))] |
| | | |
| | | def get_output_names(self): |
| | | return ["y"] + ["out_cache_%d" % i for i in range(len(self.encoder.encoders))] |
| | | |
| | | def get_dynamic_axes(self): |
| | | ret = {"tgt": {0: "tgt_batch", 1: "tgt_length"}} |
| | | ret.update( |
| | | { |
| | | "cache_%d" % d: {0: "cache_%d_batch" % d, 1: "cache_%d_length" % d} |
| | | for d in range(len(self.encoder.encoders)) |
| | | } |
| | | ) |
| | | ret.update( |
| | | { |
| | | "out_cache_%d" |
| | | % d: {0: "out_cache_%d_batch" % d, 1: "out_cache_%d_length" % d} |
| | | for d in range(len(self.encoder.encoders)) |
| | | } |
| | | ) |
| | | return ret |
| | | |
| | | def get_model_config(self, path): |
| | | return { |
| | | "use_lm": True, |
| | | "model_path": os.path.join(path, f"{self.model_name}.onnx"), |
| | | "lm_type": "TransformerLM", |
| | | "odim": self.encoder.encoders[0].size, |
| | | "nlayers": len(self.encoder.encoders), |
| | | } |
| New file |
| | |
| | | """VGG2L module definition for custom encoder.""" |
| | | |
| | | from typing import Tuple, Union |
| | | |
| | | import torch |
| | | |
| | | |
| | | class VGG2L(torch.nn.Module): |
| | | """VGG2L module for custom encoder. |
| | | |
| | | Args: |
| | | idim: Input dimension. |
| | | odim: Output dimension. |
| | | pos_enc: Positional encoding class. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, idim: int, odim: int, pos_enc: torch.nn.Module = None): |
| | | """Construct a VGG2L object.""" |
| | | super().__init__() |
| | | |
| | | self.vgg2l = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, 64, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(64, 64, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.MaxPool2d((3, 2)), |
| | | torch.nn.Conv2d(64, 128, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(128, 128, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.MaxPool2d((2, 2)), |
| | | ) |
| | | |
| | | if pos_enc is not None: |
| | | self.output = torch.nn.Sequential( |
| | | torch.nn.Linear(128 * ((idim // 2) // 2), odim), pos_enc |
| | | ) |
| | | else: |
| | | self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim) |
| | | |
| | | def forward( |
| | | self, feats: torch.Tensor, feats_mask: torch.Tensor |
| | | ) -> Union[ |
| | | Tuple[torch.Tensor, torch.Tensor], |
| | | Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], |
| | | ]: |
| | | """Forward VGG2L bottleneck. |
| | | |
| | | Args: |
| | | feats: Feature sequences. (B, F, D_feats) |
| | | feats_mask: Mask of feature sequences. (B, 1, F) |
| | | |
| | | Returns: |
| | | vgg_output: VGG output sequences. |
| | | (B, sub(F), D_out) or ((B, sub(F), D_out), (B, sub(F), D_att)) |
| | | vgg_mask: Mask of VGG output sequences. (B, 1, sub(F)) |
| | | |
| | | """ |
| | | feats = feats.unsqueeze(1) |
| | | vgg_output = self.vgg2l(feats) |
| | | |
| | | b, c, t, f = vgg_output.size() |
| | | |
| | | vgg_output = self.output( |
| | | vgg_output.transpose(1, 2).contiguous().view(b, t, c * f) |
| | | ) |
| | | |
| | | if feats_mask is not None: |
| | | vgg_mask = self.create_new_mask(feats_mask) |
| | | else: |
| | | vgg_mask = feats_mask |
| | | |
| | | return vgg_output, vgg_mask |
| | | |
| | | def create_new_mask(self, feats_mask: torch.Tensor) -> torch.Tensor: |
| | | """Create a subsampled mask of feature sequences. |
| | | |
| | | Args: |
| | | feats_mask: Mask of feature sequences. (B, 1, F) |
| | | |
| | | Returns: |
| | | vgg_mask: Mask of VGG2L output sequences. (B, 1, sub(F)) |
| | | |
| | | """ |
| | | vgg1_t_len = feats_mask.size(2) - (feats_mask.size(2) % 3) |
| | | vgg_mask = feats_mask[:, :, :vgg1_t_len][:, :, ::3] |
| | | |
| | | vgg2_t_len = vgg_mask.size(2) - (vgg_mask.size(2) % 2) |
| | | vgg_mask = vgg_mask[:, :, :vgg2_t_len][:, :, ::2] |
| | | |
| | | return vgg_mask |
| | |
| | | name="specaug", |
| | | classes=dict( |
| | | specaug=SpecAug, |
| | | specaug_lfr=SpecAugLFR, |
| | | specaug_lfr=FSpecAugLR, |
| | | ), |
| | | type_check=AbsSpecAug, |
| | | default=None, |