jmwang66
2023-06-20 2ff405b2f4ab899eff9bece232969fbb0c8f0555
Merge pull request #653 from alibaba-damo-academy/dev_wjm_infer

Dev wjm infer
27个文件已修改
3个文件已添加
1360 ■■■■■ 已修改文件
.github/workflows/UnitTest.yml 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_infer.py 116 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py 155 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/diar_infer.py 45 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/diar_inference_launch.py 63 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/lm_inference_launch.py 43 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_infer.py 26 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_inference_launch.py 35 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/sv_infer.py 28 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/sv_inference_launch.py 40 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/tp_infer.py 39 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/tp_inference_launch.py 46 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/vad_infer.py 38 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/vad_inference_launch.py 59 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_args.py 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_asr_model.py 28 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_diar_model.py 20 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_lm_model.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_model.py 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_model_from_file.py 193 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_streaming_iterator.py 67 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_sv_model.py 258 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_vad_model.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_contextual_paraformer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_mfcca.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_uni_asr.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_vad.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
tests/test_sv_inference_pipeline.py 补丁 | 查看 | 原始文档 | blame | 历史
tests/test_vad_inference_pipeline.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
.github/workflows/UnitTest.yml
@@ -8,6 +8,7 @@
    branches:
      - dev_wjm
      - dev_jy
      - dev_wjm_infer
jobs:
  build:
@@ -18,6 +19,12 @@
        python-version: ["3.7"]
    steps:
      - name: Remove unnecessary files
        run:
          sudo rm -rf /usr/share/dotnet
          sudo rm -rf /opt/ghc
          sudo rm -rf "/usr/local/share/boost"
          sudo rm -rf "$AGENT_TOOLSDIRECTORY"
      - uses: actions/checkout@v3
      - name: Set up Python ${{ matrix.python-version }}
        uses: actions/setup-python@v4
funasr/bin/asr_infer.py
@@ -1,66 +1,46 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import sys
import time
import codecs
import copy
import logging
import os
import re
import codecs
import tempfile
import requests
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import numpy as np
import requests
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from  funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.modules.beam_search.beam_search import BeamSearch
# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.asr import ASRTask
from funasr.tasks.lm import LMTask
from funasr.build_utils.build_asr_model import frontend_choices
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.vad_infer import Speech2VadSegment
from funasr.bin.punc_infer import Text2Punc
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.tasks.asr import frontend_choices
class Speech2Text:
    """Speech2Text class
@@ -102,7 +82,7 @@
        
        # 1. Build ASR model
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
@@ -110,7 +90,6 @@
            if asr_train_args.frontend == 'wav_frontend':
                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
            else:
                from funasr.tasks.asr import frontend_choices
                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
        
@@ -130,7 +109,7 @@
        
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, None, device
            )
            scorers["lm"] = lm.lm
@@ -272,6 +251,7 @@
        assert check_return_type(results)
        return results
class Speech2TextParaformer:
    """Speech2Text class
@@ -312,9 +292,8 @@
        # 1. Build ASR model
        scorers = {}
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -336,8 +315,8 @@
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            scorers["lm"] = lm.lm
@@ -466,13 +445,16 @@
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, NeatContextualParaformer):
        if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
                                                                                   NeatContextualParaformer):
            if self.hotword_list:
                logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                     pre_token_length)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        else:
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                     pre_token_length, hw_list=self.hotword_list)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        if isinstance(self.asr_model, BiCifParaformer):
@@ -533,7 +515,6 @@
                                                            vad_offset=begin_time)
                results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
        # assert check_return_type(results)
        return results
@@ -591,6 +572,7 @@
            hotword_list = None
        return hotword_list
class Speech2TextParaformerOnline:
    """Speech2Text class
@@ -630,9 +612,8 @@
        # 1. Build ASR model
        scorers = {}
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -654,8 +635,8 @@
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            scorers["lm"] = lm.lm
@@ -845,6 +826,7 @@
        # assert check_return_type(results)
        return results
class Speech2TextUniASR:
    """Speech2Text class
@@ -886,9 +868,8 @@
        # 1. Build ASR model
        scorers = {}
        from funasr.tasks.asr import ASRTaskUniASR as ASRTask
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device, mode="uniasr"
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -914,8 +895,8 @@
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, device, "lm"
            )
            scorers["lm"] = lm.lm
@@ -1117,9 +1098,8 @@
        assert check_argument_types()
        
        # 1. Build ASR model
        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        
@@ -1139,8 +1119,8 @@
        
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            lm.to(device)
            scorers["lm"] = lm.lm
@@ -1328,8 +1308,7 @@
        super().__init__()
        
        assert check_argument_types()
        from funasr.tasks.asr import ASRTransducerTask
        asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        
@@ -1363,8 +1342,8 @@
            asr_model.to(dtype=getattr(torch, dtype)).eval()
        
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            lm_scorer = lm.lm
        else:
