jmwang66
2023-06-20 2ff405b2f4ab899eff9bece232969fbb0c8f0555
Merge pull request #653 from alibaba-damo-academy/dev_wjm_infer

Dev wjm infer
27个文件已修改
3个文件已添加
2916 ■■■■ 已修改文件
.github/workflows/UnitTest.yml 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_infer.py 598 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py 863 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/diar_infer.py 49 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/diar_inference_launch.py 67 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/lm_inference_launch.py 127 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_infer.py 60 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_inference_launch.py 105 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/sv_infer.py 28 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/sv_inference_launch.py 106 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/tp_infer.py 65 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/tp_inference_launch.py 116 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/vad_infer.py 40 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/vad_inference_launch.py 59 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_args.py 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_asr_model.py 30 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_diar_model.py 22 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_lm_model.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_model.py 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_model_from_file.py 193 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_streaming_iterator.py 67 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_sv_model.py 258 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_vad_model.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_contextual_paraformer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_mfcca.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_uni_asr.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_vad.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
tests/test_sv_inference_pipeline.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
tests/test_vad_inference_pipeline.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
.github/workflows/UnitTest.yml
@@ -8,6 +8,7 @@
    branches:
      - dev_wjm
      - dev_jy
      - dev_wjm_infer
jobs:
  build:
@@ -18,6 +19,12 @@
        python-version: ["3.7"]
    steps:
      - name: Remove unnecessary files
        run:
          sudo rm -rf /usr/share/dotnet
          sudo rm -rf /opt/ghc
          sudo rm -rf "/usr/local/share/boost"
          sudo rm -rf "$AGENT_TOOLSDIRECTORY"
      - uses: actions/checkout@v3
      - name: Set up Python ${{ matrix.python-version }}
        uses: actions/setup-python@v4
funasr/bin/asr_infer.py
@@ -1,66 +1,46 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import sys
import time
import codecs
import copy
import logging
import os
import re
import codecs
import tempfile
import requests
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import numpy as np
import requests
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from  funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.modules.beam_search.beam_search import BeamSearch
# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.asr import ASRTask
from funasr.tasks.lm import LMTask
from funasr.build_utils.build_asr_model import frontend_choices
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.vad_infer import Speech2VadSegment
from funasr.bin.punc_infer import Text2Punc
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.tasks.asr import frontend_choices
class Speech2Text:
    """Speech2Text class
@@ -73,36 +53,36 @@
        [(text, token, token_int, hypothesis object), ...]
    """
    def __init__(
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        cmvn_file: Union[Path, str] = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        maxlenratio: float = 0.0,
        minlenratio: float = 0.0,
        batch_size: int = 1,
        dtype: str = "float32",
        beam_size: int = 20,
        ctc_weight: float = 0.5,
        lm_weight: float = 1.0,
        ngram_weight: float = 0.9,
        penalty: float = 0.0,
        nbest: int = 1,
        streaming: bool = False,
        frontend_conf: dict = None,
        **kwargs,
            self,
            asr_train_config: Union[Path, str] = None,
            asr_model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            lm_train_config: Union[Path, str] = None,
            lm_file: Union[Path, str] = None,
            token_type: str = None,
            bpemodel: str = None,
            device: str = "cpu",
            maxlenratio: float = 0.0,
            minlenratio: float = 0.0,
            batch_size: int = 1,
            dtype: str = "float32",
            beam_size: int = 20,
            ctc_weight: float = 0.5,
            lm_weight: float = 1.0,
            ngram_weight: float = 0.9,
            penalty: float = 0.0,
            nbest: int = 1,
            streaming: bool = False,
            frontend_conf: dict = None,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
@@ -110,16 +90,15 @@
            if asr_train_args.frontend == 'wav_frontend':
                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
            else:
                from funasr.tasks.asr import frontend_choices
                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        decoder = asr_model.decoder
        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        token_list = asr_model.token_list
        scorers.update(
@@ -127,24 +106,24 @@
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, None, device
            )
            scorers["lm"] = lm.lm
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        # 4. Build BeamSearch object
        # transducer is not supported now
        beam_search_transducer = None
        from funasr.modules.beam_search.beam_search import BeamSearch
        weights = dict(
            decoder=1.0 - ctc_weight,
            ctc=ctc_weight,
@@ -162,13 +141,13 @@
            token_list=token_list,
            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
        )
        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
@@ -180,7 +159,7 @@
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.converter = converter
@@ -193,10 +172,10 @@
        self.dtype = dtype
        self.nbest = nbest
        self.frontend = frontend
    @torch.no_grad()
    def __call__(
        self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
    ) -> List[
        Tuple[
            Optional[str],
@@ -214,11 +193,11 @@
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
@@ -229,48 +208,49 @@
            feats_len = speech_lengths
        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
        batch = {"speech": feats, "speech_lengths": feats_len}
        # a. To device
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        enc, _ = self.asr_model.encode(**batch)
        if isinstance(enc, tuple):
            enc = enc[0]
        assert len(enc) == 1, len(enc)
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
        )
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        for hyp in nbest_hyps:
            assert isinstance(hyp, (Hypothesis)), type(hyp)
            # remove sos/eos and get results
            last_pos = -1
            if isinstance(hyp.yseq, list):
                token_int = hyp.yseq[1:last_pos]
            else:
                token_int = hyp.yseq[1:last_pos].tolist()
            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != 0, token_int))
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)
            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, hyp))
        assert check_return_type(results)
        return results
class Speech2TextParaformer:
    """Speech2Text class
@@ -312,9 +292,8 @@
        # 1. Build ASR model
        scorers = {}
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -336,8 +315,8 @@
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            scorers["lm"] = lm.lm
@@ -466,18 +445,21 @@
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, NeatContextualParaformer):
        if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
                                                                                   NeatContextualParaformer):
            if self.hotword_list:
                logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                     pre_token_length)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        else:
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                     pre_token_length, hw_list=self.hotword_list)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        if isinstance(self.asr_model, BiCifParaformer):
            _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
                                                                                   pre_token_length)  # test no bias cif2
                                                                                pre_token_length)  # test no bias cif2
        results = []
        b, n, d = decoder_out.size()
@@ -527,12 +509,11 @@
                    text = None
                timestamp = []
                if isinstance(self.asr_model, BiCifParaformer):
                    _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i]*3],
                                                            us_peaks[i][:enc_len[i]*3],
                                                            copy.copy(token),
                                                            vad_offset=begin_time)
                    _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i] * 3],
                                                               us_peaks[i][:enc_len[i] * 3],
                                                               copy.copy(token),
                                                               vad_offset=begin_time)
                results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
        # assert check_return_type(results)
        return results
@@ -591,6 +572,7 @@
            hotword_list = None
        return hotword_list
class Speech2TextParaformerOnline:
    """Speech2Text class
@@ -630,9 +612,8 @@
        # 1. Build ASR model
        scorers = {}
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -654,8 +635,8 @@
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            scorers["lm"] = lm.lm
@@ -789,7 +770,7 @@
        enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
        predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache)
        pre_acoustic_embeds, pre_token_length= predictor_outs[0], predictor_outs[1]
        pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
        if torch.max(pre_token_length) < 1:
            return []
        decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache)
@@ -839,11 +820,12 @@
                        postprocessed_result += item + " "
                    else:
                        postprocessed_result += item
                results.append(postprocessed_result)
        # assert check_return_type(results)
        return results
class Speech2TextUniASR:
    """Speech2Text class
@@ -886,9 +868,8 @@
        # 1. Build ASR model
        scorers = {}
        from funasr.tasks.asr import ASRTaskUniASR as ASRTask
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device, mode="uniasr"
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -914,8 +895,8 @@
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, device, "lm"
            )
            scorers["lm"] = lm.lm
@@ -1077,7 +1058,7 @@
        assert check_return_type(results)
        return results
class Speech2TextMFCCA:
    """Speech2Text class
@@ -1090,45 +1071,44 @@
        [(text, token, token_int, hypothesis object), ...]
    """
    def __init__(
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        cmvn_file: Union[Path, str] = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        maxlenratio: float = 0.0,
        minlenratio: float = 0.0,
        batch_size: int = 1,
        dtype: str = "float32",
        beam_size: int = 20,
        ctc_weight: float = 0.5,
        lm_weight: float = 1.0,
        ngram_weight: float = 0.9,
        penalty: float = 0.0,
        nbest: int = 1,
        streaming: bool = False,
        **kwargs,
            self,
            asr_train_config: Union[Path, str] = None,
            asr_model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            lm_train_config: Union[Path, str] = None,
            lm_file: Union[Path, str] = None,
            token_type: str = None,
            bpemodel: str = None,
            device: str = "cpu",
            maxlenratio: float = 0.0,
            minlenratio: float = 0.0,
            batch_size: int = 1,
            dtype: str = "float32",
            beam_size: int = 20,
            ctc_weight: float = 0.5,
            lm_weight: float = 1.0,
            ngram_weight: float = 0.9,
            penalty: float = 0.0,
            nbest: int = 1,
            streaming: bool = False,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        decoder = asr_model.decoder
        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        token_list = asr_model.token_list
        scorers.update(
@@ -1136,11 +1116,11 @@
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            lm.to(device)
            scorers["lm"] = lm.lm
@@ -1148,11 +1128,11 @@
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        # 4. Build BeamSearch object
        # transducer is not supported now
        beam_search_transducer = None
        weights = dict(
            decoder=1.0 - ctc_weight,
            ctc=ctc_weight,
@@ -1176,7 +1156,7 @@
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
@@ -1188,7 +1168,7 @@
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.converter = converter
@@ -1200,10 +1180,10 @@
        self.device = device
        self.dtype = dtype
        self.nbest = nbest
    @torch.no_grad()
    def __call__(
        self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
    ) -> List[
        Tuple[
            Optional[str],
@@ -1231,45 +1211,45 @@
        # lenghts: (1,)
        lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
        batch = {"speech": speech, "speech_lengths": lengths}
        # a. To device
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        enc, _ = self.asr_model.encode(**batch)
        assert len(enc) == 1, len(enc)
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
        )
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        for hyp in nbest_hyps:
            assert isinstance(hyp, (Hypothesis)), type(hyp)
            # remove sos/eos and get results
            last_pos = -1
            if isinstance(hyp.yseq, list):
                token_int = hyp.yseq[1:last_pos]
            else:
                token_int = hyp.yseq[1:last_pos].tolist()
            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != 0, token_int))
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)
            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, hyp))
        assert check_return_type(results)
        return results
@@ -1298,45 +1278,44 @@
        right_context: Number of frames in right context AFTER subsampling.
        display_partial_hypotheses: Whether to display partial hypotheses.
    """
    def __init__(
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        cmvn_file: Union[Path, str] = None,
        beam_search_config: Dict[str, Any] = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        beam_size: int = 5,
        dtype: str = "float32",
        lm_weight: float = 1.0,
        quantize_asr_model: bool = False,
        quantize_modules: List[str] = None,
        quantize_dtype: str = "qint8",
        nbest: int = 1,
        streaming: bool = False,
        simu_streaming: bool = False,
        chunk_size: int = 16,
        left_context: int = 32,
        right_context: int = 0,
        display_partial_hypotheses: bool = False,
            self,
            asr_train_config: Union[Path, str] = None,
            asr_model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            beam_search_config: Dict[str, Any] = None,
            lm_train_config: Union[Path, str] = None,
            lm_file: Union[Path, str] = None,
            token_type: str = None,
            bpemodel: str = None,
            device: str = "cpu",
            beam_size: int = 5,
            dtype: str = "float32",
            lm_weight: float = 1.0,
            quantize_asr_model: bool = False,
            quantize_modules: List[str] = None,
            quantize_dtype: str = "qint8",
            nbest: int = 1,
            streaming: bool = False,
            simu_streaming: bool = False,
            chunk_size: int = 16,
            left_context: int = 32,
            right_context: int = 0,
            display_partial_hypotheses: bool = False,
    ) -> None:
        """Construct a Speech2Text object."""
        super().__init__()
        assert check_argument_types()
        from funasr.tasks.asr import ASRTransducerTask
        asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
        if quantize_asr_model:
            if quantize_modules is not None:
                if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
