From 76fd90d23073f7d41250c50d7c92b423604c6ac1 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 13 三月 2023 16:09:41 +0800
Subject: [PATCH] add class TimestampPredictor in e2e
---
funasr/models/e2e_tp.py | 154 ++++++++++++++++++++++++++++++++++++++
funasr/tasks/asr.py | 81 ++++++++++++++++++++
funasr/bin/build_trainer.py | 4
3 files changed, 238 insertions(+), 1 deletions(-)
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
index 8dee758..94f7262 100644
--- a/funasr/bin/build_trainer.py
+++ b/funasr/bin/build_trainer.py
@@ -28,7 +28,9 @@
elif mode == "uniasr":
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
elif mode == "mfcca":
- from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
+ from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
+ elif mode == "tp":
+ from funasr.tasks.asr import ASRTaskAligner as ASRTask
else:
raise ValueError("Unknown mode: {}".format(mode))
parser = ASRTask.get_parser()
diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py
new file mode 100644
index 0000000..8808008
--- /dev/null
+++ b/funasr/models/e2e_tp.py
@@ -0,0 +1,154 @@
+import logging
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import torch
+import numpy as np
+from typeguard import check_argument_types
+
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.predictor.cif import mae_loss
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.modules.nets_utils import make_pad_mask, pad_list
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.predictor.cif import CifPredictorV3
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+class TimestampPredictor(AbsESPnetModel):
+ """
+ Author: Speech Lab, Alibaba Group, China
+ """
+
+ def __init__(
+ self,
+ frontend: Optional[AbsFrontend],
+ encoder: AbsEncoder,
+ predictor: CifPredictorV3,
+ predictor_bias: int = 0,
+ ):
+ assert check_argument_types()
+
+ super().__init__()
+ # note that eos is the same as sos (equivalent ID)
+
+ self.frontend = frontend
+ self.encoder = encoder
+ self.encoder.interctc_use_conditioning = False
+
+ self.predictor = predictor
+ self.predictor_bias = predictor_bias
+ self.criterion_pre = mae_loss()
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+ batch_size = speech.shape[0]
+ # for data-parallel
+ text = text[:, : text_lengths.max()]
+ speech = speech[:, :speech_lengths.max()]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, text = add_sos_eos(text, 1, 2, -1)
+ text_lengths = text_lengths + self.predictor_bias
+ _, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
+
+ # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+ loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
+
+ loss = loss_pre
+ stats = dict()
+
+ # Collect Attn branch stats
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 4. Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
+
+ return encoder_out, encoder_out_lens
+
+ def _extract_feats(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert speech_lengths.dim() == 1, speech_lengths.shape
+
+ # for data-parallel
+ speech = speech[:, : speech_lengths.max()]
+ if self.frontend is not None:
+ # Frontend
+ # e.g. STFT and Feature extract
+ # data_loader may send time-domain signal in this case
+ # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ # No frontend and no feature extract
+ feats, feats_lengths = speech, speech_lengths
+ return feats, feats_lengths
+
+ def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out,
+ encoder_out_mask,
+ token_num)
+ return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index bc89744..13898a6 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -40,6 +40,7 @@
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.abs_encoder import AbsEncoder
@@ -124,6 +125,7 @@
bicif_paraformer=BiCifParaformer,
contextual_paraformer=ContextualParaformer,
mfcca=MFCCA,
+ timestamp_predictor=TimestampPredictor,
),
type_check=AbsESPnetModel,
default="asr",
@@ -1245,6 +1247,85 @@
class ASRTaskAligner(ASRTaskParaformer):
+ # If you need more than one optimizers, change this value
+ num_optimizers: int = 1
+
+ # Add variable objects configurations
+ class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --model and --model_conf
+ model_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+ ]
+
+ # If you need to modify train() or eval() procedures, change Trainer class here
+ trainer = Trainer
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace):
+ assert check_argument_types()
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+ # 3. Predictor
+ predictor_class = predictor_choices.get_class(args.predictor)
+ predictor = predictor_class(**args.predictor_conf)
+
+ # 10. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("asr")
+
+ # 8. Build model
+ model = model_class(
+ frontend=frontend,
+ encoder=encoder,
+ predictor=predictor,
+ **args.model_conf,
+ )
+
+ # 11. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ assert check_return_type(model)
+ return model
+
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
--
Gitblit v1.9.1