@@ -1651,9 +1630,8 @@
        assert check_argument_types()
        
        # 1. Build ASR model
        from funasr.tasks.asr import ASRTaskSAASR
        scorers = {}
        asr_model, asr_train_args = ASRTaskSAASR.build_model_from_file(
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
@@ -1682,8 +1660,8 @@
        
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, None, device
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            scorers["lm"] = lm.lm
        
funasr/bin/asr_inference_launch.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -7,77 +7,45 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
#!/usr/bin/env python3
import argparse
import logging
import sys
import time
import copy
import os
import codecs
import tempfile
import requests
from pathlib import Path
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import yaml
import numpy as np
import torch
import torchaudio
import yaml
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import BeamSearch
# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr.bin.asr_infer import Speech2Text
from funasr.bin.asr_infer import Speech2TextMFCCA
from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
from funasr.bin.asr_infer import Speech2TextSAASR
from funasr.bin.asr_infer import Speech2TextTransducer
from funasr.bin.asr_infer import Speech2TextUniASR
from funasr.bin.punc_infer import Text2Punc
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.vad_infer import Speech2VadSegment
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.asr import ASRTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import asr_utils, postprocess_utils
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.bin.asr_infer import Speech2Text
from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
from funasr.bin.asr_infer import Speech2TextUniASR
from funasr.bin.asr_infer import Speech2TextMFCCA
from funasr.bin.vad_infer import Speech2VadSegment
from funasr.bin.punc_infer import Text2Punc
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.asr_infer import Speech2TextTransducer
from funasr.bin.asr_infer import Speech2TextSAASR
def inference_asr(
    maxlenratio: float,
@@ -173,18 +141,16 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            mc=mc,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        
        finish_count = 0
@@ -360,17 +326,15 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        
        if param_dict is not None:
@@ -611,17 +575,15 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=None,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            batch_size=1,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
            collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        
        if param_dict is not None:
@@ -657,6 +619,7 @@
            data_with_index = [(vadsegments[i], i) for i in range(n)]
            sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
            results_sorted = []
            batch_size_token_ms = batch_size_token*60
            if speech2text.device == "cpu":
                batch_size_token_ms = 0
@@ -666,7 +629,8 @@
            beg_idx = 0
            for j, _ in enumerate(range(0, n)):
                batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
                if j < n-1 and (batch_size_token_ms_cum + sorted_data[j+1][0][1] - sorted_data[j+1][0][0])<batch_size_token_ms:
                if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][
                    0]) < batch_size_token_ms:
                    continue
                batch_size_token_ms_cum = 0
                end_idx = j + 1
@@ -752,6 +716,7 @@
        return asr_result_list
    
    return _forward
def inference_paraformer_online(
        maxlenratio: float,
@@ -875,7 +840,8 @@
            feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
            cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
                        "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
                        "tail_chunk": False}
            cache["encoder"] = cache_en
            cache_de = {"decode_fsmn": None}
@@ -1059,17 +1025,15 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        
        finish_count = 0
@@ -1214,18 +1178,16 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            fs=fs,
            mc=True,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        
        finish_count = 0
@@ -1277,6 +1239,7 @@
        return asr_result_list
    
    return _forward
def inference_transducer(
    output_dir: str,
@@ -1400,20 +1363,14 @@
                 **kwargs,
                 ):
        # 3. Build data-iterator
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(
                speech2text.asr_train_args, False
            ),
            collate_fn=ASRTask.build_collate_fn(
                speech2text.asr_train_args, False
            ),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
    
        # 4 .Start for-loop
@@ -1464,7 +1421,6 @@
    
                    if text is not None:
                        ibest_writer["text"][key] = text
    return _forward
@@ -1561,18 +1517,16 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            mc=mc,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        
        finish_count = 0
@@ -1922,7 +1876,6 @@
    inference_pipeline = inference_launch(**kwargs)
    return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