@@ -1344,36 +1323,36 @@
                        "Only 'Linear' and 'LSTM' modules are currently supported"
                        " by PyTorch and in --quantize_modules"
                    )
                q_config = set([getattr(torch.nn, q) for q in quantize_modules])
            else:
                q_config = {torch.nn.Linear}
            if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
                raise ValueError(
                    "float16 dtype for dynamic quantization is not supported with torch"
                    " version < 1.5.0. Switching to qint8 dtype instead."
                )
            q_dtype = getattr(torch, quantize_dtype)
            asr_model = torch.quantization.quantize_dynamic(
                asr_model, q_config, dtype=q_dtype
            ).eval()
        else:
            asr_model.to(dtype=getattr(torch, dtype)).eval()
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            lm_scorer = lm.lm
        else:
            lm_scorer = None
        # 4. Build BeamSearch object
        if beam_search_config is None:
            beam_search_config = {}
        beam_search = BeamSearchTransducer(
            asr_model.decoder,
            asr_model.joint_network,
@@ -1383,14 +1362,14 @@
            nbest=nbest,
            **beam_search_config,
        )
        token_list = asr_model.token_list
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
@@ -1402,60 +1381,60 @@
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.device = device
        self.dtype = dtype
        self.nbest = nbest
        self.converter = converter
        self.tokenizer = tokenizer
        self.beam_search = beam_search
        self.streaming = streaming
        self.simu_streaming = simu_streaming
        self.chunk_size = max(chunk_size, 0)
        self.left_context = left_context
        self.right_context = max(right_context, 0)
        if not streaming or chunk_size == 0:
            self.streaming = False
            self.asr_model.encoder.dynamic_chunk_training = False
        if not simu_streaming or chunk_size == 0:
            self.simu_streaming = False
            self.asr_model.encoder.dynamic_chunk_training = False
        self.frontend = frontend
        self.window_size = self.chunk_size + self.right_context
        if self.streaming:
            self._ctx = self.asr_model.encoder.get_encoder_input_size(
                self.window_size
            )
            self.last_chunk_length = (
                self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
                    self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
            )
            self.reset_inference_cache()
    def reset_inference_cache(self) -> None:
        """Reset Speech2Text parameters."""
        self.frontend_cache = None
        self.asr_model.encoder.reset_streaming_cache(
            self.left_context, device=self.device
        )
        self.beam_search.reset_inference_cache()
        self.num_processed_frames = torch.tensor([[0]], device=self.device)
    @torch.no_grad()
    def streaming_decode(
        self,
        speech: Union[torch.Tensor, np.ndarray],
        is_final: bool = True,
            self,
            speech: Union[torch.Tensor, np.ndarray],
            is_final: bool = True,
    ) -> List[HypothesisTransducer]:
        """Speech2Text streaming call.
        Args:
@@ -1473,13 +1452,13 @@
                )
                speech = torch.cat([speech, pad],
                                   dim=0)  # feats, feats_length = self.apply_frontend(speech, is_final=is_final)
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out = self.asr_model.encoder.chunk_forward(
@@ -1491,14 +1470,14 @@
            right_context=self.right_context,
        )
        nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
        self.num_processed_frames += self.chunk_size
        if is_final:
            self.reset_inference_cache()
        return nbest_hyps
    @torch.no_grad()
    def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
        """Speech2Text call.
@@ -1508,29 +1487,29 @@
            nbest_hypothesis: N-best hypothesis.
        """
        assert check_argument_types()
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if self.frontend is not None:
            speech = torch.unsqueeze(speech, axis=0)
            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
        else:
            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context,
                                                            self.right_context)
        nbest_hyps = self.beam_search(enc_out[0])
        return nbest_hyps
    @torch.no_grad()
    def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
        """Speech2Text call.
@@ -1540,7 +1519,7 @@
            nbest_hypothesis: N-best hypothesis.
        """
        assert check_argument_types()
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
@@ -1548,19 +1527,19 @@
            speech = torch.unsqueeze(speech, axis=0)
            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
        else:
            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
        nbest_hyps = self.beam_search(enc_out[0])
        return nbest_hyps
    def hypotheses_to_results(self, nbest_hyps: List[HypothesisTransducer]) -> List[Any]:
        """Build partial or final results from the hypotheses.
        Args:
@@ -1569,26 +1548,26 @@
            results: Results containing different representation for the hypothesis.
        """
        results = []
        for hyp in nbest_hyps:
            token_int = list(filter(lambda x: x != 0, hyp.yseq))
            token = self.converter.ids2tokens(token_int)
            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, hyp))
            assert check_return_type(results)
        return results
    @staticmethod
    def from_pretrained(
        model_tag: Optional[str] = None,
        **kwargs: Optional[Any],
            model_tag: Optional[str] = None,
            **kwargs: Optional[Any],
    ) -> Speech2Text:
        """Build Speech2Text instance from the pretrained model.
        Args:
@@ -1599,7 +1578,7 @@
        if model_tag is not None:
            try:
                from espnet_model_zoo.downloader import ModelDownloader
            except ImportError:
                logging.error(
                    "`espnet_model_zoo` is not installed. "
@@ -1608,7 +1587,7 @@
                raise
            d = ModelDownloader()
            kwargs.update(**d.download_and_unpack(model_tag))
        return Speech2TextTransducer(**kwargs)
@@ -1623,37 +1602,36 @@
        [(text, token, token_int, hypothesis object), ...]
    """
    def __init__(
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        cmvn_file: Union[Path, str] = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        maxlenratio: float = 0.0,
        minlenratio: float = 0.0,
        batch_size: int = 1,
        dtype: str = "float32",
        beam_size: int = 20,
        ctc_weight: float = 0.5,
        lm_weight: float = 1.0,
        ngram_weight: float = 0.9,
        penalty: float = 0.0,
        nbest: int = 1,
        streaming: bool = False,
        frontend_conf: dict = None,
        **kwargs,
            self,
            asr_train_config: Union[Path, str] = None,
            asr_model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            lm_train_config: Union[Path, str] = None,
            lm_file: Union[Path, str] = None,
            token_type: str = None,
            bpemodel: str = None,
            device: str = "cpu",
            maxlenratio: float = 0.0,
            minlenratio: float = 0.0,
            batch_size: int = 1,
            dtype: str = "float32",
            beam_size: int = 20,
            ctc_weight: float = 0.5,
            lm_weight: float = 1.0,
            ngram_weight: float = 0.9,
            penalty: float = 0.0,
            nbest: int = 1,
            streaming: bool = False,
            frontend_conf: dict = None,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        from funasr.tasks.asr import ASRTaskSAASR
        scorers = {}
        asr_model, asr_train_args = ASRTaskSAASR.build_model_from_file(
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
@@ -1665,13 +1643,13 @@
            else:
                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        decoder = asr_model.decoder
        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        token_list = asr_model.token_list
        scorers.update(
@@ -1679,24 +1657,24 @@
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, None, device
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            scorers["lm"] = lm.lm
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        # 4. Build BeamSearch object
        # transducer is not supported now
        beam_search_transducer = None
        from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
        weights = dict(
            decoder=1.0 - ctc_weight,
            ctc=ctc_weight,
@@ -1714,13 +1692,13 @@
            token_list=token_list,
            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
        )
        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
@@ -1732,7 +1710,7 @@
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.converter = converter
@@ -1745,11 +1723,11 @@
        self.dtype = dtype
        self.nbest = nbest
        self.frontend = frontend
    @torch.no_grad()
    def __call__(
        self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
        profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
            profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
    ) -> List[
        Tuple[
            Optional[str],
@@ -1768,14 +1746,14 @@
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if isinstance(profile, np.ndarray):
            profile = torch.tensor(profile)
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
@@ -1786,10 +1764,10 @@
            feats_len = speech_lengths
        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
        batch = {"speech": feats, "speech_lengths": feats_len}
        # a. To device
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        asr_enc, _, spk_enc = self.asr_model.encode(**batch)
        if isinstance(asr_enc, tuple):
@@ -1798,30 +1776,30 @@
            spk_enc = spk_enc[0]
        assert len(asr_enc) == 1, len(asr_enc)
        assert len(spk_enc) == 1, len(spk_enc)
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
        )
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        for hyp in nbest_hyps:
            assert isinstance(hyp, (HypothesisSAASR)), type(hyp)
            # remove sos/eos and get results
            last_pos = -1
            if isinstance(hyp.yseq, list):
                token_int = hyp.yseq[1: last_pos]
            else:
                token_int = hyp.yseq[1: last_pos].tolist()
            spk_weigths = torch.stack(hyp.spk_weigths, dim=0)
            token_ori = self.converter.ids2tokens(token_int)
            text_ori = self.tokenizer.tokens2text(token_ori)
            text_ori_spklist = text_ori.split('$')
            cur_index = 0
            spk_choose = []
@@ -1833,32 +1811,32 @@
                spk_weights_local = spk_weights_local.mean(dim=0)
                spk_choose_local = spk_weights_local.argmax(-1)
                spk_choose.append(spk_choose_local.item() + 1)
            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != 0, token_int))
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)
            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            text_spklist = text.split('$')
            assert len(spk_choose) == len(text_spklist)
            spk_list = []
            for i in range(len(text_spklist)):
                text_split = text_spklist[i]
                n = len(text_split)
                spk_list.append(str(spk_choose[i]) * n)
            text_id = '$'.join(spk_list)
            assert len(text) == len(text_id)
            results.append((text, text_id, token, token_int, hyp))
        assert check_return_type(results)
        return results
funasr/bin/asr_inference_launch.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -7,109 +7,77 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
#!/usr/bin/env python3
import argparse
import logging
import sys
import time
import copy
import os
import codecs
import tempfile
import requests
from pathlib import Path
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import yaml
import numpy as np
import torch
import torchaudio
import yaml
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import BeamSearch
# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr.bin.asr_infer import Speech2Text
from funasr.bin.asr_infer import Speech2TextMFCCA
from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
from funasr.bin.asr_infer import Speech2TextSAASR
from funasr.bin.asr_infer import Speech2TextTransducer
from funasr.bin.asr_infer import Speech2TextUniASR
from funasr.bin.punc_infer import Text2Punc
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.vad_infer import Speech2VadSegment
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.asr import ASRTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import asr_utils, postprocess_utils
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.bin.asr_infer import Speech2Text
from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
from funasr.bin.asr_infer import Speech2TextUniASR
from funasr.bin.asr_infer import Speech2TextMFCCA
from funasr.bin.vad_infer import Speech2VadSegment
from funasr.bin.punc_infer import Text2Punc
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.asr_infer import Speech2TextTransducer
from funasr.bin.asr_infer import Speech2TextSAASR
def inference_asr(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    streaming: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    mc: bool = False,
    param_dict: dict = None,
    **kwargs,
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        streaming: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        mc: bool = False,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
@@ -120,23 +88,23 @@
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
@@ -160,7 +128,7 @@
    )
    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
    speech2text = Speech2Text(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
@@ -173,20 +141,18 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            mc=mc,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop
@@ -197,14 +163,14 @@
            writer = DatadirWriter(output_path)
        else:
            writer = None
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            # N-best list of (text, token, token_int, hyp_object)
            try:
                results = speech2text(**batch)
@@ -212,19 +178,19 @@
                logging.warning(f"Utterance {keys} {e}")
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["sil"], [2], hyp]] * nbest
            # Only supporting batch_size==1
            key = keys[0]
            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
                # Create a directory: outdir/{n}best_recog
                if writer is not None:
                    ibest_writer = writer[f"{n}best_recog"]
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                if text is not None:
                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
@@ -233,67 +199,67 @@
                    asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text
                logging.info("uttid: {}".format(key))
                logging.info("text predictions: {}\n".format(text))
        return asr_result_list
    return _forward
def inference_paraformer(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    output_dir: Optional[str] = None,
    timestamp_infer_config: Union[Path, str] = None,
    timestamp_model_file: Union[Path, str] = None,
    param_dict: dict = None,
    **kwargs,
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        output_dir: Optional[str] = None,
        timestamp_infer_config: Union[Path, str] = None,
        timestamp_model_file: Union[Path, str] = None,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if word_lm_train_config is not None:
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    export_mode = False
    if param_dict is not None:
        hotword_list_or_file = param_dict.get('hotword')
        export_mode = param_dict.get("export_mode", False)
    else:
        hotword_list_or_file = None
    if kwargs.get("device", None) == "cpu":
        ngpu = 0
    if ngpu >= 1 and torch.cuda.is_available():
@@ -301,10 +267,10 @@
    else:
        device = "cpu"
        batch_size = 1
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
@@ -326,9 +292,9 @@
        nbest=nbest,
        hotword_list_or_file=hotword_list_or_file,
    )
    speech2text = Speech2TextParaformer(**speech2text_kwargs)
    if timestamp_model_file is not None:
        speechtext2timestamp = Speech2Timestamp(
            timestamp_cmvn_file=cmvn_file,
@@ -337,16 +303,16 @@
        )
    else:
        speechtext2timestamp = None
    def _forward(
        data_path_and_name_and_type,
        raw_inputs: Union[np.ndarray, torch.Tensor] = None,
        output_dir_v2: Optional[str] = None,
        fs: dict = None,
        param_dict: dict = None,
        **kwargs,
            data_path_and_name_and_type,
            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
            output_dir_v2: Optional[str] = None,
            fs: dict = None,
            param_dict: dict = None,
            **kwargs,
    ):
        hotword_list_or_file = None
        if param_dict is not None:
            hotword_list_or_file = param_dict.get('hotword')
@@ -354,30 +320,28 @@
            hotword_list_or_file = kwargs['hotword']
        if hotword_list_or_file is not None or 'hotword' in kwargs:
            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        if param_dict is not None:
            use_timestamp = param_dict.get('use_timestamp', True)
        else:
            use_timestamp = True
        forward_time_total = 0.0
        length_total = 0.0
        finish_count = 0
@@ -390,17 +354,17 @@
            writer = DatadirWriter(output_path)
        else:
            writer = None
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
            logging.info("decoding, utt_id: {}".format(keys))
            # N-best list of (text, token, token_int, hyp_object)
            time_beg = time.time()
            results = speech2text(**batch)
            if len(results) < 1:
@@ -416,10 +380,10 @@
                                                                                               100 * forward_time / (
                                                                                                       length * lfr_factor))
            logging.info(rtf_cur)
            for batch_id in range(_bs):
                result = [results[batch_id][:-2]]
                key = keys[batch_id]
                for n, result in zip(range(1, nbest + 1), result):
                    text, token, token_int, hyp = result[0], result[1], result[2], result[3]
@@ -438,13 +402,13 @@
                    # Create a directory: outdir/{n}best_recog
                    if writer is not None:
                        ibest_writer = writer[f"{n}best_recog"]
                        # Write the result to each file
                        ibest_writer["token"][key] = " ".join(token)
                        # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                        ibest_writer["score"][key] = str(hyp.score)
                        ibest_writer["rtf"][key] = rtf_cur
                    if text is not None:
                        if use_timestamp and timestamp is not None:
                            postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
@@ -465,7 +429,7 @@
                        # asr_utils.print_progress(finish_count / file_count)
                        if writer is not None:
                            ibest_writer["text"][key] = " ".join(word_lists)
                    logging.info("decoding, utt: {}, predictions: {}".format(key, text))
        rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
                                                                                                           forward_time_total,
@@ -475,74 +439,74 @@
        if writer is not None:
            ibest_writer["rtf"]["rtf_avf"] = rtf_avg
        return asr_result_list
    return _forward
def inference_paraformer_vad_punc(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    vad_infer_config: Optional[str] = None,
    vad_model_file: Optional[str] = None,
    vad_cmvn_file: Optional[str] = None,
    time_stamp_writer: bool = True,
    punc_infer_config: Optional[str] = None,
    punc_model_file: Optional[str] = None,
    outputs_dict: Optional[bool] = True,
    param_dict: dict = None,
    **kwargs,
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        vad_infer_config: Optional[str] = None,
        vad_model_file: Optional[str] = None,
        vad_cmvn_file: Optional[str] = None,
        time_stamp_writer: bool = True,
        punc_infer_config: Optional[str] = None,
        punc_model_file: Optional[str] = None,
        outputs_dict: Optional[bool] = True,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if word_lm_train_config is not None:
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if param_dict is not None:
        hotword_list_or_file = param_dict.get('hotword')
    else:
        hotword_list_or_file = None
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2vadsegment
    speech2vadsegment_kwargs = dict(
        vad_infer_config=vad_infer_config,
@@ -553,7 +517,7 @@
    )
    # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
    speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
    # 3. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
@@ -579,12 +543,12 @@
    text2punc = None
    if punc_model_file is not None:
        text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
    if output_dir is not None:
        writer = DatadirWriter(output_dir)
        ibest_writer = writer[f"1best_recog"]
        ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
@@ -592,43 +556,41 @@
                 param_dict: dict = None,
                 **kwargs,
                 ):
        hotword_list_or_file = None
        if param_dict is not None:
            hotword_list_or_file = param_dict.get('hotword')
        if 'hotword' in kwargs:
            hotword_list_or_file = kwargs['hotword']
        batch_size_token = kwargs.get("batch_size_token", 6000)
        print("batch_size_token: ", batch_size_token)
        if speech2text.hotword_list is None:
            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=None,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            batch_size=1,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
            collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        if param_dict is not None:
            use_timestamp = param_dict.get('use_timestamp', True)
        else:
            use_timestamp = True
        finish_count = 0
        file_count = 1
        lfr_factor = 6
@@ -639,7 +601,7 @@
        if output_path is not None:
            writer = DatadirWriter(output_path)
            ibest_writer = writer[f"1best_recog"]
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
@@ -648,15 +610,16 @@
            beg_vad = time.time()
            vad_results = speech2vadsegment(**batch)
            end_vad = time.time()
            print("time cost vad: ", end_vad-beg_vad)
            print("time cost vad: ", end_vad - beg_vad)
            _, vadsegments = vad_results[0], vad_results[1][0]
            speech, speech_lengths = batch["speech"], batch["speech_lengths"]
            n = len(vadsegments)
            data_with_index = [(vadsegments[i], i) for i in range(n)]
            sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
            results_sorted = []
            batch_size_token_ms = batch_size_token*60
            if speech2text.device == "cpu":
                batch_size_token_ms = 0
@@ -666,7 +629,8 @@
            beg_idx = 0
            for j, _ in enumerate(range(0, n)):
                batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
                if j < n-1 and (batch_size_token_ms_cum + sorted_data[j+1][0][1] - sorted_data[j+1][0][0])<batch_size_token_ms:
                if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][
                    0]) < batch_size_token_ms:
                    continue
                batch_size_token_ms_cum = 0
                end_idx = j + 1
@@ -679,11 +643,11 @@
                results = speech2text(**batch)
                end_asr = time.time()
                print("time cost asr: ", end_asr - beg_asr)
                if len(results) < 1:
                    results = [["", [], [], [], [], [], []]]
                results_sorted.extend(results)
            restored_data = [0] * n
            for j in range(n):
                index = sorted_data[j][1]
@@ -699,12 +663,12 @@
                        t[1] += vadsegments[j][0]
                    result[4] += restored_data[j][4]
                # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
            key = keys[0]
            # result = result_segments[0]
            text, token, token_int = result[0], result[1], result[2]
            time_stamp = result[4] if len(result[4]) > 0 else None
            if use_timestamp and time_stamp is not None:
                postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
            else:
@@ -718,23 +682,23 @@
                                                                           postprocessed_result[2]
            else:
                text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
            text_postprocessed_punc = text_postprocessed
            punc_id_list = []
            if len(word_lists) > 0 and text2punc is not None:
                beg_punc = time.time()
                text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
                end_punc = time.time()
                print("time cost punc: ", end_punc-beg_punc)
                print("time cost punc: ", end_punc - beg_punc)
            item = {'key': key, 'value': text_postprocessed_punc}
            if text_postprocessed != "":
                item['text_postprocessed'] = text_postprocessed
            if time_stamp_postprocessed != "":
                item['time_stamp'] = time_stamp_postprocessed
            item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
            asr_result_list.append(item)
            finish_count += 1
            # asr_utils.print_progress(finish_count / file_count)
@@ -747,11 +711,12 @@
                ibest_writer["text_with_punc"][key] = text_postprocessed_punc
                if time_stamp_postprocessed is not None:
                    ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
            logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
        return asr_result_list
    return _forward
def inference_paraformer_online(
        maxlenratio: float,
@@ -852,7 +817,7 @@
            data = yaml.load(f, Loader=yaml.Loader)
        return data
    def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
    def _prepare_cache(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
        if len(cache) > 0:
            return cache
        config = _read_yaml(asr_train_config)
@@ -868,14 +833,15 @@
        return cache
    def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
    def _cache_reset(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
        if len(cache) > 0:
            config = _read_yaml(asr_train_config)
            enc_output_size = config["encoder_conf"]["output_size"]
            feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
            cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
                        "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
                        "tail_chunk": False}
            cache["encoder"] = cache_en
            cache_de = {"decode_fsmn": None}
@@ -920,7 +886,7 @@
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
            sample_offset = 0
            speech_length = raw_inputs.shape[1]
            stride_size =  chunk_size[1] * 960
            stride_size = chunk_size[1] * 960
            cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
            final_result = ""
            for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
@@ -949,40 +915,40 @@
def inference_uniasr(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    ngram_file: Optional[str] = None,
    cmvn_file: Optional[str] = None,
    # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    streaming: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    token_num_relax: int = 1,
    decoding_ind: int = 0,
    decoding_mode: str = "model1",
    param_dict: dict = None,
    **kwargs,
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        ngram_file: Optional[str] = None,
        cmvn_file: Optional[str] = None,
        # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        streaming: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        token_num_relax: int = 1,
        decoding_ind: int = 0,
        decoding_mode: str = "model1",
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
@@ -993,17 +959,17 @@
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    if param_dict is not None and "decoding_model" in param_dict:
        if param_dict["decoding_model"] == "fast":
            decoding_ind = 0
@@ -1016,10 +982,10 @@
            decoding_mode = "model2"
        else:
            raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
@@ -1046,7 +1012,7 @@
        decoding_mode=decoding_mode,
    )
    speech2text = Speech2TextUniASR(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
@@ -1059,19 +1025,17 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop
@@ -1082,14 +1046,14 @@
            writer = DatadirWriter(output_path)
        else:
            writer = None
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            # N-best list of (text, token, token_int, hyp_object)
            try:
                results = speech2text(**batch)
@@ -1097,7 +1061,7 @@
                logging.warning(f"Utterance {keys} {e}")
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["sil"], [2], hyp]] * nbest
            # Only supporting batch_size==1
            key = keys[0]
            logging.info(f"Utterance: {key}")
@@ -1105,12 +1069,12 @@
                # Create a directory: outdir/{n}best_recog
                if writer is not None:
                    ibest_writer = writer[f"{n}best_recog"]
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                if text is not None:
                    text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
@@ -1120,40 +1084,40 @@
                    if writer is not None:
                        ibest_writer["text"][key] = " ".join(word_lists)
        return asr_result_list
    return _forward
def inference_mfcca(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    streaming: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    param_dict: dict = None,
    **kwargs,
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        streaming: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
@@ -1164,20 +1128,20 @@
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
@@ -1201,7 +1165,7 @@
    )
    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
    speech2text = Speech2TextMFCCA(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
@@ -1214,20 +1178,18 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            fs=fs,
            mc=True,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop
@@ -1238,14 +1200,14 @@
            writer = DatadirWriter(output_path)
        else:
            writer = None
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            # N-best list of (text, token, token_int, hyp_object)
            try:
                results = speech2text(**batch)
@@ -1253,19 +1215,19 @@
                logging.warning(f"Utterance {keys} {e}")
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["<space>"], [2], hyp]] * nbest
            # Only supporting batch_size==1
            key = keys[0]
            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
                # Create a directory: outdir/{n}best_recog
                if writer is not None:
                    ibest_writer = writer[f"{n}best_recog"]
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                if text is not None:
                    text_postprocessed = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
@@ -1275,42 +1237,43 @@
                    if writer is not None:
                        ibest_writer["text"][key] = text
        return asr_result_list
    return _forward
def inference_transducer(
    output_dir: str,
    batch_size: int,
    dtype: str,
    beam_size: int,
    ngpu: int,
    seed: int,
    lm_weight: float,
    nbest: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str],
    beam_search_config: Optional[dict],
    lm_train_config: Optional[str],
    lm_file: Optional[str],
    model_tag: Optional[str],
    token_type: Optional[str],
    bpemodel: Optional[str],
    key_file: Optional[str],
    allow_variable_data_keys: bool,
    quantize_asr_model: Optional[bool],
    quantize_modules: Optional[List[str]],
    quantize_dtype: Optional[str],
    streaming: Optional[bool],
    simu_streaming: Optional[bool],
    chunk_size: Optional[int],
    left_context: Optional[int],
    right_context: Optional[int],
    display_partial_hypotheses: bool,
    **kwargs,
        output_dir: str,
        batch_size: int,
        dtype: str,
        beam_size: int,
        ngpu: int,
        seed: int,
        lm_weight: float,
        nbest: int,
        num_workers: int,
        log_level: Union[int, str],
        data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str],
        beam_search_config: Optional[dict],
        lm_train_config: Optional[str],
        lm_file: Optional[str],
        model_tag: Optional[str],
        token_type: Optional[str],
        bpemodel: Optional[str],
        key_file: Optional[str],
        allow_variable_data_keys: bool,
        quantize_asr_model: Optional[bool],
        quantize_modules: Optional[List[str]],
        quantize_dtype: Optional[str],
        streaming: Optional[bool],
        simu_streaming: Optional[bool],
        chunk_size: Optional[int],
        left_context: Optional[int],
        right_context: Optional[int],
        display_partial_hypotheses: bool,
        **kwargs,
) -> None:
    """Transducer model inference.
    Args:
@@ -1391,7 +1354,7 @@
        model_tag=model_tag,
        **speech2text_kwargs,
    )
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
@@ -1400,106 +1363,99 @@
                 **kwargs,
                 ):
        # 3. Build data-iterator
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(
                speech2text.asr_train_args, False
            ),
            collate_fn=ASRTask.build_collate_fn(
                speech2text.asr_train_args, False
            ),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        # 4 .Start for-loop
        with DatadirWriter(output_dir) as writer:
            for keys, batch in loader:
                assert isinstance(batch, dict), type(batch)
                assert all(isinstance(s, str) for s in keys), keys
                _bs = len(next(iter(batch.values())))
                assert len(keys) == _bs, f"{len(keys)} != {_bs}"
                batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
                assert len(batch.keys()) == 1
                try:
                    if speech2text.streaming:
                        speech = batch["speech"]
                        _steps = len(speech) // speech2text._ctx
                        _end = 0
                        for i in range(_steps):
                            _end = (i + 1) * speech2text._ctx
                            speech2text.streaming_decode(
                                speech[i * speech2text._ctx : _end], is_final=False
                                speech[i * speech2text._ctx: _end], is_final=False
                            )
                        final_hyps = speech2text.streaming_decode(
                            speech[_end : len(speech)], is_final=True
                            speech[_end: len(speech)], is_final=True
                        )
                    elif speech2text.simu_streaming:
                        final_hyps = speech2text.simu_streaming_decode(**batch)
                    else:
                        final_hyps = speech2text(**batch)
                    results = speech2text.hypotheses_to_results(final_hyps)
                except TooShortUttError as e:
                    logging.warning(f"Utterance {keys} {e}")
                    hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
                    results = [[" ", ["<space>"], [2], hyp]] * nbest
                key = keys[0]
                for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
                    ibest_writer = writer[f"{n}best_recog"]
                    ibest_writer["token"][key] = " ".join(token)
                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                    if text is not None:
                        ibest_writer["text"][key] = text
    return _forward
