From 997374b88fe6b2ae5cb4dcaf47d78cb3eff09fc2 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 11 六月 2024 19:56:52 +0800
Subject: [PATCH] add ctc inference code (#1806)

---
 funasr/models/ctc/model.py                                   |  262 +++++++++++++++++++++++++++++
 examples/industrial_data_pretraining/ctc/infer_from_local.py |   31 +++
 funasr/models/sanm/encoder.py                                |  220 ++++++++++++++++++++++++
 examples/industrial_data_pretraining/ctc/demo.py             |   21 ++
 4 files changed, 534 insertions(+), 0 deletions(-)

diff --git a/examples/industrial_data_pretraining/ctc/demo.py b/examples/industrial_data_pretraining/ctc/demo.py
new file mode 100644
index 0000000..85a748a
--- /dev/null
+++ b/examples/industrial_data_pretraining/ctc/demo.py
@@ -0,0 +1,21 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import sys
+from funasr import AutoModel
+
+model_dir=sys.argv[1]
+input_file=sys.argv[2]
+
+model = AutoModel(
+    model=model_dir,
+)
+
+res = model.generate(
+    input=input_file,
+    cache={},
+)
+
+print(res)
diff --git a/examples/industrial_data_pretraining/ctc/infer_from_local.py b/examples/industrial_data_pretraining/ctc/infer_from_local.py
new file mode 100644
index 0000000..1c863e4
--- /dev/null
+++ b/examples/industrial_data_pretraining/ctc/infer_from_local.py
@@ -0,0 +1,31 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method2, inference from local model
+
+# for more input type, please ref to readme.md
+model_dir=$1
+input_file=$2
+output_dir=$3
+
+# download model
+device="cuda:0" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"
+
+tokens="${model_dir}/tokens.json"
+cmvn_file="${model_dir}/am.mvn"
+
+config="config.yaml"
+init_param="${model_dir}/model.pt"
+
+mkdir -p ${output_dir}
+
+python -m funasr.bin.inference \
+--config-path "${model_dir}" \
+--config-name "${config}" \
+++init_param="${init_param}" \
+++tokenizer_conf.token_list="${tokens}" \
+++frontend_conf.cmvn_file="${cmvn_file}" \
+++input="${input_file}" \
+++output_dir="${output_dir}" \
+++device="${device}" \
+
diff --git a/funasr/models/ctc/model.py b/funasr/models/ctc/model.py
new file mode 100644
index 0000000..e493c3b
--- /dev/null
+++ b/funasr/models/ctc/model.py
@@ -0,0 +1,262 @@
+import logging
+from typing import Union, Dict, List, Tuple, Optional
+
+import time
+import torch
+import torch.nn as nn
+
+from funasr.models.ctc.ctc import CTC
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.register import tables
+from funasr.models.paraformer.search import Hypothesis
+
+
+@tables.register("model_classes", "CTC")
+class Transformer(nn.Module):
+    """CTC-attention hybrid Encoder-Decoder model"""
+
+    def __init__(
+        self,
+        specaug: str = None,
+        specaug_conf: dict = None,
+        normalize: str = None,
+        normalize_conf: dict = None,
+        encoder: str = None,
+        encoder_conf: dict = None,
+        ctc_conf: dict = None,
+        input_size: int = 80,
+        vocab_size: int = -1,
+        ignore_id: int = -1,
+        blank_id: int = 0,
+        sos: int = 1,
+        eos: int = 2,
+        length_normalized_loss: bool = False,
+        **kwargs,
+    ):
+
+        super().__init__()
+
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+        if normalize is not None:
+            normalize_class = tables.normalize_classes.get(normalize)
+            normalize = normalize_class(**normalize_conf)
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(input_size=input_size, **encoder_conf)
+        encoder_output_size = encoder.output_size()
+
+        if ctc_conf is None:
+            ctc_conf = {}
+        ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
+
+        self.blank_id = blank_id
+        self.sos = sos if sos is not None else vocab_size - 1
+        self.eos = eos if eos is not None else vocab_size - 1
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.specaug = specaug
+        self.normalize = normalize
+        self.encoder = encoder
+        self.error_calculator = None
+
+        self.ctc = ctc
+
+        self.length_normalized_loss = length_normalized_loss
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+        # import pdb;
+        # pdb.set_trace()
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+
+        batch_size = speech.shape[0]
+
+        # 1. Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        loss_ctc, cer_ctc = None, None
+        stats = dict()
+
+        loss_ctc, cer_ctc = self._calc_ctc_loss(
+            encoder_out, encoder_out_lens, text, text_lengths
+        )
+
+        loss = loss_ctc
+
+        # Collect total loss stats
+        stats["loss"] = torch.clone(loss.detach())
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + 1).sum())
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def encode(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+
+        # Data augmentation
+        if self.specaug is not None and self.training:
+            speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+        # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+        if self.normalize is not None:
+            speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+        # Forward encoder
+        # feats: (Batch, Length, Dim)
+        # -> encoder_out: (Batch, Length2, Dim2)
+        encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
+
+        return encoder_out, encoder_out_lens
+
+
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        # Calc CTC loss
+        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+        # Calc CER using CTC
+        cer_ctc = None
+        if not self.training and self.error_calculator is not None:
+            ys_hat = self.ctc.argmax(encoder_out).data
+            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+        return loss_ctc, cer_ctc
+
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+
+        if kwargs.get("batch_size", 1) > 1:
+            raise NotImplementedError("batch decoding is not implemented")
+
+        meta_data = {}
+        if (
+            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
+        ):  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(
+                data_in,
+                fs=frontend.fs,
+                audio_fs=kwargs.get("fs", 16000),
+                data_type=kwargs.get("data_type", "sound"),
+                tokenizer=tokenizer,
+            )
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(
+                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+            )
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = (
+                speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+            )
+
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+        # Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        # c. Passed the encoder result and the beam search
+        ctc_logits = self.ctc.log_softmax(encoder_out)
+
+        results = []
+        b, n, d = encoder_out.size()
+        if isinstance(key[0], (list, tuple)):
+            key = key[0]
+        if len(key) < b:
+            key = key * b
+        for i in range(b):
+            x = ctc_logits[i, :encoder_out_lens[i], :]
+            yseq = x.argmax(dim=-1)
+            yseq = torch.unique_consecutive(yseq, dim=-1)
+            yseq = torch.tensor([self.sos] + yseq.tolist() + [self.eos], device=yseq.device)
+            nbest_hyps = [Hypothesis(yseq=yseq)]
+
+            for nbest_idx, hyp in enumerate(nbest_hyps):
+                ibest_writer = None
+                if kwargs.get("output_dir") is not None:
+                    if not hasattr(self, "writer"):
+                        self.writer = DatadirWriter(kwargs.get("output_dir"))
+                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+
+                # remove sos/eos and get results
+                last_pos = -1
+                if isinstance(hyp.yseq, list):
+                    token_int = hyp.yseq[1:last_pos]
+                else:
+                    token_int = hyp.yseq[1:last_pos].tolist()
+
+                # remove blank symbol id, which is assumed to be 0
+                token_int = list(
+                    filter(
+                        lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
+                    )
+                )
+
+                # Change integer-ids to tokens
+                token = tokenizer.ids2tokens(token_int)
+                text = tokenizer.tokens2text(token)
+
+                text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+                result_i = {"key": key[i], "token": token, "text": text_postprocessed}
+                results.append(result_i)
+
+                if ibest_writer is not None:
+                    ibest_writer["token"][key[i]] = " ".join(token)
+                    ibest_writer["text"][key[i]] = text_postprocessed
+
+        return results, meta_data
+
diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
index dc30a94..b2a442b 100644
--- a/funasr/models/sanm/encoder.py
+++ b/funasr/models/sanm/encoder.py
@@ -484,6 +484,226 @@
         return xs_pad, ilens, None
 
 
+@tables.register("encoder_classes", "SANMTPEncoder")
+class SANMTPEncoder(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
+    https://arxiv.org/abs/2006.01713
+    """
+    def __init__(
+            self,
+            input_size: int,
+            output_size: int = 256,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            tp_blocks: int = 0,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            attention_dropout_rate: float = 0.0,
+            stochastic_depth_rate: float = 0.0,
+            input_layer: Optional[str] = "conv2d",
+            pos_enc_class=SinusoidalPositionEncoder,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            positionwise_layer_type: str = "linear",
+            positionwise_conv_kernel_size: int = 1,
+            padding_idx: int = -1,
+            kernel_size: int = 11,
+            sanm_shfit: int = 0,
+            selfattention_layer_type: str = "sanm",
+    ):
+        super().__init__()
+        self._output_size = output_size
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                eval(pos_enc_class)(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "linear_no_pos":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                eval(pos_enc_class)(output_size, positional_dropout_rate, use_pos=False),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d2":
+            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                eval(pos_enc_class)(output_size, positional_dropout_rate),
+            )
+        elif input_layer is None:
+            if input_size == output_size:
+                self.embed = None
+            else:
+                self.embed = torch.nn.Linear(input_size, output_size)
+        elif input_layer == "pe":
+            self.embed = SinusoidalPositionEncoder()
+        elif input_layer == "pe_online":
+            self.embed = StreamSinusoidalPositionEncoder()
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+        if selfattention_layer_type == "selfattn":
+            encoder_selfattn_layer = MultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+        elif selfattention_layer_type == "sanm":
+            encoder_selfattn_layer = MultiHeadedAttentionSANM
+            encoder_selfattn_layer_args0 = (
+                attention_heads,
+                input_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+        self.encoders0 = repeat(
+            1,
+            lambda lnum: EncoderLayerSANM(
+                input_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        self.encoders = repeat(
+            num_blocks - 1,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+                stochastic_depth_rate,
+            ),
+        )
+        self.tp_encoders = repeat(
+            tp_blocks,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+                stochastic_depth_rate,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+        self.tp_blocks = tp_blocks
+        if self.tp_blocks > 0:
+            self.tp_norm = LayerNorm(output_size)
+    def output_size(self) -> int:
+        return self._output_size
+    def forward(
+            self,
+            xs_pad: torch.Tensor,
+            ilens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+            prev_states: Not to be used now.
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        xs_pad *= self.output_size() ** 0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        elif (
+                isinstance(self.embed, Conv2dSubsampling)
+                or isinstance(self.embed, Conv2dSubsampling2)
+                or isinstance(self.embed, Conv2dSubsampling6)
+                or isinstance(self.embed, Conv2dSubsampling8)
+        ):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
+                    + f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+        # forward encoder1
+        mask_shfit_chunk, mask_att_chunk_encoder = None, None
+        encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+        # forward encoder2
+        olens = masks.squeeze(1).sum(1)
+        mask_shfit_chunk2, mask_att_chunk_encoder2 = None, None
+        for layer_idx, encoder_layer in enumerate(self.tp_encoders):
+            encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk2, mask_att_chunk_encoder2)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        if self.tp_blocks > 0:
+            xs_pad = self.tp_norm(xs_pad)
+        return xs_pad, olens
+
+
 class EncoderLayerSANMExport(nn.Module):
     def __init__(
         self,

--
Gitblit v1.9.1