if __name__ == "__main__":
funasr/bin/diar_infer.py
@@ -1,41 +1,28 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from collections import OrderedDict
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from collections import OrderedDict
import numpy as np
import soundfile
import torch
from scipy.ndimage import median_filter
from torch.nn import functional as F
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.diar import DiarTask
from funasr.tasks.diar import EENDOLADiarTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from scipy.ndimage import median_filter
from funasr.utils.misc import statistic_model_parameters
from funasr.datasets.iterable_dataset import load_bytes
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.tasks.diar import DiarTask
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.torch_utils.device_funcs import to_device
from funasr.utils.misc import statistic_model_parameters
class Speech2DiarizationEEND:
    """Speech2Diarlization class
@@ -61,10 +48,12 @@
        assert check_argument_types()
        # 1. Build Diarization model
        diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file(
        diar_model, diar_train_args = build_model_from_file(
            config_file=diar_train_config,
            model_file=diar_model_file,
            device=device
            device=device,
            task_name="diar",
            mode="eend-ola",
        )
        frontend = None
        if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
@@ -177,10 +166,12 @@
        assert check_argument_types()
        # TODO: 1. Build Diarization model
        diar_model, diar_train_args = DiarTask.build_model_from_file(
        diar_model, diar_train_args = build_model_from_file(
            config_file=diar_train_config,
            model_file=diar_model_file,
            device=device
            device=device,
            task_name="diar",
            mode="sond",
        )
        logging.info("diar_model: {}".format(diar_model))
        logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
@@ -344,7 +335,3 @@
            kwargs.update(**d.download_and_unpack(model_tag))
        return Speech2DiarizationSOND(**kwargs)
funasr/bin/diar_inference_launch.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -8,47 +8,28 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from collections import OrderedDict
import numpy as np
import soundfile
import torch
from torch.nn import functional as F
from typeguard import check_argument_types
from typeguard import check_return_type
from scipy.signal import medfilt
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.diar import DiarTask
from funasr.tasks.diar import EENDOLADiarTask
from funasr.torch_utils.device_funcs import to_device
from typeguard import check_argument_types
from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
from funasr.datasets.iterable_dataset import load_bytes
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from scipy.ndimage import median_filter
from funasr.utils.misc import statistic_model_parameters
from funasr.datasets.iterable_dataset import load_bytes
from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
def inference_sond(
        diar_train_config: str,
@@ -94,7 +75,8 @@
    set_all_random_seed(seed)
    # 2a. Build speech2xvec [Optional]
    if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
    if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict[
        "extract_profile"]:
        assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
        assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
        sv_train_config = param_dict["sv_train_config"]
@@ -186,16 +168,15 @@
                raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
        else:
            # 3. Build data-iterator
            loader = DiarTask.build_streaming_iterator(
                data_path_and_name_and_type,
            loader = build_streaming_iterator(
                task_name="diar",
                preprocess_args=None,
                data_path_and_name_and_type=data_path_and_name_and_type,
                dtype=dtype,
                batch_size=batch_size,
                key_file=key_file,
                num_workers=num_workers,
                preprocess_fn=None,
                collate_fn=None,
                allow_variable_data_keys=allow_variable_data_keys,
                inference=True,
                use_collate_fn=False,
            )
        # 7. Start for-loop
@@ -234,6 +215,7 @@
        return result_list
    return _forward
def inference_eend(
        diar_train_config: str,
@@ -306,16 +288,14 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
        loader = EENDOLADiarTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="diar",
            preprocess_args=None,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False),
            collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        # 3. Start for-loop
@@ -362,8 +342,6 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "sond":
        return inference_sond(mode=mode, **kwargs)
@@ -386,6 +364,7 @@
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Speaker Verification",
funasr/bin/lm_inference_launch.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -7,40 +7,25 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.types import float_or_none
import argparse
import logging
from pathlib import Path
import sys
import os
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import torch
from torch.nn.parallel import data_parallel
from typeguard import check_argument_types
from funasr.tasks.lm import LMTask
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.datasets.preprocessor import LMPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import float_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
@@ -78,8 +63,8 @@
    set_all_random_seed(seed)
    
    # 2. Build Model
    model, train_args = LMTask.build_model_from_file(
        train_config, model_file, device)
    model, train_args = build_model_from_file(
        train_config, model_file, None, device, "lm")
    wrapped_model = ForwardAdaptor(model, "nll")
    wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
    logging.info(f"Model:\n{model}")
@@ -193,16 +178,14 @@
            return results
        
        # 3. Build data-iterator
        loader = LMTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="lm",
            preprocess_args=train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=preprocessor,
            collate_fn=LMTask.build_collate_fn(train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        
        # 4. Start for-loop
@@ -302,6 +285,7 @@
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
    
def get_parser():
    parser = config_argparse.ArgumentParser(
@@ -407,4 +391,3 @@
if __name__ == "__main__":
    main()
funasr/bin/punc_infer.py
@@ -1,33 +1,19 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Any
from typing import List
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.punctuation import PunctuationTask
from funasr.datasets.preprocessor import split_to_mini_sentence
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
@@ -40,7 +26,7 @@
        dtype: str = "float32",
    ):
        #  Build Model
        model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
        model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
        self.device = device
        # Wrape model to make model.nll() data-parallel
        self.wrapped_model = ForwardAdaptor(model, "inference")
@@ -153,7 +139,7 @@
        dtype: str = "float32",
    ):
        #  Build Model
        model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
        model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
        self.device = device
        # Wrape model to make model.nll() data-parallel
        self.wrapped_model = ForwardAdaptor(model, "inference")
@@ -267,5 +253,3 @@
            sentence_out = sentence_out[:-1]
            sentence_punc_list_out[-1] = "_"
        return sentence_out, sentence_punc_list_out, cache_out
funasr/bin/punc_inference_launch.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -7,41 +7,22 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.types import float_or_none
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Any
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.punctuation import PunctuationTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.datasets.preprocessor import split_to_mini_sentence
from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
def inference_punc(
    batch_size: int,
@@ -121,6 +102,7 @@
    return _forward
def inference_punc_vad_realtime(
    batch_size: int,
    dtype: str,
@@ -177,7 +159,6 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "punc":
        return inference_punc(**kwargs)
@@ -186,6 +167,7 @@
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
@@ -267,7 +249,6 @@
    kwargs.pop("njob", None)
    inference_pipeline = inference_launch(**kwargs)
    return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":
funasr/bin/sv_infer.py
@@ -1,35 +1,24 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import numpy as np
import torch
from kaldiio import WriteHelper
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.sv import SVTask
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.misc import statistic_model_parameters
class Speech2Xvector:
    """Speech2Xvector class