def inference_sa_asr(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    streaming: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    mc: bool = False,
    param_dict: dict = None,
    **kwargs,
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        streaming: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        mc: bool = False,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    if batch_size > 1:
@@ -1508,23 +1464,23 @@
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
@@ -1548,7 +1504,7 @@
    )
    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
    speech2text = Speech2TextSAASR(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
@@ -1561,20 +1517,18 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            mc=mc,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop
@@ -1585,7 +1539,7 @@
            writer = DatadirWriter(output_path)
        else:
            writer = None
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
@@ -1599,20 +1553,20 @@
                logging.warning(f"Utterance {keys} {e}")
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["sil"], [2], hyp]] * nbest
            # Only supporting batch_size==1
            key = keys[0]
            for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
                # Create a directory: outdir/{n}best_recog
                if writer is not None:
                    ibest_writer = writer[f"{n}best_recog"]
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                    ibest_writer["text_id"][key] = text_id
                if text is not None:
                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
@@ -1621,12 +1575,12 @@
                    asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text
                logging.info("uttid: {}".format(key))
                logging.info("text predictions: {}".format(text))
                logging.info("text_id predictions: {}\n".format(text_id))
        return asr_result_list
    return _forward
@@ -1664,7 +1618,7 @@
        description="ASR Decoding",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    # Note(kamo): Use '_' instead of '-' as separator.
    # '-' is confusing if written in yaml.
    parser.add_argument(
@@ -1674,7 +1628,7 @@
        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
        help="The verbose level of logging",
    )
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument(
        "--ngpu",
@@ -1707,7 +1661,7 @@
        default=1,
        help="The number of workers used for DataLoader",
    )
    group = parser.add_argument_group("Input data related")
    group.add_argument(
        "--data_path_and_name_and_type",
@@ -1729,7 +1683,7 @@
        default=False,
        help="MultiChannel input",
    )
    group = parser.add_argument_group("The model configuration related")
    group.add_argument(
        "--vad_infer_config",
@@ -1792,7 +1746,7 @@
        default={},
        help="The keyword arguments for transducer beam search.",
    )
    group = parser.add_argument_group("Beam-search related")
    group.add_argument(
        "--batch_size",
@@ -1839,7 +1793,7 @@
        default=False,
        help="Whether to display partial hypotheses during chunk-by-chunk inference.",
    )
    group = parser.add_argument_group("Dynamic quantization related")
    group.add_argument(
        "--quantize_asr_model",
@@ -1864,7 +1818,7 @@
        choices=["float16", "qint8"],
        help="Dtype for dynamic quantization.",
    )
    group = parser.add_argument_group("Text converter related")
    group.add_argument(
        "--token_type",
@@ -1922,7 +1876,6 @@
    inference_pipeline = inference_launch(**kwargs)
    return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
if __name__ == "__main__":
funasr/bin/diar_infer.py
@@ -1,41 +1,28 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from collections import OrderedDict
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from collections import OrderedDict
import numpy as np
import soundfile
import torch
from scipy.ndimage import median_filter
from torch.nn import functional as F
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.diar import DiarTask
from funasr.tasks.diar import EENDOLADiarTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from scipy.ndimage import median_filter
from funasr.utils.misc import statistic_model_parameters
from funasr.datasets.iterable_dataset import load_bytes
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.tasks.diar import DiarTask
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.torch_utils.device_funcs import to_device
from funasr.utils.misc import statistic_model_parameters
class Speech2DiarizationEEND:
    """Speech2Diarlization class
@@ -61,10 +48,12 @@
        assert check_argument_types()
        # 1. Build Diarization model
        diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file(
        diar_model, diar_train_args = build_model_from_file(
            config_file=diar_train_config,
            model_file=diar_model_file,
            device=device
            device=device,
            task_name="diar",
            mode="eend-ola",
        )
        frontend = None
        if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
@@ -177,10 +166,12 @@
        assert check_argument_types()
        # TODO: 1. Build Diarization model
        diar_model, diar_train_args = DiarTask.build_model_from_file(
        diar_model, diar_train_args = build_model_from_file(
            config_file=diar_train_config,
            model_file=diar_model_file,
            device=device
            device=device,
            task_name="diar",
            mode="sond",
        )
        logging.info("diar_model: {}".format(diar_model))
        logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
@@ -248,7 +239,7 @@
        ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
        logits_idx = F.upsample(
            logits_idx.unsqueeze(1).float(),
            size=(ut, ),
            size=(ut,),
            mode="nearest",
        ).squeeze(1).long()
        logits_idx = logits_idx[0].tolist()
@@ -268,7 +259,7 @@
            if spk not in results:
                results[spk] = []
            if dur > self.dur_threshold:
                results[spk].append((st, st+dur))
                results[spk].append((st, st + dur))
        # sort segments in start time ascending
        for spk in results:
@@ -344,7 +335,3 @@
            kwargs.update(**d.download_and_unpack(model_tag))
        return Speech2DiarizationSOND(**kwargs)
funasr/bin/diar_inference_launch.py
@@ -1,5 +1,5 @@
# !/usr/bin/env python3
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -8,47 +8,28 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from collections import OrderedDict
import numpy as np
import soundfile
import torch
from torch.nn import functional as F
from typeguard import check_argument_types
from typeguard import check_return_type
from scipy.signal import medfilt
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.diar import DiarTask
from funasr.tasks.diar import EENDOLADiarTask
from funasr.torch_utils.device_funcs import to_device
from typeguard import check_argument_types
from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
from funasr.datasets.iterable_dataset import load_bytes
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from scipy.ndimage import median_filter
from funasr.utils.misc import statistic_model_parameters
from funasr.datasets.iterable_dataset import load_bytes
from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
def inference_sond(
        diar_train_config: str,
@@ -94,7 +75,8 @@
    set_all_random_seed(seed)
    # 2a. Build speech2xvec [Optional]
    if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
    if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict[
        "extract_profile"]:
        assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
        assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
        sv_train_config = param_dict["sv_train_config"]
@@ -139,7 +121,7 @@
        rst = []
        mid = uttid.rsplit("-", 1)[0]
        for key in results:
            results[key] = [(x[0]/100, x[1]/100) for x in results[key]]
            results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]]
        if out_format == "vad":
            for spk, segs in results.items():
                rst.append("{} {}".format(spk, segs))
@@ -176,7 +158,7 @@
                        example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
                                   for x in example]
                        speech = example[0]
                        logging.info("Extracting profiles for {} waveforms".format(len(example)-1))
                        logging.info("Extracting profiles for {} waveforms".format(len(example) - 1))
                        profile = [speech2xvector.calculate_embedding(x) for x in example[1:]]
                        profile = torch.cat(profile, dim=0)
                        yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]}
@@ -186,16 +168,15 @@
                raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
        else:
            # 3. Build data-iterator
            loader = DiarTask.build_streaming_iterator(
                data_path_and_name_and_type,
            loader = build_streaming_iterator(
                task_name="diar",
                preprocess_args=None,
                data_path_and_name_and_type=data_path_and_name_and_type,
                dtype=dtype,
                batch_size=batch_size,
                key_file=key_file,
                num_workers=num_workers,
                preprocess_fn=None,
                collate_fn=None,
                allow_variable_data_keys=allow_variable_data_keys,
                inference=True,
                use_collate_fn=False,
            )
        # 7. Start for-loop
@@ -234,6 +215,7 @@
        return result_list
    return _forward
def inference_eend(
        diar_train_config: str,
@@ -306,16 +288,14 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
        loader = EENDOLADiarTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="diar",
            preprocess_args=None,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False),
            collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        # 3. Start for-loop
@@ -362,8 +342,6 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "sond":
        return inference_sond(mode=mode, **kwargs)
@@ -386,6 +364,7 @@
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Speaker Verification",
funasr/bin/lm_inference_launch.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -7,40 +7,25 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.types import float_or_none
import argparse
import logging
from pathlib import Path
import sys
import os
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import torch
from torch.nn.parallel import data_parallel
from typeguard import check_argument_types
from funasr.tasks.lm import LMTask
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.datasets.preprocessor import LMPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import float_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
@@ -48,42 +33,42 @@
def inference_lm(
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    log_base: Optional[float] = 10,
    allow_variable_data_keys: bool = False,
    split_with_space: Optional[bool] = False,
    seg_dict_file: Optional[str] = None,
    output_dir: Optional[str] = None,
    param_dict: dict = None,
    **kwargs,
        batch_size: int,
        dtype: str,
        ngpu: int,
        seed: int,
        num_workers: int,
        log_level: Union[int, str],
        key_file: Optional[str],
        train_config: Optional[str],
        model_file: Optional[str],
        log_base: Optional[float] = 10,
        allow_variable_data_keys: bool = False,
        split_with_space: Optional[bool] = False,
        seg_dict_file: Optional[str] = None,
        output_dir: Optional[str] = None,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build Model
    model, train_args = LMTask.build_model_from_file(
        train_config, model_file, device)
    model, train_args = build_model_from_file(
        train_config, model_file, None, device, "lm")
    wrapped_model = ForwardAdaptor(model, "nll")
    wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
    logging.info(f"Model:\n{model}")
    preprocessor = LMPreprocessor(
        train=False,
        token_type=train_args.token_type,
@@ -96,12 +81,12 @@
        split_with_space=split_with_space,
        seg_dict_file=seg_dict_file
    )
    def _forward(
        data_path_and_name_and_type,
        raw_inputs: Union[List[Any], bytes, str] = None,
        output_dir_v2: Optional[str] = None,
        param_dict: dict = None,
            data_path_and_name_and_type,
            raw_inputs: Union[List[Any], bytes, str] = None,
            output_dir_v2: Optional[str] = None,
            param_dict: dict = None,
    ):
        results = []
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
@@ -109,7 +94,7 @@
            writer = DatadirWriter(output_path)
        else:
            writer = None
        if raw_inputs != None:
            line = raw_inputs.strip()
            key = "lm demo"
@@ -121,7 +106,7 @@
            batch['text'] = line
            if preprocessor != None:
                batch = preprocessor(key, batch)
            #  Force data-precision
            for name in batch:
                value = batch[name]
@@ -138,11 +123,11 @@
                else:
                    raise NotImplementedError(f"Not supported dtype: {value.dtype}")
                batch[name] = value
            batch["text_lengths"] = torch.from_numpy(
                np.array([len(batch["text"])], dtype='int32'))
            batch["text"] = np.expand_dims(batch["text"], axis=0)
            with torch.no_grad():
                batch = to_device(batch, device)
                if ngpu <= 1:
@@ -173,7 +158,7 @@
                            word_nll=round(word_nll.item(), 8)
                        )
                        pre_word = cur_word
                    sent_nll_mean = sent_nll.mean().cpu().numpy()
                    sent_nll_sum = sent_nll.sum().cpu().numpy()
                    if log_base is None:
@@ -189,22 +174,20 @@
                    if writer is not None:
                        writer["ppl"][key + ":\n"] = ppl_out
                    results.append(item)
            return results
        # 3. Build data-iterator
        loader = LMTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="lm",
            preprocess_args=train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=preprocessor,
            collate_fn=LMTask.build_collate_fn(train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        # 4. Start for-loop
        total_nll = 0.0
        total_ntokens = 0
@@ -214,7 +197,7 @@
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            ppl_out_batch = ""
            with torch.no_grad():
                batch = to_device(batch, device)
@@ -247,7 +230,7 @@
                            word_nll=round(word_nll.item(), 8)
                        )
                        pre_word = cur_word
                    sent_nll_mean = sent_nll.mean().cpu().numpy()
                    sent_nll_sum = sent_nll.sum().cpu().numpy()
                    if log_base is None:
@@ -265,9 +248,9 @@
                        writer["ppl"][key + ":\n"] = ppl_out
                        writer["utt2nll"][key] = str(utt2nll)
                    results.append(item)
            ppl_out_all += ppl_out_batch
            assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
            # nll: (B, L) -> (B,)
            nll = nll.detach().cpu().numpy().sum(1)
@@ -275,12 +258,12 @@
            lengths = lengths.detach().cpu().numpy()
            total_nll += nll.sum()
            total_ntokens += lengths.sum()
        if log_base is None:
            ppl = np.exp(total_nll / total_ntokens)
        else:
            ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
        avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
            total_nll=round(-total_nll.item(), 4),
            total_ppl=round(ppl.item(), 4)
@@ -290,9 +273,9 @@
        if writer is not None:
            writer["ppl"]["AVG PPL : "] = avg_ppl
        results.append(item)
        return results
    return _forward
@@ -302,7 +285,8 @@
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Calc perplexity",
@@ -407,4 +391,3 @@
if __name__ == "__main__":
    main()
funasr/bin/punc_infer.py
@@ -1,46 +1,32 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Any
from typing import List
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.punctuation import PunctuationTask
from funasr.datasets.preprocessor import split_to_mini_sentence
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
    def __init__(
        self,
        train_config: Optional[str],
        model_file: Optional[str],
        device: str = "cpu",
        dtype: str = "float32",
            self,
            train_config: Optional[str],
            model_file: Optional[str],
            device: str = "cpu",
            dtype: str = "float32",
    ):
        #  Build Model
        model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
        model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
        self.device = device
        # Wrape model to make model.nll() data-parallel
        self.wrapped_model = ForwardAdaptor(model, "inference")
@@ -144,16 +130,16 @@
class Text2PuncVADRealtime:
    def __init__(
        self,
        train_config: Optional[str],
        model_file: Optional[str],
        device: str = "cpu",
        dtype: str = "float32",
            self,
            train_config: Optional[str],
            model_file: Optional[str],
            device: str = "cpu",
            dtype: str = "float32",
    ):
        #  Build Model
        model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
        model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
        self.device = device
        # Wrape model to make model.nll() data-parallel
        self.wrapped_model = ForwardAdaptor(model, "inference")
@@ -178,7 +164,7 @@
            text_name="text",
            non_linguistic_symbols=train_args.non_linguistic_symbols,
        )
    @torch.no_grad()
    def __call__(self, text: Union[list, str], cache: list, split_size=20):
        if cache is not None and len(cache) > 0:
@@ -215,7 +201,7 @@
            if indices.size()[0] != 1:
                punctuations = torch.squeeze(indices)
            assert punctuations.size()[0] == len(mini_sentence)
            # Search for the last Period/QuestionMark as cache
            if mini_sentence_i < len(mini_sentences) - 1:
                sentenceEnd = -1
@@ -226,7 +212,7 @@
                        break
                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
                        last_comma_index = i
                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
                    # The sentence it too long, cut off at a comma.
                    sentenceEnd = last_comma_index
@@ -235,11 +221,11 @@
                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
                mini_sentence = mini_sentence[0:sentenceEnd + 1]
                punctuations = punctuations[0:sentenceEnd + 1]
            punctuations_np = punctuations.cpu().numpy()
            sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
            sentence_words_list += mini_sentence
        assert len(sentence_punc_list) == len(sentence_words_list)
        words_with_punc = []
        sentence_punc_list_out = []
@@ -256,7 +242,7 @@
                if sentence_punc_list[i] != "_":
                    words_with_punc.append(sentence_punc_list[i])
        sentence_out = "".join(words_with_punc)
        sentenceEnd = -1
        for i in range(len(sentence_punc_list) - 2, 1, -1):
            if sentence_punc_list[i] == "。" or sentence_punc_list[i] == "?":
@@ -267,5 +253,3 @@
            sentence_out = sentence_out[:-1]
            sentence_punc_list_out[-1] = "_"
        return sentence_out, sentence_punc_list_out, cache_out
funasr/bin/punc_inference_launch.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -7,55 +7,36 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.types import float_or_none
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Any
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.punctuation import PunctuationTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.datasets.preprocessor import split_to_mini_sentence
from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
def inference_punc(
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    output_dir: Optional[str] = None,
    param_dict: dict = None,
    **kwargs,
        batch_size: int,
        dtype: str,
        ngpu: int,
        seed: int,
        num_workers: int,
        log_level: Union[int, str],
        key_file: Optional[str],
        train_config: Optional[str],
        model_file: Optional[str],
        output_dir: Optional[str] = None,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    logging.basicConfig(
@@ -73,11 +54,11 @@
    text2punc = Text2Punc(train_config, model_file, device)
    def _forward(
        data_path_and_name_and_type,
        raw_inputs: Union[List[Any], bytes, str] = None,
        output_dir_v2: Optional[str] = None,
        cache: List[Any] = None,
        param_dict: dict = None,
            data_path_and_name_and_type,
            raw_inputs: Union[List[Any], bytes, str] = None,
            output_dir_v2: Optional[str] = None,
            cache: List[Any] = None,
            param_dict: dict = None,
    ):
        results = []
        split_size = 20
@@ -121,20 +102,21 @@
    return _forward
def inference_punc_vad_realtime(
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    #cache: list,
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    output_dir: Optional[str] = None,
    param_dict: dict = None,
    **kwargs,
        batch_size: int,
        dtype: str,
        ngpu: int,
        seed: int,
        num_workers: int,
        log_level: Union[int, str],
        # cache: list,
        key_file: Optional[str],
        train_config: Optional[str],
        model_file: Optional[str],
        output_dir: Optional[str] = None,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
@@ -150,11 +132,11 @@
    text2punc = Text2PuncVADRealtime(train_config, model_file, device)
    def _forward(
        data_path_and_name_and_type,
        raw_inputs: Union[List[Any], bytes, str] = None,
        output_dir_v2: Optional[str] = None,
        cache: List[Any] = None,
        param_dict: dict = None,
            data_path_and_name_and_type,
            raw_inputs: Union[List[Any], bytes, str] = None,
            output_dir_v2: Optional[str] = None,
            cache: List[Any] = None,
            param_dict: dict = None,
    ):
        results = []
        split_size = 10
@@ -177,7 +159,6 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "punc":
        return inference_punc(**kwargs)
@@ -186,6 +167,7 @@
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
@@ -267,7 +249,6 @@
    kwargs.pop("njob", None)
    inference_pipeline = inference_launch(**kwargs)
    return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":
funasr/bin/sv_infer.py
@@ -1,35 +1,24 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import numpy as np
import torch
from kaldiio import WriteHelper
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.sv import SVTask
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.misc import statistic_model_parameters
class Speech2Xvector:
    """Speech2Xvector class
@@ -56,10 +45,13 @@
        assert check_argument_types()
        # TODO: 1. Build SV model
        sv_model, sv_train_args = SVTask.build_model_from_file(
        sv_model, sv_train_args = build_model_from_file(
            config_file=sv_train_config,
            model_file=sv_model_file,
            device=device
            cmvn_file=None,
            device=device,
            task_name="sv",
            mode="sv",
        )
        logging.info("sv_model: {}".format(sv_model))
        logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
@@ -157,7 +149,3 @@
            kwargs.update(**d.download_and_unpack(model_tag))
        return Speech2Xvector(**kwargs)
funasr/bin/sv_inference_launch.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -7,20 +7,6 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
@@ -30,61 +16,59 @@
import torch
from kaldiio import WriteHelper
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.sv import SVTask
from funasr.torch_utils.device_funcs import to_device
from funasr.bin.sv_infer import Speech2Xvector
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.misc import statistic_model_parameters
from funasr.bin.sv_infer import Speech2Xvector
def inference_sv(
    output_dir: Optional[str] = None,
    batch_size: int = 1,
    dtype: str = "float32",
    ngpu: int = 1,
    seed: int = 0,
    num_workers: int = 0,
    log_level: Union[int, str] = "INFO",
    key_file: Optional[str] = None,
    sv_train_config: Optional[str] = "sv.yaml",
    sv_model_file: Optional[str] = "sv.pb",
    model_tag: Optional[str] = None,
    allow_variable_data_keys: bool = True,
    streaming: bool = False,
    embedding_node: str = "resnet1_dense",
    sv_threshold: float = 0.9465,
    param_dict: Optional[dict] = None,
    **kwargs,
        output_dir: Optional[str] = None,
        batch_size: int = 1,
        dtype: str = "float32",
        ngpu: int = 1,
        seed: int = 0,
        num_workers: int = 0,
        log_level: Union[int, str] = "INFO",
        key_file: Optional[str] = None,
        sv_train_config: Optional[str] = "sv.yaml",
        sv_model_file: Optional[str] = "sv.pb",
        model_tag: Optional[str] = None,
        allow_variable_data_keys: bool = True,
        streaming: bool = False,
        embedding_node: str = "resnet1_dense",
        sv_threshold: float = 0.9465,
        param_dict: Optional[dict] = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    logging.info("param_dict: {}".format(param_dict))
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2xvector
    speech2xvector_kwargs = dict(
        sv_train_config=sv_train_config,
@@ -100,32 +84,31 @@
        **speech2xvector_kwargs,
    )
    speech2xvector.sv_model.eval()
    def _forward(
        data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
        raw_inputs: Union[np.ndarray, torch.Tensor] = None,
        output_dir_v2: Optional[str] = None,
        param_dict: Optional[dict] = None,
            data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
            output_dir_v2: Optional[str] = None,
            param_dict: Optional[dict] = None,
    ):
        logging.info("param_dict: {}".format(param_dict))
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        # 3. Build data-iterator
        loader = SVTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="sv",
            preprocess_args=None,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=None,
            collate_fn=None,
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
            use_collate_fn=False,
        )
        # 7 .Start for-loop
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        embd_writer, ref_embd_writer, score_writer = None, None, None
@@ -139,7 +122,7 @@
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            embedding, ref_embedding, score = speech2xvector(**batch)
            # Only supporting batch_size==1
            key = keys[0]
@@ -161,18 +144,16 @@
                        score_writer = open(os.path.join(output_path, "score.txt"), "w")
                    ref_embd_writer(key, ref_embedding[0].cpu().numpy())
                    score_writer.write("{} {:.6f}\n".format(key, normalized_score))
        if output_path is not None:
            embd_writer.close()
            if ref_embd_writer is not None:
                ref_embd_writer.close()
                score_writer.close()
        return sv_result_list
    return _forward
def inference_launch(mode, **kwargs):
@@ -182,6 +163,7 @@
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Speaker Verification",
funasr/bin/tp_infer.py
@@ -1,57 +1,35 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
from optparse import Option
import sys
import json
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.datasets.preprocessor import LMPreprocessor
from funasr.tasks.asr import ASRTaskAligner as ASRTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.text.token_id_converter import TokenIDConverter
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.torch_utils.device_funcs import to_device
class Speech2Timestamp:
    def __init__(
        self,
        timestamp_infer_config: Union[Path, str] = None,
        timestamp_model_file: Union[Path, str] = None,
        timestamp_cmvn_file: Union[Path, str] = None,
        device: str = "cpu",
        dtype: str = "float32",
        **kwargs,
            self,
            timestamp_infer_config: Union[Path, str] = None,
            timestamp_model_file: Union[Path, str] = None,
            timestamp_cmvn_file: Union[Path, str] = None,
            device: str = "cpu",
            dtype: str = "float32",
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        tp_model, tp_train_args = ASRTask.build_model_from_file(
            timestamp_infer_config, timestamp_model_file, device=device
        tp_model, tp_train_args = build_model_from_file(
            timestamp_infer_config, timestamp_model_file, cmvn_file=None, device=device, task_name="asr", mode="tp"
        )
        if 'cuda' in device:
            tp_model = tp_model.cuda()  # force model to cuda
@@ -59,13 +37,12 @@
        frontend = None
        if tp_train_args.frontend is not None:
            frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf)
        logging.info("tp_model: {}".format(tp_model))
        logging.info("tp_train_args: {}".format(tp_train_args))
        tp_model.to(dtype=getattr(torch, dtype)).eval()
        logging.info(f"Decoding device={device}, dtype={dtype}")
        self.tp_model = tp_model
        self.tp_train_args = tp_train_args
@@ -79,13 +56,13 @@
        self.encoder_downsampling_factor = 1
        if tp_train_args.encoder_conf["input_layer"] == "conv2d":
            self.encoder_downsampling_factor = 4
    @torch.no_grad()
    def __call__(
        self,
        speech: Union[torch.Tensor, np.ndarray],
        speech_lengths: Union[torch.Tensor, np.ndarray] = None,
        text_lengths: Union[torch.Tensor, np.ndarray] = None
            self,
            speech: Union[torch.Tensor, np.ndarray],
            speech_lengths: Union[torch.Tensor, np.ndarray] = None,
            text_lengths: Union[torch.Tensor, np.ndarray] = None
    ):
        assert check_argument_types()
@@ -113,8 +90,6 @@
            enc = enc[0]
        # c. Forward Predictor
        _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1)
        _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len,
                                                                           text_lengths.to(self.device) + 1)
        return us_alphas, us_peaks
funasr/bin/tp_inference_launch.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -8,87 +8,66 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
import argparse
import logging
from optparse import Option
import sys
import json
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.datasets.preprocessor import LMPreprocessor
from funasr.tasks.asr import ASRTaskAligner as ASRTask
from funasr.torch_utils.device_funcs import to_device
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.text.token_id_converter import TokenIDConverter
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.bin.tp_infer import Speech2Timestamp
def inference_tp(
    batch_size: int,
    ngpu: int,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    timestamp_infer_config: Optional[str],
    timestamp_model_file: Optional[str],
    timestamp_cmvn_file: Optional[str] = None,
    # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
    key_file: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    num_workers: int = 1,
    split_with_space: bool = True,
    seg_dict_file: Optional[str] = None,
    **kwargs,
        batch_size: int,
        ngpu: int,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        timestamp_infer_config: Optional[str],
        timestamp_model_file: Optional[str],
        timestamp_cmvn_file: Optional[str] = None,
        # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
        key_file: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        num_workers: int = 1,
        split_with_space: bool = True,
        seg_dict_file: Optional[str] = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2vadsegment
    speechtext2timestamp_kwargs = dict(
        timestamp_infer_config=timestamp_infer_config,
@@ -99,7 +78,7 @@
    )
    logging.info("speechtext2timestamp_kwargs: {}".format(speechtext2timestamp_kwargs))
    speechtext2timestamp = Speech2Timestamp(**speechtext2timestamp_kwargs)
    preprocessor = LMPreprocessor(
        train=False,
        token_type=speechtext2timestamp.tp_train_args.token_type,
@@ -112,21 +91,21 @@
        split_with_space=split_with_space,
        seg_dict_file=seg_dict_file,
    )
    if output_dir is not None:
        writer = DatadirWriter(output_dir)
        tp_writer = writer[f"timestamp_prediction"]
        # ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
    else:
        tp_writer = None
    def _forward(
        data_path_and_name_and_type,
        raw_inputs: Union[np.ndarray, torch.Tensor] = None,
        output_dir_v2: Optional[str] = None,
        fs: dict = None,
        param_dict: dict = None,
        **kwargs
            data_path_and_name_and_type,
            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
            output_dir_v2: Optional[str] = None,
            fs: dict = None,
            param_dict: dict = None,
            **kwargs
    ):
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        writer = None
@@ -140,32 +119,31 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speechtext2timestamp.tp_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=preprocessor,
            collate_fn=ASRTask.build_collate_fn(speechtext2timestamp.tp_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        tp_result_list = []
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            logging.info("timestamp predicting, utt_id: {}".format(keys))
            _batch = {'speech': batch['speech'],
                      'speech_lengths': batch['speech_lengths'],
                      'text_lengths': batch['text_lengths']}
            us_alphas, us_cif_peak = speechtext2timestamp(**_batch)
            for batch_id in range(_bs):
                key = keys[batch_id]
                token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
@@ -178,10 +156,8 @@
                    tp_writer["tp_time"][key + '#'] = str(ts_list)
                tp_result_list.append(item)
        return tp_result_list
    return _forward
def inference_launch(mode, **kwargs):
@@ -190,6 +166,7 @@
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
@@ -306,7 +283,6 @@
    inference_pipeline = inference_launch(**kwargs)
    return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":
funasr/bin/train.py
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
funasr/bin/vad_infer.py
@@ -1,42 +1,23 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
import json
import math
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import math
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.vad import VADTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.torch_utils.device_funcs import to_device
class Speech2VadSegment:
@@ -64,8 +45,8 @@
        assert check_argument_types()
        # 1. Build vad model
        vad_model, vad_infer_args = VADTask.build_model_from_file(
            vad_infer_config, vad_model_file, device
        vad_model, vad_infer_args = build_model_from_file(
            vad_infer_config, vad_model_file, None, device, task_name="vad"
        )
        frontend = None
        if vad_infer_args.frontend is not None:
@@ -128,12 +109,13 @@
                "in_cache": in_cache
            }
            # a. To device
            #batch = to_device(batch, device=self.device)
            # batch = to_device(batch, device=self.device)
            segments_part, in_cache = self.vad_model(**batch)
            if segments_part:
                for batch_num in range(0, self.batch_size):
                    segments[batch_num] += segments_part[batch_num]
        return fbanks, segments
class Speech2VadSegmentOnline(Speech2VadSegment):
    """Speech2VadSegmentOnline class
@@ -146,13 +128,13 @@
        [[10, 230], [245, 450], ...]
    """
    def __init__(self, **kwargs):
        super(Speech2VadSegmentOnline, self).__init__(**kwargs)
        vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
        self.frontend = None
        if self.vad_infer_args.frontend is not None:
            self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
    @torch.no_grad()
    def __call__(
@@ -198,5 +180,3 @@
            # in_cache.update(batch['in_cache'])
            # in_cache = {key: value for key, value in batch['in_cache'].items()}
        return fbanks, segments, in_cache
funasr/bin/vad_inference_launch.py
@@ -1,58 +1,34 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
import argparse
import logging
import os
import sys
import json
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import math
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.vad import VADTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.bin.vad_infer import Speech2VadSegment, Speech2VadSegmentOnline
def inference_vad(
        batch_size: int,
@@ -74,7 +50,6 @@
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    logging.basicConfig(
        level=log_level,
@@ -112,16 +87,14 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = VADTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="vad",
            preprocess_args=None,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
            collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
@@ -157,6 +130,7 @@
    return _forward
def inference_vad_online(
        batch_size: int,
        ngpu: int,
@@ -175,7 +149,6 @@
        **kwargs,
):
    assert check_argument_types()
    logging.basicConfig(
        level=log_level,
@@ -214,16 +187,14 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = VADTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="vad",
            preprocess_args=None,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
            collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
@@ -273,8 +244,6 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "offline":
        return inference_vad(**kwargs)
@@ -283,6 +252,7 @@
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
@@ -405,5 +375,6 @@
    inference_pipeline = inference_launch(**kwargs)
    return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":
    main()
funasr/build_utils/build_args.py
@@ -41,7 +41,7 @@
            "--cmvn_file",
            type=str_or_none,
            default=None,
            help="The file path of noise scp file.",
            help="The path of cmvn file.",
        )
    elif args.task_name == "pretrain":
@@ -75,12 +75,29 @@
            default=None,
            help="The number of input dimension of the feature",
        )
        task_parser.add_argument(
            "--cmvn_file",
            type=str_or_none,
            default=None,
            help="The path of cmvn file.",
        )
    elif args.task_name == "diar":
        from funasr.build_utils.build_diar_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(task_parser)
    elif args.task_name == "sv":
        from funasr.build_utils.build_sv_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(task_parser)
        task_parser.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
    else:
        raise NotImplementedError("Not supported task: {}".format(args.task_name))
funasr/build_utils/build_asr_model.py
@@ -20,15 +20,18 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
@@ -42,6 +45,7 @@
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
@@ -89,6 +93,7 @@
        paraformer_bert=ParaformerBert,
        bicif_paraformer=BiCifParaformer,
        contextual_paraformer=ContextualParaformer,
        neatcontextual_paraformer=NeatContextualParaformer,
        mfcca=MFCCA,
        timestamp_prediction=TimestampPredictor,
        rnnt=TransducerModel,
@@ -258,17 +263,22 @@
def build_asr_model(args):
    # token_list
    if args.token_list is not None:
        with open(args.token_list) as f:
    if isinstance(args.token_list, str):
        with open(args.token_list, encoding="utf-8") as f:
            token_list = [line.rstrip() for line in f]
        args.token_list = list(token_list)
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
    elif isinstance(args.token_list, (tuple, list)):
        token_list = list(args.token_list)
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
    else:
        token_list = None
        vocab_size = None
    # frontend
    if args.input_size is None:
    if hasattr(args, "input_size") and args.input_size is None:
        frontend_class = frontend_choices.get_class(args.frontend)
        if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
            frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
@@ -279,7 +289,7 @@
        args.frontend = None
        args.frontend_conf = {}
        frontend = None
        input_size = args.input_size
        input_size = args.input_size if hasattr(args, "input_size") else None
    # data augmentation for spectrogram
    if args.specaug is not None:
@@ -291,7 +301,10 @@
    # normalization layer
    if args.normalize is not None:
        normalize_class = normalize_choices.get_class(args.normalize)
        normalize = normalize_class(**args.normalize_conf)
        if args.model == "mfcca":
            normalize = normalize_class(stats_file=args.cmvn_file, **args.normalize_conf)
        else:
            normalize = normalize_class(**args.normalize_conf)
    else:
        normalize = None
@@ -325,7 +338,8 @@
            token_list=token_list,
            **args.model_conf,
        )
    elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
    elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer",
                        "contextual_paraformer", "neatcontextual_paraformer"]:
        # predictor
        predictor_class = predictor_choices.get_class(args.predictor)
        predictor = predictor_class(**args.predictor_conf)
funasr/build_utils/build_diar_model.py
@@ -178,14 +178,18 @@
def build_diar_model(args):
    # token_list
    if args.token_list is not None:
        with open(args.token_list) as f:
    if isinstance(args.token_list, str):
        with open(args.token_list, encoding="utf-8") as f:
            token_list = [line.rstrip() for line in f]
        # Overwriting token_list to keep it as "portable".
        args.token_list = list(token_list)
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
    elif isinstance(args.token_list, (tuple, list)):
        token_list = list(args.token_list)
    else:
        vocab_size = None
        raise RuntimeError("token_list must be str or list")
    vocab_size = len(token_list)
    logging.info(f"Vocabulary size: {vocab_size}")
    # frontend
    if args.input_size is None:
@@ -205,7 +209,7 @@
    encoder_class = encoder_choices.get_class(args.encoder)
    encoder = encoder_class(input_size=input_size, **args.encoder_conf)
    if args.model_name == "sond":
    if args.model == "sond":
        # data augmentation for spectrogram
        if args.specaug is not None:
            specaug_class = specaug_choices.get_class(args.specaug)
@@ -243,11 +247,7 @@
        # decoder
        decoder_class = decoder_choices.get_class(args.decoder)
        decoder = decoder_class(
            vocab_size=vocab_size,
            encoder_output_size=encoder.output_size(),
            **args.decoder_conf,
        )
        decoder = decoder_class(**args.decoder_conf)
        # logger aggregator
        if getattr(args, "label_aggregator", None) is not None:
funasr/build_utils/build_lm_model.py
@@ -34,10 +34,14 @@
def build_lm_model(args):
    # token_list
    if args.token_list is not None:
        with open(args.token_list) as f:
    if isinstance(args.token_list, str):
        with open(args.token_list, encoding="utf-8") as f:
            token_list = [line.rstrip() for line in f]
        args.token_list = list(token_list)
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
    elif isinstance(args.token_list, (tuple, list)):
        token_list = list(args.token_list)
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
    else:
@@ -47,6 +51,7 @@
    lm_class = lm_choices.get_class(args.lm)
    lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
    args.model = args.model if hasattr(args, "model") else "lm"
    model_class = model_choices.get_class(args.model)
    model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
funasr/build_utils/build_model.py
@@ -1,9 +1,10 @@
from funasr.build_utils.build_asr_model import build_asr_model
from funasr.build_utils.build_diar_model import build_diar_model
from funasr.build_utils.build_lm_model import build_lm_model
from funasr.build_utils.build_pretrain_model import build_pretrain_model
from funasr.build_utils.build_punc_model import build_punc_model
from funasr.build_utils.build_sv_model import build_sv_model
from funasr.build_utils.build_vad_model import build_vad_model
from funasr.build_utils.build_diar_model import build_diar_model
def build_model(args):
@@ -19,6 +20,8 @@
        model = build_vad_model(args)
    elif args.task_name == "diar":
        model = build_diar_model(args)
    elif args.task_name == "sv":
        model = build_sv_model(args)
    else:
        raise NotImplementedError("Not supported task: {}".format(args.task_name))
funasr/build_utils/build_model_from_file.py
New file
@@ -0,0 +1,193 @@
import argparse
import logging
import os
from pathlib import Path
from typing import Union
import torch
import yaml
from typeguard import check_argument_types
from funasr.build_utils.build_model import build_model
from funasr.models.base_model import FunASRModel
def build_model_from_file(
        config_file: Union[Path, str] = None,
        model_file: Union[Path, str] = None,
        cmvn_file: Union[Path, str] = None,
        device: str = "cpu",
        task_name: str = "asr",
        mode: str = "paraformer",
):
    """Build model from the files.
    This method is used for inference or fine-tuning.
    Args:
        config_file: The yaml file saved when training.
        model_file: The model file saved when training.
        device: Device type, "cpu", "cuda", or "cuda:N".
    """
    assert check_argument_types()
    if config_file is None:
        assert model_file is not None, (
            "The argument 'model_file' must be provided "
            "if the argument 'config_file' is not specified."
        )
        config_file = Path(model_file).parent / "config.yaml"
    else:
        config_file = Path(config_file)
    with config_file.open("r", encoding="utf-8") as f:
        args = yaml.safe_load(f)
    if cmvn_file is not None:
        args["cmvn_file"] = cmvn_file
    args = argparse.Namespace(**args)
    args.task_name = task_name
    model = build_model(args)
    if not isinstance(model, FunASRModel):
        raise RuntimeError(
            f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
        )
    model.to(device)
    model_dict = dict()
    model_name_pth = None
    if model_file is not None:
        logging.info("model_file is {}".format(model_file))
        if device == "cuda":
            device = f"cuda:{torch.cuda.current_device()}"
        model_dir = os.path.dirname(model_file)
        model_name = os.path.basename(model_file)
        if "model.ckpt-" in model_name or ".bin" in model_name:
            model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
                                                                        '.pb')) if ".bin" in model_name else os.path.join(
                model_dir, "{}.pb".format(model_name))
            if os.path.exists(model_name_pth):
                logging.info("model_file is load from pth: {}".format(model_name_pth))
                model_dict = torch.load(model_name_pth, map_location=device)
            else:
                model_dict = convert_tf2torch(model, model_file, mode)
            model.load_state_dict(model_dict)
        else:
            model_dict = torch.load(model_file, map_location=device)
    if task_name == "diar" and mode == "sond":
        model_dict = fileter_model_dict(model_dict, model.state_dict())
    if task_name == "vad":
        model.encoder.load_state_dict(model_dict)
    else:
        model.load_state_dict(model_dict)
    if model_name_pth is not None and not os.path.exists(model_name_pth):
        torch.save(model_dict, model_name_pth)
        logging.info("model_file is saved to pth: {}".format(model_name_pth))
    return model, args
def convert_tf2torch(
        model,
        ckpt,
        mode,
):
    assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" or mode == "tp"
    logging.info("start convert tf model to torch model")
    from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
    var_dict_tf = load_tf_dict(ckpt)
    var_dict_torch = model.state_dict()
    var_dict_torch_update = dict()
    if mode == "uniasr":
        # encoder
        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # predictor
        var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # encoder2
        var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # predictor2
        var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder2
        var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # stride_conv
        var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
    elif mode == "paraformer":
        # encoder
        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # predictor
        var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # bias_encoder
        var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
    elif "mode" == "sond":
        if model.encoder is not None:
            var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # speaker encoder
        if model.speaker_encoder is not None:
            var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # cd scorer
        if model.cd_scorer is not None:
            var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # ci scorer
        if model.ci_scorer is not None:
            var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        if model.decoder is not None:
            var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
    elif "mode" == "sv":
        # speech encoder
        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # pooling layer
        var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
    else:
        # encoder
        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # predictor
        var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # bias_encoder
        var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        return var_dict_torch_update
    return var_dict_torch_update
def fileter_model_dict(src_dict: dict, dest_dict: dict):
    from collections import OrderedDict
    new_dict = OrderedDict()
    for key, value in src_dict.items():
        if key in dest_dict:
            new_dict[key] = value
        else:
            logging.info("{} is no longer needed in this model.".format(key))
    for key, value in dest_dict.items():
        if key not in new_dict:
            logging.warning("{} is missed in checkpoint.".format(key))
    return new_dict
funasr/build_utils/build_streaming_iterator.py
New file
@@ -0,0 +1,67 @@
import numpy as np
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from funasr.datasets.iterable_dataset import IterableESPnetDataset
from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
from funasr.datasets.small_datasets.preprocessor import build_preprocess
def build_streaming_iterator(
        task_name,
        preprocess_args,
        data_path_and_name_and_type,
        key_file: str = None,
        batch_size: int = 1,
        fs: dict = None,
        mc: bool = False,
        dtype: str = np.float32,
        num_workers: int = 1,
        use_collate_fn: bool = True,
        preprocess_fn=None,
        ngpu: int = 0,
        train: bool = False,
) -> DataLoader:
    """Build DataLoader using iterable dataset"""
    assert check_argument_types()
    # preprocess
    if preprocess_fn is not None:
        preprocess_fn = preprocess_fn
    elif preprocess_args is not None:
        preprocess_args.task_name = task_name
        preprocess_fn = build_preprocess(preprocess_args, train)
    else:
        preprocess_fn = None
    # collate
    if not use_collate_fn:
        collate_fn = None
    elif task_name in ["punc", "lm"]:
        collate_fn = CommonCollateFn(int_pad_value=0)
    else:
        collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
    if collate_fn is not None:
        kwargs = dict(collate_fn=collate_fn)
    else:
        kwargs = {}
    dataset = IterableESPnetDataset(
        data_path_and_name_and_type,
        float_dtype=dtype,
        fs=fs,
        mc=mc,
        preprocess=preprocess_fn,
        key_file=key_file,
    )
    if dataset.apply_utt2category:
        kwargs.update(batch_size=1)
    else:
        kwargs.update(batch_size=batch_size)
    return DataLoader(
        dataset=dataset,
        pin_memory=ngpu > 0,
        num_workers=num_workers,
        **kwargs,
    )
funasr/build_utils/build_sv_model.py
New file
@@ -0,0 +1,258 @@
import logging
import torch
from typeguard import check_return_type
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.base_model import FunASRModel
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.sv_decoder import DenseDecoder
from funasr.models.e2e_sv import ESPnetSVModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.pooling.statistic_pooling import StatisticPooling
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
    HuggingFaceTransformersPostEncoder,  # noqa: H301
)
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.preencoder.linear import LinearProjection
from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
frontend_choices = ClassChoices(
    name="frontend",
    classes=dict(
        default=DefaultFrontend,
        sliding_window=SlidingWindow,
        s3prl=S3prlFrontend,
        fused=FusedFrontends,
        wav_frontend=WavFrontend,
    ),
    type_check=AbsFrontend,
    default="default",
)
specaug_choices = ClassChoices(
    name="specaug",
    classes=dict(
        specaug=SpecAug,
    ),
    type_check=AbsSpecAug,
    default=None,
    optional=True,
)
normalize_choices = ClassChoices(
    "normalize",
    classes=dict(
        global_mvn=GlobalMVN,
        utterance_mvn=UtteranceMVN,
    ),
    type_check=AbsNormalize,
    default=None,
    optional=True,
)
model_choices = ClassChoices(
    "model",
    classes=dict(
        espnet=ESPnetSVModel,
    ),
    type_check=FunASRModel,
    default="espnet",
)
preencoder_choices = ClassChoices(
    name="preencoder",
    classes=dict(
        sinc=LightweightSincConvs,
        linear=LinearProjection,
    ),
    type_check=AbsPreEncoder,
    default=None,
    optional=True,
)
encoder_choices = ClassChoices(
    "encoder",
    classes=dict(
        resnet34=ResNet34,
        resnet34_sp_l2reg=ResNet34_SP_L2Reg,
        rnn=RNNEncoder,
    ),
    type_check=AbsEncoder,
    default="resnet34",
)
postencoder_choices = ClassChoices(
    name="postencoder",
    classes=dict(
        hugging_face_transformers=HuggingFaceTransformersPostEncoder,
    ),
    type_check=AbsPostEncoder,
    default=None,
    optional=True,
)
pooling_choices = ClassChoices(
    name="pooling_type",
    classes=dict(
        statistic=StatisticPooling,
    ),
    type_check=torch.nn.Module,
    default="statistic",
)
decoder_choices = ClassChoices(
    "decoder",
    classes=dict(
        dense=DenseDecoder,
    ),
    type_check=AbsDecoder,
    default="dense",
)
class_choices_list = [
    # --frontend and --frontend_conf
    frontend_choices,
    # --specaug and --specaug_conf
    specaug_choices,
    # --normalize and --normalize_conf
    normalize_choices,
    # --model and --model_conf
    model_choices,
    # --preencoder and --preencoder_conf
    preencoder_choices,
    # --encoder and --encoder_conf
    encoder_choices,
    # --postencoder and --postencoder_conf
    postencoder_choices,
    # --pooling and --pooling_conf
    pooling_choices,
    # --decoder and --decoder_conf
    decoder_choices,
]
def build_sv_model(args):
    # token_list
    if isinstance(args.token_list, str):
        with open(args.token_list, encoding="utf-8") as f:
            token_list = [line.rstrip() for line in f]
        # Overwriting token_list to keep it as "portable".
        args.token_list = list(token_list)
    elif isinstance(args.token_list, (tuple, list)):
        token_list = list(args.token_list)
    else:
        raise RuntimeError("token_list must be str or list")
    vocab_size = len(token_list)
    logging.info(f"Speaker number: {vocab_size}")
    # 1. frontend
    if args.input_size is None:
        # Extract features in the model
        frontend_class = frontend_choices.get_class(args.frontend)
        frontend = frontend_class(**args.frontend_conf)
        input_size = frontend.output_size()
    else:
        # Give features from data-loader
        args.frontend = None
        args.frontend_conf = {}
        frontend = None
        input_size = args.input_size
    # 2. Data augmentation for spectrogram
    if args.specaug is not None:
        specaug_class = specaug_choices.get_class(args.specaug)
        specaug = specaug_class(**args.specaug_conf)
    else:
        specaug = None
    # 3. Normalization layer
    if args.normalize is not None:
        normalize_class = normalize_choices.get_class(args.normalize)
        normalize = normalize_class(**args.normalize_conf)
    else:
        normalize = None
    # 4. Pre-encoder input block
    # NOTE(kan-bayashi): Use getattr to keep the compatibility
    if getattr(args, "preencoder", None) is not None:
        preencoder_class = preencoder_choices.get_class(args.preencoder)
        preencoder = preencoder_class(**args.preencoder_conf)
        input_size = preencoder.output_size()
    else:
        preencoder = None
    # 5. Encoder
    encoder_class = encoder_choices.get_class(args.encoder)
    encoder = encoder_class(input_size=input_size, **args.encoder_conf)
    # 6. Post-encoder block
    # NOTE(kan-bayashi): Use getattr to keep the compatibility
    encoder_output_size = encoder.output_size()
    if getattr(args, "postencoder", None) is not None:
        postencoder_class = postencoder_choices.get_class(args.postencoder)
        postencoder = postencoder_class(
            input_size=encoder_output_size, **args.postencoder_conf
        )
        encoder_output_size = postencoder.output_size()
    else:
        postencoder = None
    # 7. Pooling layer
    pooling_class = pooling_choices.get_class(args.pooling_type)
    pooling_dim = (2, 3)
    eps = 1e-12
    if hasattr(args, "pooling_type_conf"):
        if "pooling_dim" in args.pooling_type_conf:
            pooling_dim = args.pooling_type_conf["pooling_dim"]
        if "eps" in args.pooling_type_conf:
            eps = args.pooling_type_conf["eps"]
    pooling_layer = pooling_class(
        pooling_dim=pooling_dim,
        eps=eps,
    )
    if args.pooling_type == "statistic":
        encoder_output_size *= 2
    # 8. Decoder
    decoder_class = decoder_choices.get_class(args.decoder)
    decoder = decoder_class(
        vocab_size=vocab_size,
        encoder_output_size=encoder_output_size,
        **args.decoder_conf,
    )
    # 7. Build model
    try:
        model_class = model_choices.get_class(args.model)
    except AttributeError:
        model_class = model_choices.get_class("espnet")
    model = model_class(
        vocab_size=vocab_size,
        token_list=token_list,
        frontend=frontend,
        specaug=specaug,
        normalize=normalize,
        preencoder=preencoder,
        encoder=encoder,
        postencoder=postencoder,
        pooling_layer=pooling_layer,
        decoder=decoder,
        **args.model_conf,
    )
    # FIXME(kamo): Should be done in model?
    # 8. Initialize
    if args.init is not None:
        initialize(model, args.init)
    assert check_return_type(model)
    return model
