shixian.shi
2024-01-05 276c443ba4b6574df682acd330a18a3830c9ef2c
update monotonic aligner
1个文件已修改
3个文件已添加
3 文件已重命名
2个文件已删除
526 ■■■■■ 已修改文件
examples/industrial_data_pretraining/monotonic_aligner/README_zh.md 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/monotonic_aligner/demo.py 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/monotonic_aligner/infer.sh 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/tp_aligner/infer.sh 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/monotonic_aligner/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/monotonic_aligner/model.py 202 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/monotonic_aligner/template.yaml 115 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/seaco_paraformer/template.yaml 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/tp_aligner/e2e_tp.py 170 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/monotonic_aligner/README_zh.md
examples/industrial_data_pretraining/monotonic_aligner/demo.py
examples/industrial_data_pretraining/monotonic_aligner/infer.sh
New file
@@ -0,0 +1,15 @@
# download model
local_path_root=../modelscope_models
mkdir -p ${local_path_root}
local_path=${local_path_root}/speech_timestamp_prediction-v1-16k-offline
# git clone https://www.modelscope.cn/damo/speech_timestamp_prediction-v1-16k-offline.git ${local_path}
python funasr/bin/inference.py \
+model="${local_path}" \
+input='["/Users/shixian/code/modelscope_models/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav", "欢迎大家来到魔搭社区进行体验"]' \
+data_type='["sound", "text"]' \
+output_dir="../outputs/debug" \
+device="cpu" \
+batch_size=2
examples/industrial_data_pretraining/tp_aligner/infer.sh
File was deleted
funasr/models/monotonic_aligner/__init__.py
funasr/models/monotonic_aligner/model.py
New file
@@ -0,0 +1,202 @@
import time
import copy
import torch
from torch.cuda.amp import autocast
from typing import Union, Dict, List, Tuple, Optional
from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank, load_audio_and_text_image_video
@tables.register("model_classes", "monotonicaligner")
class MonotonicAligner(torch.nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Achieving timestamp prediction while recognizing with non-autoregressive end-to-end ASR model
    https://arxiv.org/abs/2301.12343
    """
    def __init__(
        self,
        input_size: int = 80,
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        normalize: str = None,
        normalize_conf: Optional[Dict] = None,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        predictor: str = None,
        predictor_conf: Optional[Dict] = None,
        predictor_bias: int = 0,
        length_normalized_loss: bool = False,
        **kwargs,
    ):
        super().__init__()
        if specaug is not None:
            specaug_class = tables.specaug_classes.get(specaug.lower())
            specaug = specaug_class(**specaug_conf)
        if normalize is not None:
            normalize_class = tables.normalize_classes.get(normalize.lower())
            normalize = normalize_class(**normalize_conf)
        encoder_class = tables.encoder_classes.get(encoder.lower())
        encoder = encoder_class(input_size=input_size, **encoder_conf)
        encoder_output_size = encoder.output_size()
        predictor_class = tables.predictor_classes.get(predictor.lower())
        predictor = predictor_class(**predictor_conf)
        self.specaug = specaug
        self.normalize = normalize
        self.encoder = encoder
        self.predictor = predictor
        self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
        self.predictor_bias = predictor_bias
    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
                speech.shape[0]
                == speech_lengths.shape[0]
                == text.shape[0]
                == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        batch_size = speech.shape[0]
        # for data-parallel
        text = text[:, : text_lengths.max()]
        speech = speech[:, :speech_lengths.max()]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        if self.predictor_bias == 1:
            _, text = add_sos_eos(text, 1, 2, -1)
            text_lengths = text_lengths + self.predictor_bias
        _, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
        # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
        loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
        loss = loss_pre
        stats = dict()
        # Collect Attn branch stats
        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
                                                                                            encoder_out_mask,
                                                                                            token_num)
        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)
            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)
        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        return encoder_out, encoder_out_lens
    def generate(self,
             data_in,
             data_lengths=None,
             key: list=None,
             tokenizer=None,
             frontend=None,
             **kwargs,
             ):
        meta_data = {}
        # extract fbank feats
        time1 = time.perf_counter()
        audio_list, text_token_int_list = load_audio_and_text_image_video(data_in,
                                                                            fs=frontend.fs,
                                                                            audio_fs=kwargs.get("fs", 16000),
                                                                            data_type=kwargs.get("data_type", "sound"),
                                                                            tokenizer=tokenizer)
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        speech, speech_lengths = extract_fbank(audio_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)
        time3 = time.perf_counter()
        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
        meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
        speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        # predictor
        text_lengths = torch.tensor([len(i)+1 for i in text_token_int_list]).to(encoder_out.device)
        _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, token_num=text_lengths)
        results = []
        ibest_writer = None
        if ibest_writer is None and kwargs.get("output_dir") is not None:
            writer = DatadirWriter(kwargs.get("output_dir"))
            ibest_writer = writer["tp_res"]
        for i, (us_alpha, us_peak, token_int) in enumerate(zip(us_alphas, us_peaks, text_token_int_list)):
            token = tokenizer.ids2tokens(token_int)
            timestamp_str, timestamp = ts_prediction_lfr6_standard(us_alpha[:encoder_out_lens[i] * 3],
                                                                   us_peak[:encoder_out_lens[i] * 3],
                                                                   copy.copy(token))
            text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
                token, timestamp)
            result_i = {"key": key[i], "text": text_postprocessed,
                                "timestamp": time_stamp_postprocessed,
                                }
            # ibest_writer["token"][key[i]] = " ".join(token)
            ibest_writer["timestamp_list"][key[i]] = time_stamp_postprocessed
            ibest_writer["timestamp_str"][key[i]] = timestamp_str
            results.append(result_i)
        return results, meta_data
funasr/models/monotonic_aligner/template.yaml
New file
@@ -0,0 +1,115 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: MonotonicAligner
model_conf:
    length_normalized_loss: False
    predictor_bias: 1
# encoder
encoder: SANMEncoder
encoder_conf:
    output_size: 320
    attention_heads: 4
    linear_units: 1280
    num_blocks: 30
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    attention_dropout_rate: 0.1
    input_layer: pe
    pos_enc_class: SinusoidalPositionEncoder
    normalize_before: true
    kernel_size: 11
    sanm_shfit: 0
    selfattention_layer_type: sanm
predictor: CifPredictorV3
predictor_conf:
    idim: 320
    threshold: 1.0
    l_order: 1
    r_order: 1
    tail_threshold: 0.45
    smooth_factor2: 0.25
    noise_threshold2: 0.01
    upsample_times: 3
    use_cif1_cnn: false
    upsample_type: cnn_blstm
# frontend related
frontend: WavFrontend
frontend_conf:
    fs: 16000
    window: hamming
    n_mels: 80
    frame_length: 25
    frame_shift: 10
    lfr_m: 7
    lfr_n: 6
specaug: SpecAugLFR
specaug_conf:
    apply_time_warp: false
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 30
    lfr_rate: 6
    num_freq_mask: 1
    apply_time_mask: true
    time_mask_width_range:
    - 0
    - 12
    num_time_mask: 1
train_conf:
  accum_grad: 1
  grad_clip: 5
  max_epoch: 150
  val_scheduler_criterion:
      - valid
      - acc
  best_model_criterion:
  -   - valid
      - acc
      - max
  keep_nbest_models: 10
  log_interval: 50
optim: adam
optim_conf:
   lr: 0.0005
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 30000
dataset: AudioDataset
dataset_conf:
    index_ds: IndexDSJsonl
    batch_sampler: DynamicBatchLocalShuffleSampler
    batch_type: example # example or length
    batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
    buffer_size: 500
    shuffle: True
    num_workers: 0
tokenizer: CharTokenizer
tokenizer_conf:
  unk_symbol: <unk>
  split_with_space: true
ctc_conf:
    dropout_rate: 0.0
    ctc_type: builtin
    reduce: true
    ignore_nan_grad: true
normalize: null
funasr/models/seaco_paraformer/template.yaml
@@ -68,13 +68,18 @@
    use_output_layer: false
    wo_input_layer: true
predictor: CifPredictorV2
predictor: CifPredictorV3
predictor_conf:
    idim: 512
    threshold: 1.0
    l_order: 1
    r_order: 1
    tail_threshold: 0.45
    smooth_factor2: 0.25
    noise_threshold2: 0.01
    upsample_times: 3
    use_cif1_cnn: false
    upsample_type: cnn_blstm
# frontend related
frontend: WavFrontend
funasr/models/tp_aligner/e2e_tp.py
File was deleted