@@ -56,10 +45,13 @@
        assert check_argument_types()
        # TODO: 1. Build SV model
        sv_model, sv_train_args = SVTask.build_model_from_file(
        sv_model, sv_train_args = build_model_from_file(
            config_file=sv_train_config,
            model_file=sv_model_file,
            device=device
            cmvn_file=None,
            device=device,
            task_name="sv",
            mode="sv",
        )
        logging.info("sv_model: {}".format(sv_model))
        logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
@@ -157,7 +149,3 @@
            kwargs.update(**d.download_and_unpack(model_tag))
        return Speech2Xvector(**kwargs)
funasr/bin/sv_inference_launch.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -7,20 +7,6 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
@@ -30,18 +16,16 @@
import torch
from kaldiio import WriteHelper
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.sv import SVTask
from funasr.torch_utils.device_funcs import to_device
from funasr.bin.sv_infer import Speech2Xvector
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.misc import statistic_model_parameters
from funasr.bin.sv_infer import Speech2Xvector
def inference_sv(
    output_dir: Optional[str] = None,
@@ -114,16 +98,15 @@
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        
        # 3. Build data-iterator
        loader = SVTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="sv",
            preprocess_args=None,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=None,
            collate_fn=None,
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
            use_collate_fn=False,
        )
        
        # 7 .Start for-loop
@@ -173,8 +156,6 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "sv":
        return inference_sv(**kwargs)
@@ -182,6 +163,7 @@
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Speaker Verification",
funasr/bin/tp_infer.py
@@ -1,41 +1,19 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
from optparse import Option
import sys
import json
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.datasets.preprocessor import LMPreprocessor
from funasr.tasks.asr import ASRTaskAligner as ASRTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.text.token_id_converter import TokenIDConverter
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.torch_utils.device_funcs import to_device
class Speech2Timestamp:
@@ -50,8 +28,8 @@
    ):
        assert check_argument_types()
        # 1. Build ASR model
        tp_model, tp_train_args = ASRTask.build_model_from_file(
            timestamp_infer_config, timestamp_model_file, device=device
        tp_model, tp_train_args = build_model_from_file(
            timestamp_infer_config, timestamp_model_file, cmvn_file=None, device=device, task_name="asr", mode="tp"
        )
        if 'cuda' in device:
            tp_model = tp_model.cuda()  # force model to cuda
@@ -65,7 +43,6 @@
        tp_model.to(dtype=getattr(torch, dtype)).eval()
        logging.info(f"Decoding device={device}, dtype={dtype}")
        self.tp_model = tp_model
        self.tp_train_args = tp_train_args
@@ -113,8 +90,6 @@
            enc = enc[0]
        # c. Forward Predictor
        _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1)
        _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len,
                                                                           text_lengths.to(self.device) + 1)
        return us_alphas, us_peaks