funasr/build_utils/build_vad_model.py
@@ -50,6 +50,10 @@
def build_vad_model(args):
    # frontend
    if not hasattr(args, "cmvn_file"):
        args.cmvn_file = None
    if not hasattr(args, "init"):
        args.init = None
    if args.input_size is None:
        frontend_class = frontend_choices.get_class(args.frontend)
        if args.frontend == 'wav_frontend':
funasr/models/e2e_asr_contextual_paraformer.py
@@ -43,9 +43,7 @@
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        preencoder: Optional[AbsPreEncoder],
        encoder: AbsEncoder,
        postencoder: Optional[AbsPostEncoder],
        decoder: AbsDecoder,
        ctc: CTC,
        ctc_weight: float = 0.5,
@@ -72,6 +70,8 @@
        crit_attn_weight: float = 0.0,
        crit_attn_smooth: float = 0.0,
        bias_encoder_dropout_rate: float = 0.0,
        preencoder: Optional[AbsPreEncoder] = None,
        postencoder: Optional[AbsPostEncoder] = None,
    ):
        assert check_argument_types()
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
funasr/models/e2e_asr_mfcca.py
@@ -53,7 +53,7 @@
            encoder: AbsEncoder,
            decoder: AbsDecoder,
            ctc: CTC,
            rnnt_decoder: None,
            rnnt_decoder: None = None,
            ctc_weight: float = 0.5,
            ignore_id: int = -1,
            lsm_weight: float = 0.0,
