From 276c443ba4b6574df682acd330a18a3830c9ef2c Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期五, 05 一月 2024 15:47:06 +0800
Subject: [PATCH] update monotonic aligner
---
/dev/null | 170 ------------------
funasr/models/monotonic_aligner/template.yaml | 115 ++++++++++++
examples/industrial_data_pretraining/monotonic_aligner/README_zh.md | 0
funasr/models/seaco_paraformer/template.yaml | 7
funasr/models/monotonic_aligner/__init__.py | 0
examples/industrial_data_pretraining/monotonic_aligner/demo.py | 0
funasr/models/monotonic_aligner/model.py | 202 ++++++++++++++++++++++
examples/industrial_data_pretraining/monotonic_aligner/infer.sh | 15 +
8 files changed, 338 insertions(+), 171 deletions(-)
diff --git a/examples/industrial_data_pretraining/tp_aligner/README_zh.md b/examples/industrial_data_pretraining/monotonic_aligner/README_zh.md
similarity index 100%
rename from examples/industrial_data_pretraining/tp_aligner/README_zh.md
rename to examples/industrial_data_pretraining/monotonic_aligner/README_zh.md
diff --git a/examples/industrial_data_pretraining/tp_aligner/demo.py b/examples/industrial_data_pretraining/monotonic_aligner/demo.py
similarity index 100%
rename from examples/industrial_data_pretraining/tp_aligner/demo.py
rename to examples/industrial_data_pretraining/monotonic_aligner/demo.py
diff --git a/examples/industrial_data_pretraining/monotonic_aligner/infer.sh b/examples/industrial_data_pretraining/monotonic_aligner/infer.sh
new file mode 100644
index 0000000..dc9d823
--- /dev/null
+++ b/examples/industrial_data_pretraining/monotonic_aligner/infer.sh
@@ -0,0 +1,15 @@
+
+# download model
+local_path_root=../modelscope_models
+mkdir -p ${local_path_root}
+local_path=${local_path_root}/speech_timestamp_prediction-v1-16k-offline
+# git clone https://www.modelscope.cn/damo/speech_timestamp_prediction-v1-16k-offline.git ${local_path}
+
+
+python funasr/bin/inference.py \
++model="${local_path}" \
++input='["/Users/shixian/code/modelscope_models/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav", "娆㈣繋澶у鏉ュ埌榄旀惌绀惧尯杩涜浣撻獙"]' \
++data_type='["sound", "text"]' \
++output_dir="../outputs/debug" \
++device="cpu" \
++batch_size=2
diff --git a/examples/industrial_data_pretraining/tp_aligner/infer.sh b/examples/industrial_data_pretraining/tp_aligner/infer.sh
deleted file mode 100644
index ded296f..0000000
--- a/examples/industrial_data_pretraining/tp_aligner/infer.sh
+++ /dev/null
@@ -1,17 +0,0 @@
-
-# download model
-local_path_root=../modelscope_models
-mkdir -p ${local_path_root}
-local_path=${local_path_root}/speech_timestamp_prediction-v1-16k-offline
-git clone https://www.modelscope.cn/damo/speech_timestamp_prediction-v1-16k-offline.git ${local_path}
-
-
-python funasr/bin/inference.py \
-+model="${local_path}" \
-+input='["/Users/zhifu/funasr_github/test_local/wav.scp", "/Users/zhifu/funasr_github/test_local/text.txt"]' \
-+data_type='["sound", "text"]' \
-+output_dir="./outputs/debug" \
-+device="cpu" \
-+batch_size=2 \
-+debug="true"
-
diff --git a/funasr/models/tp_aligner/__init__.py b/funasr/models/monotonic_aligner/__init__.py
similarity index 100%
rename from funasr/models/tp_aligner/__init__.py
rename to funasr/models/monotonic_aligner/__init__.py
diff --git a/funasr/models/monotonic_aligner/model.py b/funasr/models/monotonic_aligner/model.py
new file mode 100644
index 0000000..c0ac136
--- /dev/null
+++ b/funasr/models/monotonic_aligner/model.py
@@ -0,0 +1,202 @@
+import time
+import copy
+import torch
+from torch.cuda.amp import autocast
+from typing import Union, Dict, List, Tuple, Optional
+
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank, load_audio_and_text_image_video
+
+
+
+@tables.register("model_classes", "monotonicaligner")
+class MonotonicAligner(torch.nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Achieving timestamp prediction while recognizing with non-autoregressive end-to-end ASR model
+ https://arxiv.org/abs/2301.12343
+ """
+
+ def __init__(
+ self,
+ input_size: int = 80,
+ specaug: Optional[str] = None,
+ specaug_conf: Optional[Dict] = None,
+ normalize: str = None,
+ normalize_conf: Optional[Dict] = None,
+ encoder: str = None,
+ encoder_conf: Optional[Dict] = None,
+ predictor: str = None,
+ predictor_conf: Optional[Dict] = None,
+ predictor_bias: int = 0,
+ length_normalized_loss: bool = False,
+ **kwargs,
+ ):
+
+ super().__init__()
+
+ if specaug is not None:
+ specaug_class = tables.specaug_classes.get(specaug.lower())
+ specaug = specaug_class(**specaug_conf)
+ if normalize is not None:
+ normalize_class = tables.normalize_classes.get(normalize.lower())
+ normalize = normalize_class(**normalize_conf)
+ encoder_class = tables.encoder_classes.get(encoder.lower())
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
+ encoder_output_size = encoder.output_size()
+ predictor_class = tables.predictor_classes.get(predictor.lower())
+ predictor = predictor_class(**predictor_conf)
+ self.specaug = specaug
+ self.normalize = normalize
+ self.encoder = encoder
+ self.predictor = predictor
+ self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
+ self.predictor_bias = predictor_bias
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+ batch_size = speech.shape[0]
+ # for data-parallel
+ text = text[:, : text_lengths.max()]
+ speech = speech[:, :speech_lengths.max()]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, text = add_sos_eos(text, 1, 2, -1)
+ text_lengths = text_lengths + self.predictor_bias
+ _, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
+
+ # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+ loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
+
+ loss = loss_pre
+ stats = dict()
+
+ # Collect Attn branch stats
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
+ encoder_out_mask,
+ token_num)
+ return ds_alphas, ds_cif_peak, us_alphas, us_peaks
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+ # Forward encoder
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ return encoder_out, encoder_out_lens
+
+ def generate(self,
+ data_in,
+ data_lengths=None,
+ key: list=None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+
+ meta_data = {}
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_list, text_token_int_list = load_audio_and_text_image_video(data_in,
+ fs=frontend.fs,
+ audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer)
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ # predictor
+ text_lengths = torch.tensor([len(i)+1 for i in text_token_int_list]).to(encoder_out.device)
+ _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, token_num=text_lengths)
+
+ results = []
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer["tp_res"]
+ for i, (us_alpha, us_peak, token_int) in enumerate(zip(us_alphas, us_peaks, text_token_int_list)):
+ token = tokenizer.ids2tokens(token_int)
+ timestamp_str, timestamp = ts_prediction_lfr6_standard(us_alpha[:encoder_out_lens[i] * 3],
+ us_peak[:encoder_out_lens[i] * 3],
+ copy.copy(token))
+ text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
+ token, timestamp)
+ result_i = {"key": key[i], "text": text_postprocessed,
+ "timestamp": time_stamp_postprocessed,
+ }
+ # ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["timestamp_list"][key[i]] = time_stamp_postprocessed
+ ibest_writer["timestamp_str"][key[i]] = timestamp_str
+ results.append(result_i)
+ return results, meta_data
\ No newline at end of file
diff --git a/funasr/models/monotonic_aligner/template.yaml b/funasr/models/monotonic_aligner/template.yaml
new file mode 100644
index 0000000..b1379de
--- /dev/null
+++ b/funasr/models/monotonic_aligner/template.yaml
@@ -0,0 +1,115 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: MonotonicAligner
+model_conf:
+ length_normalized_loss: False
+ predictor_bias: 1
+
+# encoder
+encoder: SANMEncoder
+encoder_conf:
+ output_size: 320
+ attention_heads: 4
+ linear_units: 1280
+ num_blocks: 30
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.1
+ input_layer: pe
+ pos_enc_class: SinusoidalPositionEncoder
+ normalize_before: true
+ kernel_size: 11
+ sanm_shfit: 0
+ selfattention_layer_type: sanm
+
+predictor: CifPredictorV3
+predictor_conf:
+ idim: 320
+ threshold: 1.0
+ l_order: 1
+ r_order: 1
+ tail_threshold: 0.45
+ smooth_factor2: 0.25
+ noise_threshold2: 0.01
+ upsample_times: 3
+ use_cif1_cnn: false
+ upsample_type: cnn_blstm
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 7
+ lfr_n: 6
+
+specaug: SpecAugLFR
+specaug_conf:
+ apply_time_warp: false
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ lfr_rate: 6
+ num_freq_mask: 1
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 12
+ num_time_mask: 1
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 150
+ val_scheduler_criterion:
+ - valid
+ - acc
+ best_model_criterion:
+ - - valid
+ - acc
+ - max
+ keep_nbest_models: 10
+ log_interval: 50
+
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: DynamicBatchLocalShuffleSampler
+ batch_type: example # example or length
+ batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 500
+ shuffle: True
+ num_workers: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+ unk_symbol: <unk>
+ split_with_space: true
+
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin
+ reduce: true
+ ignore_nan_grad: true
+
+normalize: null
diff --git a/funasr/models/seaco_paraformer/template.yaml b/funasr/models/seaco_paraformer/template.yaml
index 52654ac..ab2301a 100644
--- a/funasr/models/seaco_paraformer/template.yaml
+++ b/funasr/models/seaco_paraformer/template.yaml
@@ -68,13 +68,18 @@
use_output_layer: false
wo_input_layer: true
-predictor: CifPredictorV2
+predictor: CifPredictorV3
predictor_conf:
idim: 512
threshold: 1.0
l_order: 1
r_order: 1
tail_threshold: 0.45
+ smooth_factor2: 0.25
+ noise_threshold2: 0.01
+ upsample_times: 3
+ use_cif1_cnn: false
+ upsample_type: cnn_blstm
# frontend related
frontend: WavFrontend
diff --git a/funasr/models/tp_aligner/e2e_tp.py b/funasr/models/tp_aligner/e2e_tp.py
deleted file mode 100644
index c675b0e..0000000
--- a/funasr/models/tp_aligner/e2e_tp.py
+++ /dev/null
@@ -1,170 +0,0 @@
-import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-
-import torch
-import numpy as np
-
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.frontends.abs_frontend import AbsFrontend
-from funasr.models.paraformer.cif_predictor import mae_loss
-from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.models.base_model import FunASRModel
-from funasr.models.paraformer.cif_predictor import CifPredictorV3
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
-else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
-
-
-class TimestampPredictor(FunASRModel):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- """
-
- def __init__(
- self,
- frontend: Optional[AbsFrontend],
- encoder: AbsEncoder,
- predictor: CifPredictorV3,
- predictor_bias: int = 0,
- token_list=None,
- ):
-
- super().__init__()
- # note that eos is the same as sos (equivalent ID)
-
- self.frontend = frontend
- self.encoder = encoder
- self.encoder.interctc_use_conditioning = False
-
- self.predictor = predictor
- self.predictor_bias = predictor_bias
- self.criterion_pre = mae_loss()
- self.token_list = token_list
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- assert text_lengths.dim() == 1, text_lengths.shape
- # Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- batch_size = speech.shape[0]
- # for data-parallel
- text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max()]
-
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, text = add_sos_eos(text, 1, 2, -1)
- text_lengths = text_lengths + self.predictor_bias
- _, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
-
- # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
-
- loss = loss_pre
- stats = dict()
-
- # Collect Attn branch stats
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
- stats["loss"] = torch.clone(loss.detach())
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
- def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- """
- with autocast(False):
- # 1. Extract feats
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-
- # 4. Forward encoder
- # feats: (Batch, Length, Dim)
- # -> encoder_out: (Batch, Length2, Dim2)
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
-
- return encoder_out, encoder_out_lens
-
- def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- assert speech_lengths.dim() == 1, speech_lengths.shape
-
- # for data-parallel
- speech = speech[:, : speech_lengths.max()]
- if self.frontend is not None:
- # Frontend
- # e.g. STFT and Feature extract
- # data_loader may send time-domain signal in this case
- # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
- feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
- # No frontend and no feature extract
- feats, feats_lengths = speech, speech_lengths
- return feats, feats_lengths
-
- def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
- encoder_out_mask,
- token_num)
- return ds_alphas, ds_cif_peak, us_alphas, us_peaks
-
- def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Dict[str, torch.Tensor]:
- if self.extract_feats_in_collect_stats:
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
- else:
- # Generate dummy stats if extract_feats_in_collect_stats is False
- logging.warning(
- "Generating dummy stats for feats and feats_lengths, "
- "because encoder_conf.extract_feats_in_collect_stats is "
- f"{self.extract_feats_in_collect_stats}"
- )
- feats, feats_lengths = speech, speech_lengths
- return {"feats": feats, "feats_lengths": feats_lengths}
--
Gitblit v1.9.1