funasr/bin/tp_inference_launch.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -8,46 +8,25 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
import argparse
import logging
from optparse import Option
import sys
import json
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.datasets.preprocessor import LMPreprocessor
from funasr.tasks.asr import ASRTaskAligner as ASRTask
from funasr.torch_utils.device_funcs import to_device
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.text.token_id_converter import TokenIDConverter
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.bin.tp_infer import Speech2Timestamp
def inference_tp(
    batch_size: int,
@@ -141,16 +120,15 @@
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speechtext2timestamp.tp_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=preprocessor,
            collate_fn=ASRTask.build_collate_fn(speechtext2timestamp.tp_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        
        tp_result_list = []
@@ -182,14 +160,13 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "tp_norm":
        return inference_tp(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
@@ -306,7 +283,6 @@
    inference_pipeline = inference_launch(**kwargs)
    return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":
funasr/bin/train.py
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
funasr/bin/vad_infer.py
@@ -1,42 +1,23 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
import json
import math
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import math
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.vad import VADTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.torch_utils.device_funcs import to_device
class Speech2VadSegment:
@@ -64,8 +45,8 @@
        assert check_argument_types()
        # 1. Build vad model
        vad_model, vad_infer_args = VADTask.build_model_from_file(
            vad_infer_config, vad_model_file, device
        vad_model, vad_infer_args = build_model_from_file(
            vad_infer_config, vad_model_file, None, device, task_name="vad"
        )
        frontend = None
        if vad_infer_args.frontend is not None:
@@ -135,6 +116,7 @@
                    segments[batch_num] += segments_part[batch_num]
        return fbanks, segments
class Speech2VadSegmentOnline(Speech2VadSegment):
    """Speech2VadSegmentOnline class
@@ -146,13 +128,13 @@
        [[10, 230], [245, 450], ...]
    """
    def __init__(self, **kwargs):
        super(Speech2VadSegmentOnline, self).__init__(**kwargs)
        vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
        self.frontend = None
        if self.vad_infer_args.frontend is not None:
            self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
    @torch.no_grad()
    def __call__(
@@ -198,5 +180,3 @@
            # in_cache.update(batch['in_cache'])
            # in_cache = {key: value for key, value in batch['in_cache'].items()}
        return fbanks, segments, in_cache
funasr/bin/vad_inference_launch.py
@@ -1,58 +1,34 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
import argparse
import logging
import os
import sys
import json
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import math
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.vad import VADTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.bin.vad_infer import Speech2VadSegment, Speech2VadSegmentOnline
def inference_vad(
        batch_size: int,
@@ -74,7 +50,6 @@
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    logging.basicConfig(
        level=log_level,
@@ -112,16 +87,14 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = VADTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="vad",
            preprocess_args=None,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
            collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
@@ -157,6 +130,7 @@
    return _forward
def inference_vad_online(
        batch_size: int,
        ngpu: int,
@@ -175,7 +149,6 @@
        **kwargs,
):
    assert check_argument_types()
    logging.basicConfig(
        level=log_level,
@@ -214,16 +187,14 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = VADTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="vad",
            preprocess_args=None,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
            collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
@@ -273,8 +244,6 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "offline":
        return inference_vad(**kwargs)
@@ -283,6 +252,7 @@
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
@@ -405,5 +375,6 @@
    inference_pipeline = inference_launch(**kwargs)
    return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":
    main()
funasr/build_utils/build_args.py
@@ -41,7 +41,7 @@
            "--cmvn_file",
            type=str_or_none,
            default=None,
            help="The file path of noise scp file.",
            help="The path of cmvn file.",
        )
    elif args.task_name == "pretrain":
@@ -75,12 +75,29 @@
            default=None,
            help="The number of input dimension of the feature",
        )
        task_parser.add_argument(
            "--cmvn_file",
            type=str_or_none,
            default=None,
            help="The path of cmvn file.",
        )
    elif args.task_name == "diar":
        from funasr.build_utils.build_diar_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(task_parser)
    elif args.task_name == "sv":
        from funasr.build_utils.build_sv_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(task_parser)
        task_parser.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
    else:
        raise NotImplementedError("Not supported task: {}".format(args.task_name))
funasr/build_utils/build_asr_model.py
@@ -20,15 +20,18 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
@@ -42,6 +45,7 @@
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
@@ -89,6 +93,7 @@
        paraformer_bert=ParaformerBert,
        bicif_paraformer=BiCifParaformer,
        contextual_paraformer=ContextualParaformer,
        neatcontextual_paraformer=NeatContextualParaformer,
        mfcca=MFCCA,
        timestamp_prediction=TimestampPredictor,
        rnnt=TransducerModel,
@@ -258,17 +263,22 @@
def build_asr_model(args):
    # token_list
    if args.token_list is not None:
        with open(args.token_list) as f:
    if isinstance(args.token_list, str):
        with open(args.token_list, encoding="utf-8") as f:
            token_list = [line.rstrip() for line in f]
        args.token_list = list(token_list)
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
    elif isinstance(args.token_list, (tuple, list)):
        token_list = list(args.token_list)
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
    else:
        token_list = None
        vocab_size = None
    # frontend
    if args.input_size is None:
    if hasattr(args, "input_size") and args.input_size is None:
        frontend_class = frontend_choices.get_class(args.frontend)
        if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
            frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
@@ -279,7 +289,7 @@
        args.frontend = None
        args.frontend_conf = {}
        frontend = None
        input_size = args.input_size
        input_size = args.input_size if hasattr(args, "input_size") else None
    # data augmentation for spectrogram
    if args.specaug is not None:
@@ -291,6 +301,9 @@
    # normalization layer
    if args.normalize is not None:
        normalize_class = normalize_choices.get_class(args.normalize)
        if args.model == "mfcca":
            normalize = normalize_class(stats_file=args.cmvn_file, **args.normalize_conf)
        else:
        normalize = normalize_class(**args.normalize_conf)
    else:
        normalize = None
@@ -325,7 +338,8 @@
            token_list=token_list,
            **args.model_conf,
        )
    elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
    elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer",
                        "contextual_paraformer", "neatcontextual_paraformer"]:
        # predictor
        predictor_class = predictor_choices.get_class(args.predictor)
        predictor = predictor_class(**args.predictor_conf)
funasr/build_utils/build_diar_model.py
@@ -178,14 +178,18 @@
def build_diar_model(args):
    # token_list
    if args.token_list is not None:
        with open(args.token_list) as f:
    if isinstance(args.token_list, str):
        with open(args.token_list, encoding="utf-8") as f:
            token_list = [line.rstrip() for line in f]
        # Overwriting token_list to keep it as "portable".
        args.token_list = list(token_list)
    elif isinstance(args.token_list, (tuple, list)):
        token_list = list(args.token_list)
    else:
        raise RuntimeError("token_list must be str or list")
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
    else:
        vocab_size = None
    # frontend
    if args.input_size is None:
@@ -205,7 +209,7 @@
    encoder_class = encoder_choices.get_class(args.encoder)
    encoder = encoder_class(input_size=input_size, **args.encoder_conf)
    if args.model_name == "sond":
    if args.model == "sond":
        # data augmentation for spectrogram
        if args.specaug is not None:
            specaug_class = specaug_choices.get_class(args.specaug)
@@ -243,11 +247,7 @@
        # decoder
        decoder_class = decoder_choices.get_class(args.decoder)
        decoder = decoder_class(
            vocab_size=vocab_size,
            encoder_output_size=encoder.output_size(),
            **args.decoder_conf,
        )
        decoder = decoder_class(**args.decoder_conf)
        # logger aggregator
        if getattr(args, "label_aggregator", None) is not None:
funasr/build_utils/build_lm_model.py
@@ -34,10 +34,14 @@
def build_lm_model(args):
    # token_list
    if args.token_list is not None:
        with open(args.token_list) as f:
    if isinstance(args.token_list, str):
        with open(args.token_list, encoding="utf-8") as f:
            token_list = [line.rstrip() for line in f]
        args.token_list = list(token_list)
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
    elif isinstance(args.token_list, (tuple, list)):
        token_list = list(args.token_list)
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
    else:
@@ -47,6 +51,7 @@
    lm_class = lm_choices.get_class(args.lm)
    lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
    args.model = args.model if hasattr(args, "model") else "lm"
    model_class = model_choices.get_class(args.model)
    model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
funasr/build_utils/build_model.py
@@ -1,9 +1,10 @@
from funasr.build_utils.build_asr_model import build_asr_model
from funasr.build_utils.build_diar_model import build_diar_model
from funasr.build_utils.build_lm_model import build_lm_model
from funasr.build_utils.build_pretrain_model import build_pretrain_model
from funasr.build_utils.build_punc_model import build_punc_model
from funasr.build_utils.build_sv_model import build_sv_model
from funasr.build_utils.build_vad_model import build_vad_model
from funasr.build_utils.build_diar_model import build_diar_model
def build_model(args):
@@ -19,6 +20,8 @@
        model = build_vad_model(args)
    elif args.task_name == "diar":
        model = build_diar_model(args)
    elif args.task_name == "sv":
        model = build_sv_model(args)
    else:
        raise NotImplementedError("Not supported task: {}".format(args.task_name))
funasr/build_utils/build_model_from_file.py
New file
@@ -0,0 +1,193 @@
import argparse
import logging
import os
from pathlib import Path
from typing import Union
import torch
import yaml
from typeguard import check_argument_types
from funasr.build_utils.build_model import build_model
from funasr.models.base_model import FunASRModel
def build_model_from_file(
        config_file: Union[Path, str] = None,
        model_file: Union[Path, str] = None,
        cmvn_file: Union[Path, str] = None,
        device: str = "cpu",
        task_name: str = "asr",
        mode: str = "paraformer",
):
    """Build model from the files.
    This method is used for inference or fine-tuning.
    Args:
        config_file: The yaml file saved when training.
        model_file: The model file saved when training.
        device: Device type, "cpu", "cuda", or "cuda:N".
    """
    assert check_argument_types()
    if config_file is None:
        assert model_file is not None, (
            "The argument 'model_file' must be provided "
            "if the argument 'config_file' is not specified."
        )
        config_file = Path(model_file).parent / "config.yaml"
    else:
        config_file = Path(config_file)
    with config_file.open("r", encoding="utf-8") as f:
        args = yaml.safe_load(f)
    if cmvn_file is not None:
        args["cmvn_file"] = cmvn_file
    args = argparse.Namespace(**args)
    args.task_name = task_name
    model = build_model(args)
    if not isinstance(model, FunASRModel):
        raise RuntimeError(
            f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
        )
    model.to(device)
    model_dict = dict()
    model_name_pth = None
    if model_file is not None:
        logging.info("model_file is {}".format(model_file))
        if device == "cuda":
            device = f"cuda:{torch.cuda.current_device()}"
        model_dir = os.path.dirname(model_file)
        model_name = os.path.basename(model_file)
        if "model.ckpt-" in model_name or ".bin" in model_name:
            model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
                                                                        '.pb')) if ".bin" in model_name else os.path.join(
                model_dir, "{}.pb".format(model_name))
            if os.path.exists(model_name_pth):
                logging.info("model_file is load from pth: {}".format(model_name_pth))
                model_dict = torch.load(model_name_pth, map_location=device)
            else:
                model_dict = convert_tf2torch(model, model_file, mode)
            model.load_state_dict(model_dict)
        else:
            model_dict = torch.load(model_file, map_location=device)
    if task_name == "diar" and mode == "sond":
        model_dict = fileter_model_dict(model_dict, model.state_dict())
    if task_name == "vad":
        model.encoder.load_state_dict(model_dict)
    else:
        model.load_state_dict(model_dict)
    if model_name_pth is not None and not os.path.exists(model_name_pth):
        torch.save(model_dict, model_name_pth)
        logging.info("model_file is saved to pth: {}".format(model_name_pth))
    return model, args
def convert_tf2torch(
        model,
        ckpt,
        mode,
):
    assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" or mode == "tp"
    logging.info("start convert tf model to torch model")
    from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
    var_dict_tf = load_tf_dict(ckpt)
    var_dict_torch = model.state_dict()
    var_dict_torch_update = dict()
    if mode == "uniasr":
        # encoder
        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # predictor
        var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # encoder2
        var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # predictor2
        var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder2
        var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # stride_conv
        var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
    elif mode == "paraformer":
        # encoder
        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # predictor
        var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # bias_encoder
        var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
    elif "mode" == "sond":
        if model.encoder is not None:
            var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # speaker encoder
        if model.speaker_encoder is not None:
            var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # cd scorer
        if model.cd_scorer is not None:
            var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # ci scorer
        if model.ci_scorer is not None:
            var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        if model.decoder is not None:
            var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
    elif "mode" == "sv":
        # speech encoder
        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # pooling layer
        var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
    else:
        # encoder
        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # predictor
        var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # bias_encoder
        var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        return var_dict_torch_update
    return var_dict_torch_update
def fileter_model_dict(src_dict: dict, dest_dict: dict):
    from collections import OrderedDict
    new_dict = OrderedDict()
    for key, value in src_dict.items():
        if key in dest_dict:
            new_dict[key] = value
        else:
            logging.info("{} is no longer needed in this model.".format(key))
    for key, value in dest_dict.items():
        if key not in new_dict:
            logging.warning("{} is missed in checkpoint.".format(key))
    return new_dict
funasr/build_utils/build_streaming_iterator.py
New file
@@ -0,0 +1,67 @@
import numpy as np
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from funasr.datasets.iterable_dataset import IterableESPnetDataset
from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
from funasr.datasets.small_datasets.preprocessor import build_preprocess
def build_streaming_iterator(
        task_name,
        preprocess_args,
        data_path_and_name_and_type,
        key_file: str = None,
        batch_size: int = 1,
        fs: dict = None,
        mc: bool = False,
        dtype: str = np.float32,
        num_workers: int = 1,
        use_collate_fn: bool = True,
        preprocess_fn=None,
        ngpu: int = 0,
        train: bool = False,
) -> DataLoader:
    """Build DataLoader using iterable dataset"""
    assert check_argument_types()
    # preprocess
    if preprocess_fn is not None:
        preprocess_fn = preprocess_fn
    elif preprocess_args is not None:
        preprocess_args.task_name = task_name
        preprocess_fn = build_preprocess(preprocess_args, train)
    else:
        preprocess_fn = None
    # collate
    if not use_collate_fn:
        collate_fn = None
    elif task_name in ["punc", "lm"]:
        collate_fn = CommonCollateFn(int_pad_value=0)
    else:
        collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
    if collate_fn is not None:
        kwargs = dict(collate_fn=collate_fn)
    else:
        kwargs = {}
    dataset = IterableESPnetDataset(
        data_path_and_name_and_type,
        float_dtype=dtype,
        fs=fs,
        mc=mc,
        preprocess=preprocess_fn,
        key_file=key_file,
    )
    if dataset.apply_utt2category:
        kwargs.update(batch_size=1)
    else:
        kwargs.update(batch_size=batch_size)
    return DataLoader(
        dataset=dataset,
        pin_memory=ngpu > 0,
        num_workers=num_workers,
        **kwargs,
    )
funasr/build_utils/build_sv_model.py
New file
@@ -0,0 +1,258 @@
import logging
import torch
from typeguard import check_return_type
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.base_model import FunASRModel
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.sv_decoder import DenseDecoder
from funasr.models.e2e_sv import ESPnetSVModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.pooling.statistic_pooling import StatisticPooling
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
    HuggingFaceTransformersPostEncoder,  # noqa: H301
)
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.preencoder.linear import LinearProjection
from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
frontend_choices = ClassChoices(
    name="frontend",
    classes=dict(
        default=DefaultFrontend,
        sliding_window=SlidingWindow,
        s3prl=S3prlFrontend,
        fused=FusedFrontends,
        wav_frontend=WavFrontend,
    ),
    type_check=AbsFrontend,
    default="default",
)
specaug_choices = ClassChoices(
    name="specaug",
    classes=dict(
        specaug=SpecAug,
    ),
    type_check=AbsSpecAug,
    default=None,
    optional=True,
)
normalize_choices = ClassChoices(
    "normalize",
    classes=dict(
        global_mvn=GlobalMVN,
        utterance_mvn=UtteranceMVN,
    ),
    type_check=AbsNormalize,
    default=None,
    optional=True,
)
model_choices = ClassChoices(
    "model",
    classes=dict(
        espnet=ESPnetSVModel,
    ),
    type_check=FunASRModel,
    default="espnet",
)
preencoder_choices = ClassChoices(
    name="preencoder",
    classes=dict(
        sinc=LightweightSincConvs,
        linear=LinearProjection,
    ),
    type_check=AbsPreEncoder,
    default=None,
    optional=True,
)
encoder_choices = ClassChoices(
    "encoder",
    classes=dict(
        resnet34=ResNet34,
        resnet34_sp_l2reg=ResNet34_SP_L2Reg,
        rnn=RNNEncoder,
    ),
    type_check=AbsEncoder,
    default="resnet34",
)
postencoder_choices = ClassChoices(
    name="postencoder",
    classes=dict(
        hugging_face_transformers=HuggingFaceTransformersPostEncoder,
    ),
    type_check=AbsPostEncoder,
    default=None,
    optional=True,
)
pooling_choices = ClassChoices(
    name="pooling_type",
    classes=dict(
        statistic=StatisticPooling,
    ),
    type_check=torch.nn.Module,
    default="statistic",
)
decoder_choices = ClassChoices(
    "decoder",
    classes=dict(
        dense=DenseDecoder,
    ),
    type_check=AbsDecoder,
    default="dense",
)
class_choices_list = [
    # --frontend and --frontend_conf
    frontend_choices,
    # --specaug and --specaug_conf
    specaug_choices,
    # --normalize and --normalize_conf
    normalize_choices,
    # --model and --model_conf
    model_choices,
    # --preencoder and --preencoder_conf
    preencoder_choices,
    # --encoder and --encoder_conf
    encoder_choices,
    # --postencoder and --postencoder_conf
    postencoder_choices,
    # --pooling and --pooling_conf
    pooling_choices,
    # --decoder and --decoder_conf
    decoder_choices,
]
def build_sv_model(args):
    # token_list
    if isinstance(args.token_list, str):
        with open(args.token_list, encoding="utf-8") as f:
            token_list = [line.rstrip() for line in f]
        # Overwriting token_list to keep it as "portable".
        args.token_list = list(token_list)
    elif isinstance(args.token_list, (tuple, list)):
        token_list = list(args.token_list)
    else:
        raise RuntimeError("token_list must be str or list")
    vocab_size = len(token_list)
    logging.info(f"Speaker number: {vocab_size}")
    # 1. frontend
    if args.input_size is None:
        # Extract features in the model
        frontend_class = frontend_choices.get_class(args.frontend)
        frontend = frontend_class(**args.frontend_conf)
        input_size = frontend.output_size()
    else:
        # Give features from data-loader
        args.frontend = None
        args.frontend_conf = {}
        frontend = None
        input_size = args.input_size
    # 2. Data augmentation for spectrogram
    if args.specaug is not None:
        specaug_class = specaug_choices.get_class(args.specaug)
        specaug = specaug_class(**args.specaug_conf)
    else:
        specaug = None
    # 3. Normalization layer
    if args.normalize is not None:
        normalize_class = normalize_choices.get_class(args.normalize)
        normalize = normalize_class(**args.normalize_conf)
    else:
        normalize = None
    # 4. Pre-encoder input block
    # NOTE(kan-bayashi): Use getattr to keep the compatibility
    if getattr(args, "preencoder", None) is not None:
        preencoder_class = preencoder_choices.get_class(args.preencoder)
        preencoder = preencoder_class(**args.preencoder_conf)
        input_size = preencoder.output_size()
    else:
        preencoder = None
    # 5. Encoder
    encoder_class = encoder_choices.get_class(args.encoder)
    encoder = encoder_class(input_size=input_size, **args.encoder_conf)
    # 6. Post-encoder block
    # NOTE(kan-bayashi): Use getattr to keep the compatibility
    encoder_output_size = encoder.output_size()
    if getattr(args, "postencoder", None) is not None:
        postencoder_class = postencoder_choices.get_class(args.postencoder)
        postencoder = postencoder_class(
            input_size=encoder_output_size, **args.postencoder_conf
        )
        encoder_output_size = postencoder.output_size()
    else:
        postencoder = None
    # 7. Pooling layer
    pooling_class = pooling_choices.get_class(args.pooling_type)
    pooling_dim = (2, 3)
    eps = 1e-12
    if hasattr(args, "pooling_type_conf"):
        if "pooling_dim" in args.pooling_type_conf:
            pooling_dim = args.pooling_type_conf["pooling_dim"]
        if "eps" in args.pooling_type_conf:
            eps = args.pooling_type_conf["eps"]
    pooling_layer = pooling_class(
        pooling_dim=pooling_dim,
        eps=eps,
    )
    if args.pooling_type == "statistic":
        encoder_output_size *= 2
    # 8. Decoder
    decoder_class = decoder_choices.get_class(args.decoder)
    decoder = decoder_class(
        vocab_size=vocab_size,
        encoder_output_size=encoder_output_size,
        **args.decoder_conf,
    )
    # 7. Build model
    try:
        model_class = model_choices.get_class(args.model)
    except AttributeError:
        model_class = model_choices.get_class("espnet")
    model = model_class(
        vocab_size=vocab_size,
        token_list=token_list,
        frontend=frontend,
        specaug=specaug,
        normalize=normalize,
        preencoder=preencoder,
        encoder=encoder,
        postencoder=postencoder,
        pooling_layer=pooling_layer,
        decoder=decoder,
        **args.model_conf,
    )
    # FIXME(kamo): Should be done in model?
    # 8. Initialize
    if args.init is not None:
        initialize(model, args.init)
    assert check_return_type(model)
    return model
funasr/build_utils/build_vad_model.py
@@ -50,6 +50,10 @@
def build_vad_model(args):
    # frontend
    if not hasattr(args, "cmvn_file"):
        args.cmvn_file = None
    if not hasattr(args, "init"):
        args.init = None
    if args.input_size is None:
        frontend_class = frontend_choices.get_class(args.frontend)
        if args.frontend == 'wav_frontend':
funasr/models/e2e_asr_contextual_paraformer.py
@@ -43,9 +43,7 @@
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        preencoder: Optional[AbsPreEncoder],
        encoder: AbsEncoder,
        postencoder: Optional[AbsPostEncoder],
        decoder: AbsDecoder,
        ctc: CTC,
        ctc_weight: float = 0.5,
@@ -72,6 +70,8 @@
        crit_attn_weight: float = 0.0,
        crit_attn_smooth: float = 0.0,
        bias_encoder_dropout_rate: float = 0.0,
        preencoder: Optional[AbsPreEncoder] = None,
        postencoder: Optional[AbsPostEncoder] = None,
    ):
        assert check_argument_types()
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
funasr/models/e2e_asr_mfcca.py
@@ -53,7 +53,7 @@
            encoder: AbsEncoder,
            decoder: AbsDecoder,
            ctc: CTC,
            rnnt_decoder: None,
            rnnt_decoder: None = None,
            ctc_weight: float = 0.5,
            ignore_id: int = -1,
            lsm_weight: float = 0.0,