funasr/models/e2e_uni_asr.py
@@ -50,9 +50,7 @@
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        preencoder: Optional[AbsPreEncoder],
        encoder: AbsEncoder,
        postencoder: Optional[AbsPostEncoder],
        decoder: AbsDecoder,
        ctc: CTC,
        ctc_weight: float = 0.5,
@@ -80,6 +78,8 @@
        loss_weight_model1: float = 0.5,
        enable_maas_finetune: bool = False,
        freeze_encoder2: bool = False,
        preencoder: Optional[AbsPreEncoder] = None,
        postencoder: Optional[AbsPostEncoder] = None,
        encoder1_encoder2_joint_training: bool = True,
    ):
        assert check_argument_types()
funasr/models/e2e_vad.py
@@ -5,6 +5,7 @@
from torch import nn
import math
from funasr.models.encoder.fsmn_encoder import FSMN
from funasr.models.base_model import FunASRModel
class VadStateMachine(Enum):
@@ -211,7 +212,7 @@
        return int(self.frame_size_ms)
class E2EVadModel(nn.Module):
class E2EVadModel(FunASRModel):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
tests/test_sv_inference_pipeline.py
@@ -35,4 +35,4 @@
        logger.info(f"Similarity {rec_result['scores']}")
if __name__ == '__main__':
    unittest.main()
    unittest.main()
tests/test_vad_inference_pipeline.py
@@ -37,7 +37,7 @@
        rec_result = inference_pipeline(
            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav')
        logger.info("vad inference result: {0}".format(rec_result))
        assert rec_result["text"] == [[80, 2340], [2620, 6200], [6480, 23670], [23950, 26250], [26780, 28990],
        assert rec_result["text"] == [[70, 2340], [2620, 6200], [6480, 23670], [23950, 26250], [26780, 28990],
                                      [29950, 31430], [31750, 37600], [38210, 46900], [47310, 49630], [49910, 56460],
                                      [56740, 59540], [59820, 70450]]