From 276c443ba4b6574df682acd330a18a3830c9ef2c Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期五, 05 一月 2024 15:47:06 +0800
Subject: [PATCH] update monotonic aligner

---
 /dev/null                                                           |  170 ------------------
 funasr/models/monotonic_aligner/template.yaml                       |  115 ++++++++++++
 examples/industrial_data_pretraining/monotonic_aligner/README_zh.md |    0 
 funasr/models/seaco_paraformer/template.yaml                        |    7 
 funasr/models/monotonic_aligner/__init__.py                         |    0 
 examples/industrial_data_pretraining/monotonic_aligner/demo.py      |    0 
 funasr/models/monotonic_aligner/model.py                            |  202 ++++++++++++++++++++++
 examples/industrial_data_pretraining/monotonic_aligner/infer.sh     |   15 +
 8 files changed, 338 insertions(+), 171 deletions(-)

diff --git a/examples/industrial_data_pretraining/tp_aligner/README_zh.md b/examples/industrial_data_pretraining/monotonic_aligner/README_zh.md
similarity index 100%
rename from examples/industrial_data_pretraining/tp_aligner/README_zh.md
rename to examples/industrial_data_pretraining/monotonic_aligner/README_zh.md
diff --git a/examples/industrial_data_pretraining/tp_aligner/demo.py b/examples/industrial_data_pretraining/monotonic_aligner/demo.py
similarity index 100%
rename from examples/industrial_data_pretraining/tp_aligner/demo.py
rename to examples/industrial_data_pretraining/monotonic_aligner/demo.py
diff --git a/examples/industrial_data_pretraining/monotonic_aligner/infer.sh b/examples/industrial_data_pretraining/monotonic_aligner/infer.sh
new file mode 100644
index 0000000..dc9d823
--- /dev/null
+++ b/examples/industrial_data_pretraining/monotonic_aligner/infer.sh
@@ -0,0 +1,15 @@
+
+# download model
+local_path_root=../modelscope_models
+mkdir -p ${local_path_root}
+local_path=${local_path_root}/speech_timestamp_prediction-v1-16k-offline
+# git clone https://www.modelscope.cn/damo/speech_timestamp_prediction-v1-16k-offline.git ${local_path}
+
+
+python funasr/bin/inference.py \
++model="${local_path}" \
++input='["/Users/shixian/code/modelscope_models/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav", "娆㈣繋澶у鏉ュ埌榄旀惌绀惧尯杩涜浣撻獙"]' \
++data_type='["sound", "text"]' \
++output_dir="../outputs/debug" \
++device="cpu" \
++batch_size=2 
diff --git a/examples/industrial_data_pretraining/tp_aligner/infer.sh b/examples/industrial_data_pretraining/tp_aligner/infer.sh
deleted file mode 100644
index ded296f..0000000
--- a/examples/industrial_data_pretraining/tp_aligner/infer.sh
+++ /dev/null
@@ -1,17 +0,0 @@
-
-# download model
-local_path_root=../modelscope_models
-mkdir -p ${local_path_root}
-local_path=${local_path_root}/speech_timestamp_prediction-v1-16k-offline
-git clone https://www.modelscope.cn/damo/speech_timestamp_prediction-v1-16k-offline.git ${local_path}
-
-
-python funasr/bin/inference.py \
-+model="${local_path}" \
-+input='["/Users/zhifu/funasr_github/test_local/wav.scp", "/Users/zhifu/funasr_github/test_local/text.txt"]' \
-+data_type='["sound", "text"]' \
-+output_dir="./outputs/debug" \
-+device="cpu" \
-+batch_size=2 \
-+debug="true"
-
diff --git a/funasr/models/tp_aligner/__init__.py b/funasr/models/monotonic_aligner/__init__.py
similarity index 100%
rename from funasr/models/tp_aligner/__init__.py
rename to funasr/models/monotonic_aligner/__init__.py
diff --git a/funasr/models/monotonic_aligner/model.py b/funasr/models/monotonic_aligner/model.py
new file mode 100644
index 0000000..c0ac136
--- /dev/null
+++ b/funasr/models/monotonic_aligner/model.py
@@ -0,0 +1,202 @@
+import time
+import copy
+import torch
+from torch.cuda.amp import autocast
+from typing import Union, Dict, List, Tuple, Optional
+
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank, load_audio_and_text_image_video
+
+
+
+@tables.register("model_classes", "monotonicaligner")
+class MonotonicAligner(torch.nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Achieving timestamp prediction while recognizing with non-autoregressive end-to-end ASR model
+    https://arxiv.org/abs/2301.12343
+    """
+
+    def __init__(
+        self,
+        input_size: int = 80,
+        specaug: Optional[str] = None,
+        specaug_conf: Optional[Dict] = None,
+        normalize: str = None,
+        normalize_conf: Optional[Dict] = None,
+        encoder: str = None,
+        encoder_conf: Optional[Dict] = None,
+        predictor: str = None,
+        predictor_conf: Optional[Dict] = None,
+        predictor_bias: int = 0,
+        length_normalized_loss: bool = False,
+        **kwargs,
+    ):
+
+        super().__init__()
+
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug.lower())
+            specaug = specaug_class(**specaug_conf)
+        if normalize is not None:
+            normalize_class = tables.normalize_classes.get(normalize.lower())
+            normalize = normalize_class(**normalize_conf)
+        encoder_class = tables.encoder_classes.get(encoder.lower())
+        encoder = encoder_class(input_size=input_size, **encoder_conf)
+        encoder_output_size = encoder.output_size()
+        predictor_class = tables.predictor_classes.get(predictor.lower())
+        predictor = predictor_class(**predictor_conf)
+        self.specaug = specaug
+        self.normalize = normalize
+        self.encoder = encoder
+        self.predictor = predictor
+        self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
+        self.predictor_bias = predictor_bias
+
+    def forward(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            text: torch.Tensor,
+            text_lengths: torch.Tensor,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Frontend + Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+        assert text_lengths.dim() == 1, text_lengths.shape
+        # Check that batch_size is unified
+        assert (
+                speech.shape[0]
+                == speech_lengths.shape[0]
+                == text.shape[0]
+                == text_lengths.shape[0]
+        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+        batch_size = speech.shape[0]
+        # for data-parallel
+        text = text[:, : text_lengths.max()]
+        speech = speech[:, :speech_lengths.max()]
+
+        # 1. Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+            encoder_out.device)
+        if self.predictor_bias == 1:
+            _, text = add_sos_eos(text, 1, 2, -1)
+            text_lengths = text_lengths + self.predictor_bias
+        _, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
+
+        # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+        loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
+
+        loss = loss_pre
+        stats = dict()
+
+        # Collect Attn branch stats
+        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+        stats["loss"] = torch.clone(loss.detach())
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
+        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+            encoder_out.device)
+        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
+                                                                                            encoder_out_mask,
+                                                                                            token_num)
+        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
+
+    def encode(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+            
+            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                speech, speech_lengths = self.normalize(speech, speech_lengths)
+        
+        # Forward encoder
+        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        return encoder_out, encoder_out_lens
+    
+    def generate(self,
+             data_in,
+             data_lengths=None,
+             key: list=None,
+             tokenizer=None,
+             frontend=None,
+             **kwargs,
+             ):
+        
+        meta_data = {}
+        # extract fbank feats
+        time1 = time.perf_counter()
+        audio_list, text_token_int_list = load_audio_and_text_image_video(data_in, 
+                                                                            fs=frontend.fs, 
+                                                                            audio_fs=kwargs.get("fs", 16000), 
+                                                                            data_type=kwargs.get("data_type", "sound"), 
+                                                                            tokenizer=tokenizer)
+        time2 = time.perf_counter()
+        meta_data["load_data"] = f"{time2 - time1:0.3f}"
+        speech, speech_lengths = extract_fbank(audio_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)
+        time3 = time.perf_counter()
+        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+        meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+            
+        speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+        # Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        
+        # predictor
+        text_lengths = torch.tensor([len(i)+1 for i in text_token_int_list]).to(encoder_out.device)
+        _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, token_num=text_lengths)
+        
+        results = []
+        ibest_writer = None
+        if ibest_writer is None and kwargs.get("output_dir") is not None:
+            writer = DatadirWriter(kwargs.get("output_dir"))
+            ibest_writer = writer["tp_res"]
+        for i, (us_alpha, us_peak, token_int) in enumerate(zip(us_alphas, us_peaks, text_token_int_list)):
+            token = tokenizer.ids2tokens(token_int)
+            timestamp_str, timestamp = ts_prediction_lfr6_standard(us_alpha[:encoder_out_lens[i] * 3],
+                                                                   us_peak[:encoder_out_lens[i] * 3],
+                                                                   copy.copy(token))
+            text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
+                token, timestamp)
+            result_i = {"key": key[i], "text": text_postprocessed,
+                                "timestamp": time_stamp_postprocessed,
+                                }    
+            # ibest_writer["token"][key[i]] = " ".join(token)
+            ibest_writer["timestamp_list"][key[i]] = time_stamp_postprocessed
+            ibest_writer["timestamp_str"][key[i]] = timestamp_str
+            results.append(result_i)
+        return results, meta_data
\ No newline at end of file
diff --git a/funasr/models/monotonic_aligner/template.yaml b/funasr/models/monotonic_aligner/template.yaml
new file mode 100644
index 0000000..b1379de
--- /dev/null
+++ b/funasr/models/monotonic_aligner/template.yaml
@@ -0,0 +1,115 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: MonotonicAligner
+model_conf:
+    length_normalized_loss: False
+    predictor_bias: 1
+
+# encoder
+encoder: SANMEncoder
+encoder_conf:
+    output_size: 320
+    attention_heads: 4
+    linear_units: 1280
+    num_blocks: 30
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.1
+    input_layer: pe
+    pos_enc_class: SinusoidalPositionEncoder
+    normalize_before: true
+    kernel_size: 11
+    sanm_shfit: 0
+    selfattention_layer_type: sanm
+
+predictor: CifPredictorV3
+predictor_conf:
+    idim: 320
+    threshold: 1.0
+    l_order: 1
+    r_order: 1
+    tail_threshold: 0.45
+    smooth_factor2: 0.25
+    noise_threshold2: 0.01
+    upsample_times: 3
+    use_cif1_cnn: false
+    upsample_type: cnn_blstm
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 7
+    lfr_n: 6
+
+specaug: SpecAugLFR
+specaug_conf:
+    apply_time_warp: false
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    lfr_rate: 6
+    num_freq_mask: 1
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 12
+    num_time_mask: 1
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 150
+  val_scheduler_criterion:
+      - valid
+      - acc
+  best_model_criterion:
+  -   - valid
+      - acc
+      - max
+  keep_nbest_models: 10
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_type: example # example or length
+    batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 500
+    shuffle: True
+    num_workers: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+  split_with_space: true
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin
+    reduce: true
+    ignore_nan_grad: true
+    
+normalize: null
diff --git a/funasr/models/seaco_paraformer/template.yaml b/funasr/models/seaco_paraformer/template.yaml
index 52654ac..ab2301a 100644
--- a/funasr/models/seaco_paraformer/template.yaml
+++ b/funasr/models/seaco_paraformer/template.yaml
@@ -68,13 +68,18 @@
     use_output_layer: false
     wo_input_layer: true
 
-predictor: CifPredictorV2
+predictor: CifPredictorV3
 predictor_conf:
     idim: 512
     threshold: 1.0
     l_order: 1
     r_order: 1
     tail_threshold: 0.45
+    smooth_factor2: 0.25
+    noise_threshold2: 0.01
+    upsample_times: 3
+    use_cif1_cnn: false
+    upsample_type: cnn_blstm
 
 # frontend related
 frontend: WavFrontend
diff --git a/funasr/models/tp_aligner/e2e_tp.py b/funasr/models/tp_aligner/e2e_tp.py
deleted file mode 100644
index c675b0e..0000000
--- a/funasr/models/tp_aligner/e2e_tp.py
+++ /dev/null
@@ -1,170 +0,0 @@
-import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-
-import torch
-import numpy as np
-
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.frontends.abs_frontend import AbsFrontend
-from funasr.models.paraformer.cif_predictor import mae_loss
-from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.models.base_model import FunASRModel
-from funasr.models.paraformer.cif_predictor import CifPredictorV3
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
-    from torch.cuda.amp import autocast
-else:
-    # Nothing to do if torch<1.6.0
-    @contextmanager
-    def autocast(enabled=True):
-        yield
-
-
-class TimestampPredictor(FunASRModel):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    """
-
-    def __init__(
-            self,
-            frontend: Optional[AbsFrontend],
-            encoder: AbsEncoder,
-            predictor: CifPredictorV3,
-            predictor_bias: int = 0,
-            token_list=None,
-    ):
-
-        super().__init__()
-        # note that eos is the same as sos (equivalent ID)
-
-        self.frontend = frontend
-        self.encoder = encoder
-        self.encoder.interctc_use_conditioning = False
-
-        self.predictor = predictor
-        self.predictor_bias = predictor_bias
-        self.criterion_pre = mae_loss()
-        self.token_list = token_list
-
-    def forward(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-            text: torch.Tensor,
-            text_lengths: torch.Tensor,
-    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-        """Frontend + Encoder + Decoder + Calc loss
-        Args:
-                speech: (Batch, Length, ...)
-                speech_lengths: (Batch, )
-                text: (Batch, Length)
-                text_lengths: (Batch,)
-        """
-        assert text_lengths.dim() == 1, text_lengths.shape
-        # Check that batch_size is unified
-        assert (
-                speech.shape[0]
-                == speech_lengths.shape[0]
-                == text.shape[0]
-                == text_lengths.shape[0]
-        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
-        batch_size = speech.shape[0]
-        # for data-parallel
-        text = text[:, : text_lengths.max()]
-        speech = speech[:, :speech_lengths.max()]
-
-        # 1. Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
-        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-            encoder_out.device)
-        if self.predictor_bias == 1:
-            _, text = add_sos_eos(text, 1, 2, -1)
-            text_lengths = text_lengths + self.predictor_bias
-        _, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
-
-        # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-        loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
-
-        loss = loss_pre
-        stats = dict()
-
-        # Collect Attn branch stats
-        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
-        stats["loss"] = torch.clone(loss.detach())
-
-        # force_gatherable: to-device and to-tensor if scalar for DataParallel
-        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-        return loss, stats, weight
-
-    def encode(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Frontend + Encoder. Note that this method is used by asr_inference.py
-        Args:
-                speech: (Batch, Length, ...)
-                speech_lengths: (Batch, )
-        """
-        with autocast(False):
-            # 1. Extract feats
-            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-
-        # 4. Forward encoder
-        # feats: (Batch, Length, Dim)
-        # -> encoder_out: (Batch, Length2, Dim2)
-        encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
-
-        return encoder_out, encoder_out_lens
-
-    def _extract_feats(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        assert speech_lengths.dim() == 1, speech_lengths.shape
-
-        # for data-parallel
-        speech = speech[:, : speech_lengths.max()]
-        if self.frontend is not None:
-            # Frontend
-            #  e.g. STFT and Feature extract
-            #       data_loader may send time-domain signal in this case
-            # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
-            feats, feats_lengths = self.frontend(speech, speech_lengths)
-        else:
-            # No frontend and no feature extract
-            feats, feats_lengths = speech, speech_lengths
-        return feats, feats_lengths
-
-    def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
-        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-            encoder_out.device)
-        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
-                                                                                            encoder_out_mask,
-                                                                                            token_num)
-        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
-
-    def collect_feats(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-            text: torch.Tensor,
-            text_lengths: torch.Tensor,
-    ) -> Dict[str, torch.Tensor]:
-        if self.extract_feats_in_collect_stats:
-            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-        else:
-            # Generate dummy stats if extract_feats_in_collect_stats is False
-            logging.warning(
-                "Generating dummy stats for feats and feats_lengths, "
-                "because encoder_conf.extract_feats_in_collect_stats is "
-                f"{self.extract_feats_in_collect_stats}"
-            )
-            feats, feats_lengths = speech, speech_lengths
-        return {"feats": feats, "feats_lengths": feats_lengths}

--
Gitblit v1.9.1