funasr/models/e2e_uni_asr.py
@@ -50,9 +50,7 @@
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        preencoder: Optional[AbsPreEncoder],
        encoder: AbsEncoder,
        postencoder: Optional[AbsPostEncoder],
        decoder: AbsDecoder,
        ctc: CTC,
        ctc_weight: float = 0.5,
@@ -80,6 +78,8 @@
        loss_weight_model1: float = 0.5,
        enable_maas_finetune: bool = False,
        freeze_encoder2: bool = False,
        preencoder: Optional[AbsPreEncoder] = None,
        postencoder: Optional[AbsPostEncoder] = None,
        encoder1_encoder2_joint_training: bool = True,
    ):
        assert check_argument_types()
funasr/models/e2e_vad.py
@@ -5,6 +5,7 @@
from torch import nn
import math
from funasr.models.encoder.fsmn_encoder import FSMN
from funasr.models.base_model import FunASRModel
class VadStateMachine(Enum):
@@ -211,7 +212,7 @@
        return int(self.frame_size_ms)
class E2EVadModel(nn.Module):
class E2EVadModel(FunASRModel):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
tests/test_sv_inference_pipeline.py
tests/test_vad_inference_pipeline.py
@@ -37,7 +37,7 @@
        rec_result = inference_pipeline(
            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav')
        logger.info("vad inference result: {0}".format(rec_result))
        assert rec_result["text"] == [[80, 2340], [2620, 6200], [6480, 23670], [23950, 26250], [26780, 28990],
        assert rec_result["text"] == [[70, 2340], [2620, 6200], [6480, 23670], [23950, 26250], [26780, 28990],
                                      [29950, 31430], [31750, 37600], [38210, 46900], [47310, 49630], [49910, 56460],
                                      [56740, 59540], [59820, 70450]]