dyyzhmm
2023-03-15 d836e403ada705b798f39442398ba74e7d2d2683
Merge pull request #2 from alibaba-damo-academy/main

merge from official
21个文件已修改
18个文件已添加
10236 ■■■■■ 已修改文件
docs/images/damo.png 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/README.md 25 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer_online.py 33 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer_online.py 33 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer.py 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_vad_punc.py 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/build_trainer.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/eend_ola_inference.py 413 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/tp_inference.py 59 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/vad_inference_online.py 345 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_paraformer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_diar_eend_ola.py 242 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_tp.py 175 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_vad.py 40 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/eend_ola_feature.py 51 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/wav_frontend.py 282 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/pooling/statistic_pooling.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/eend_ola/encoder.py 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/readme.md 11 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/third_party/install_openblas.sh 39 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/triton_gpu/client/aishell_test.txt 7176 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/triton_gpu/client/decode_manifest_triton_wo_cuts.py 561 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/encoder/config.pbtxt 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/config.pbtxt 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 82 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/diar.py 329 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/timestamp_tools.py 56 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
tests/test_asr_inference_pipeline.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
tests/test_asr_vad_punc_inference_pipeline.py 32 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
tests/test_lm_pipeline.py 25 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
tests/test_punctuation_pipeline.py 43 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
tests/test_sv_inference_pipeline.py 47 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
tests/test_vad_inference_pipeline.py 35 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/images/damo.png

egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py
@@ -5,7 +5,7 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
inference_pipline = pipeline(
inference_pipeline = pipeline(
    task=Tasks.punctuation,
    model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
    model_revision="v1.0.0",
@@ -17,7 +17,7 @@
cache_out = []
rec_result_all="outputs:"
for vad in vads:
    rec_result = inference_pipline(text_in=vad, cache=cache_out)
    rec_result = inference_pipeline(text_in=vad, cache=cache_out)
    #print(rec_result)
    cache_out = rec_result['cache']
    rec_result_all += rec_result['text']
egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py
New file
@@ -0,0 +1,26 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
# 初始化推理 pipeline
# 当以原始音频作为输入时使用配置文件 sond.yaml,并设置 mode 为sond_demo
inference_diar_pipline = pipeline(
    mode="sond_demo",
    num_workers=0,
    task=Tasks.speaker_diarization,
    diar_model_config="sond.yaml",
    model='damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch',
    sv_model="damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch",
    sv_model_revision="master",
)
# 以 audio_list 作为输入,其中第一个音频为待检测语音,后面的音频为不同说话人的声纹注册语音
audio_list = [[
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record.wav",
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_A.wav",
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B.wav",
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B1.wav"
]]
results = inference_diar_pipline(audio_in=audio_list)
for rst in results:
    print(rst["value"])
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/README.md
New file
@@ -0,0 +1,25 @@
# ModelScope Model
## How to finetune and infer using a pretrained ModelScope Model
### Inference
Or you can use the finetuned model for inference directly.
- Setting parameters in `infer.py`
    - <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format.
    - <strong>text_in:</strong> # support text, text url.
    - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
- Then you can run the pipeline to infer with:
```python
    python infer.py
```
Modify inference related parameters in vad.yaml.
- max_end_silence_time: The end-point silence duration  to judge the end of sentence, the parameter range is 500ms~6000ms, and the default value is 800ms
- speech_noise_thres:  The balance of speech and silence scores, the parameter range is (-1,1)
    - The value tends to -1, the greater probability of noise being judged as speech
    - The value tends to 1, the greater probability of speech being judged as noise
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.py
New file
@@ -0,0 +1,12 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
inference_pipline = pipeline(
    task=Tasks.speech_timestamp,
    model='damo/speech_timestamp_prediction-v1-16k-offline',
    output_dir='./tmp')
rec_result = inference_pipline(
    audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav',
    text_in='一 个 东 太 平 洋 国 家 为 什 么 跑 到 西 太 平 洋 来 了 呢')
print(rec_result)
egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.py
@@ -7,7 +7,7 @@
    inference_pipline = pipeline(
        task=Tasks.voice_activity_detection,
        model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
        model_revision=None,
        model_revision='v1.2.0',
        output_dir=output_dir,
        batch_size=1,
    )
egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer_online.py
New file
@@ -0,0 +1,33 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
import soundfile
if __name__ == '__main__':
    output_dir = None
    inference_pipline = pipeline(
        task=Tasks.voice_activity_detection,
        model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
        model_revision='v1.2.0',
        output_dir=output_dir,
        batch_size=1,
        mode='online',
    )
    speech, sample_rate = soundfile.read("./vad_example_16k.wav")
    speech_length = speech.shape[0]
    sample_offset = 0
    step = 160 * 10
    param_dict = {'in_cache': dict()}
    for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
        if sample_offset + step >= speech_length - 1:
            step = speech_length - sample_offset
            is_final = True
        else:
            is_final = False
        param_dict['is_final'] = is_final
        segments_result = inference_pipline(audio_in=speech[sample_offset: sample_offset + step],
                                            param_dict=param_dict)
        print(segments_result)
egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py
@@ -7,8 +7,8 @@
    inference_pipline = pipeline(
        task=Tasks.voice_activity_detection,
        model="damo/speech_fsmn_vad_zh-cn-8k-common",
        model_revision=None,
        output_dir='./output_dir',
        model_revision='v1.2.0',
        output_dir=output_dir,
        batch_size=1,
    )
    segments_result = inference_pipline(audio_in=audio_in)
egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer_online.py
New file
@@ -0,0 +1,33 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
import soundfile
if __name__ == '__main__':
    output_dir = None
    inference_pipline = pipeline(
        task=Tasks.voice_activity_detection,
        model="damo/speech_fsmn_vad_zh-cn-8k-common",
        model_revision='v1.2.0',
        output_dir=output_dir,
        batch_size=1,
        mode='online',
    )
    speech, sample_rate = soundfile.read("./vad_example_8k.wav")
    speech_length = speech.shape[0]
    sample_offset = 0
    step = 80 * 10
    param_dict = {'in_cache': dict()}
    for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
        if sample_offset + step >= speech_length - 1:
            step = speech_length - sample_offset
            is_final = True
        else:
            is_final = False
        param_dict['is_final'] = is_final
        segments_result = inference_pipline(audio_in=speech[sample_offset: sample_offset + step],
                                            param_dict=param_dict)
        print(segments_result)
funasr/bin/asr_inference_paraformer.py
@@ -42,7 +42,7 @@
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl, time_stamp_sentence
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
class Speech2Text:
@@ -245,7 +245,7 @@
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        if isinstance(self.asr_model, BiCifParaformer):
            _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
            _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
                                                                                   pre_token_length)  # test no bias cif2
        results = []
@@ -291,7 +291,10 @@
                    text = None
                if isinstance(self.asr_model, BiCifParaformer):
                    timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
                    _, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
                                                            us_peaks[i],
                                                            copy.copy(token),
                                                            vad_offset=begin_time)
                    results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
                else:
                    results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -44,11 +44,10 @@
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
from funasr.bin.vad_inference import Speech2VadSegment
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.bin.punctuation_infer import Text2Punc
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.utils.timestamp_tools import time_stamp_sentence
header_colors = '\033[95m'
end_colors = '\033[0m'
@@ -257,7 +256,7 @@
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        if isinstance(self.asr_model, BiCifParaformer):
            _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
            _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
                                                                                   pre_token_length)  # test no bias cif2
        results = []
@@ -303,7 +302,10 @@
                    text = None
                if isinstance(self.asr_model, BiCifParaformer):
                    timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
                    _, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
                                                            us_peaks[i],
                                                            copy.copy(token),
                                                            vad_offset=begin_time)
                    results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
                else:
                    results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
funasr/bin/build_trainer.py
@@ -28,7 +28,9 @@
    elif mode == "uniasr":
        from funasr.tasks.asr import ASRTaskUniASR as ASRTask
    elif mode == "mfcca":
        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
    elif mode == "tp":
        from funasr.tasks.asr import ASRTaskAligner as ASRTask
    else:
        raise ValueError("Unknown mode: {}".format(mode))
    parser = ASRTask.get_parser()
funasr/bin/eend_ola_inference.py
New file
@@ -0,0 +1,413 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.tasks.diar import EENDOLADiarTask
from funasr.torch_utils.device_funcs import to_device
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
class Speech2Diarization:
    """Speech2Diarlization class
    Examples:
        >>> import soundfile
        >>> import numpy as np
        >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth")
        >>> profile = np.load("profiles.npy")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2diar(audio, profile)
        {"spk1": [(int, int), ...], ...}
    """
    def __init__(
            self,
            diar_train_config: Union[Path, str] = None,
            diar_model_file: Union[Path, str] = None,
            device: str = "cpu",
            dtype: str = "float32",
    ):
        assert check_argument_types()
        # 1. Build Diarization model
        diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file(
            config_file=diar_train_config,
            model_file=diar_model_file,
            device=device
        )
        frontend = None
        if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
            frontend = WavFrontendMel23(**diar_train_args.frontend_conf)
        # set up seed for eda
        np.random.seed(diar_train_args.seed)
        torch.manual_seed(diar_train_args.seed)
        torch.cuda.manual_seed(diar_train_args.seed)
        os.environ['PYTORCH_SEED'] = str(diar_train_args.seed)
        logging.info("diar_model: {}".format(diar_model))
        logging.info("diar_train_args: {}".format(diar_train_args))
        diar_model.to(dtype=getattr(torch, dtype)).eval()
        self.diar_model = diar_model
        self.diar_train_args = diar_train_args
        self.device = device
        self.dtype = dtype
        self.frontend = frontend
    @torch.no_grad()
    def __call__(
            self,
            speech: Union[torch.Tensor, np.ndarray],
            speech_lengths: Union[torch.Tensor, np.ndarray] = None
    ):
        """Inference
        Args:
            speech: Input speech data
        Returns:
            diarization results
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
            self.diar_model.frontend = None
        else:
            feats = speech
            feats_len = speech_lengths
        batch = {"speech": feats, "speech_lengths": feats_len}
        batch = to_device(batch, device=self.device)
        results = self.diar_model.estimate_sequential(**batch)
        return results
    @staticmethod
    def from_pretrained(
            model_tag: Optional[str] = None,
            **kwargs: Optional[Any],
    ):
        """Build Speech2Diarization instance from the pretrained model.
        Args:
            model_tag (Optional[str]): Model tag of the pretrained models.
                Currently, the tags of espnet_model_zoo are supported.
        Returns:
            Speech2Diarization: Speech2Diarization instance.
        """
        if model_tag is not None:
            try:
                from espnet_model_zoo.downloader import ModelDownloader
            except ImportError:
                logging.error(
                    "`espnet_model_zoo` is not installed. "
                    "Please install via `pip install -U espnet_model_zoo`."
                )
                raise
            d = ModelDownloader()
            kwargs.update(**d.download_and_unpack(model_tag))
        return Speech2Diarization(**kwargs)
def inference_modelscope(
        diar_train_config: str,
        diar_model_file: str,
        output_dir: Optional[str] = None,
        batch_size: int = 1,
        dtype: str = "float32",
        ngpu: int = 0,
        num_workers: int = 0,
        log_level: Union[int, str] = "INFO",
        key_file: Optional[str] = None,
        model_tag: Optional[str] = None,
        allow_variable_data_keys: bool = True,
        streaming: bool = False,
        param_dict: Optional[dict] = None,
        **kwargs,
):
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    logging.info("param_dict: {}".format(param_dict))
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Build speech2diar
    speech2diar_kwargs = dict(
        diar_train_config=diar_train_config,
        diar_model_file=diar_model_file,
        device=device,
        dtype=dtype,
        streaming=streaming,
    )
    logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
    speech2diar = Speech2Diarization.from_pretrained(
        model_tag=model_tag,
        **speech2diar_kwargs,
    )
    speech2diar.diar_model.eval()
    def output_results_str(results: dict, uttid: str):
        rst = []
        mid = uttid.rsplit("-", 1)[0]
        for key in results:
            results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]]
        template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
        for spk, segs in results.items():
            rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
        return "\n".join(rst)
    def _forward(
            data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
            raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
            output_dir_v2: Optional[str] = None,
            param_dict: Optional[dict] = None,
    ):
        # 2. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = EENDOLADiarTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False),
            collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        # 3. Start for-loop
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        if output_path is not None:
            os.makedirs(output_path, exist_ok=True)
            output_writer = open("{}/result.txt".format(output_path), "w")
        result_list = []
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            results = speech2diar(**batch)
            # Only supporting batch_size==1
            key, value = keys[0], output_results_str(results, keys[0])
            item = {"key": key, "value": value}
            result_list.append(item)
            if output_path is not None:
                output_writer.write(value)
                output_writer.flush()
        if output_path is not None:
            output_writer.close()
        return result_list
    return _forward
def inference(
        data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
        diar_train_config: Optional[str],
        diar_model_file: Optional[str],
        output_dir: Optional[str] = None,
        batch_size: int = 1,
        dtype: str = "float32",
        ngpu: int = 0,
        seed: int = 0,
        num_workers: int = 1,
        log_level: Union[int, str] = "INFO",
        key_file: Optional[str] = None,
        model_tag: Optional[str] = None,
        allow_variable_data_keys: bool = True,
        streaming: bool = False,
        smooth_size: int = 83,
        dur_threshold: int = 10,
        out_format: str = "vad",
        **kwargs,
):
    inference_pipeline = inference_modelscope(
        diar_train_config=diar_train_config,
        diar_model_file=diar_model_file,
        output_dir=output_dir,
        batch_size=batch_size,
        dtype=dtype,
        ngpu=ngpu,
        seed=seed,
        num_workers=num_workers,
        log_level=log_level,
        key_file=key_file,
        model_tag=model_tag,
        allow_variable_data_keys=allow_variable_data_keys,
        streaming=streaming,
        smooth_size=smooth_size,
        dur_threshold=dur_threshold,
        out_format=out_format,
        **kwargs,
    )
    return inference_pipeline(data_path_and_name_and_type, raw_inputs=None)
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Speaker verification/x-vector extraction",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    # Note(kamo): Use '_' instead of '-' as separator.
    # '-' is confusing if written in yaml.
    parser.add_argument(
        "--log_level",
        type=lambda x: x.upper(),
        default="INFO",
        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
        help="The verbose level of logging",
    )
    parser.add_argument("--output_dir", type=str, required=False)
    parser.add_argument(
        "--ngpu",
        type=int,
        default=0,
        help="The number of gpus. 0 indicates CPU mode",
    )
    parser.add_argument(
        "--gpuid_list",
        type=str,
        default="",
        help="The visible gpus",
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument(
        "--dtype",
        default="float32",
        choices=["float16", "float32", "float64"],
        help="Data type",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=1,
        help="The number of workers used for DataLoader",
    )
    group = parser.add_argument_group("Input data related")
    group.add_argument(
        "--data_path_and_name_and_type",
        type=str2triple_str,
        required=False,
        action="append",
    )
    group.add_argument("--key_file", type=str_or_none)
    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
    group = parser.add_argument_group("The model configuration related")
    group.add_argument(
        "--diar_train_config",
        type=str,
        help="diarization training configuration",
    )
    group.add_argument(
        "--diar_model_file",
        type=str,
        help="diarization model parameter file",
    )
    group.add_argument(
        "--dur_threshold",
        type=int,
        default=10,
        help="The threshold for short segments in number frames"
    )
    parser.add_argument(
        "--smooth_size",
        type=int,
        default=83,
        help="The smoothing window length in number frames"
    )
    group.add_argument(
        "--model_tag",
        type=str,
        help="Pretrained model tag. If specify this option, *_train_config and "
             "*_file will be overwritten",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="The batch size for inference",
    )
    parser.add_argument("--streaming", type=str2bool, default=False)
    return parser
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
    args = parser.parse_args(cmd)
    kwargs = vars(args)
    kwargs.pop("config", None)
    logging.info("args: {}".format(kwargs))
    if args.output_dir is None:
        jobid, n_gpu = 1, 1
        gpuid = args.gpuid_list.split(",")[jobid - 1]
    else:
        jobid = int(args.output_dir.split(".")[-1])
        n_gpu = len(args.gpuid_list.split(","))
        gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu]
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
    results_list = inference(**kwargs)
    for results in results_list:
        print("{} {}".format(results["key"], results["value"]))
if __name__ == "__main__":
    main()
funasr/bin/tp_inference.py
@@ -28,6 +28,8 @@
from funasr.utils.types import str_or_none
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.text.token_id_converter import TokenIDConverter
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
header_colors = '\033[95m'
end_colors = '\033[0m'
@@ -37,61 +39,6 @@
    'audio_fs': 16000,
    'model_fs': 16000
}
def time_stamp_lfr6_advance(us_alphas, us_cif_peak, char_list):
    START_END_THRESHOLD = 5
    MAX_TOKEN_DURATION = 12
    TIME_RATE = 10.0 * 6 / 1000 / 3  #  3 times upsampled
    if len(us_cif_peak.shape) == 2:
        alphas, cif_peak = us_alphas[0], us_cif_peak[0]  # support inference batch_size=1 only
    else:
        alphas, cif_peak = us_alphas, us_cif_peak
    num_frames = cif_peak.shape[0]
    if char_list[-1] == '</s>':
        char_list = char_list[:-1]
    # char_list = [i for i in text]
    timestamp_list = []
    new_char_list = []
    # for bicif model trained with large data, cif2 actually fires when a character starts
    # so treat the frames between two peaks as the duration of the former token
    fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 3.2  # total offset
    num_peak = len(fire_place)
    assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    # begin silence
    if fire_place[0] > START_END_THRESHOLD:
        # char_list.insert(0, '<sil>')
        timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
        new_char_list.append('<sil>')
    # tokens timestamp
    for i in range(len(fire_place)-1):
        new_char_list.append(char_list[i])
        if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] < MAX_TOKEN_DURATION:
            timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE])
        else:
            # cut the duration to token and sil of the 0-weight frames last long
            _split = fire_place[i] + MAX_TOKEN_DURATION
            timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE])
            timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE])
            new_char_list.append('<sil>')
    # tail token and end silence
    # new_char_list.append(char_list[-1])
    if num_frames - fire_place[-1] > START_END_THRESHOLD:
        _end = (num_frames + fire_place[-1]) * 0.5
        # _end = fire_place[-1]
        timestamp_list[-1][1] = _end*TIME_RATE
        timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
        new_char_list.append("<sil>")
    else:
        timestamp_list[-1][1] = num_frames*TIME_RATE
    assert len(new_char_list) == len(timestamp_list)
    res_str = ""
    for char, timestamp in zip(new_char_list, timestamp_list):
        res_str += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
    res = []
    for char, timestamp in zip(new_char_list, timestamp_list):
        if char != '<sil>':
            res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
    return res_str, res
class SpeechText2Timestamp:
@@ -315,7 +262,7 @@
            for batch_id in range(_bs):
                key = keys[batch_id]
                token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
                ts_str, ts_list = time_stamp_lfr6_advance(us_alphas[batch_id], us_cif_peak[batch_id], token)
                ts_str, ts_list = ts_prediction_lfr6_standard(us_alphas[batch_id], us_cif_peak[batch_id], token, force_time_shift=-3.0)
                logging.warning(ts_str)
                item = {'key': key, 'value': ts_str, 'timestamp':ts_list}
                tp_result_list.append(item)
funasr/bin/vad_inference_online.py
New file
@@ -0,0 +1,345 @@
import argparse
import logging
import sys
import json
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.tasks.vad import VADTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.models.frontend.wav_frontend import WavFrontendOnline
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.bin.vad_inference import Speech2VadSegment
header_colors = '\033[95m'
end_colors = '\033[0m'
global_asr_language: str = 'zh-cn'
global_sample_rate: Union[int, Dict[Any, int]] = {
    'audio_fs': 16000,
    'model_fs': 16000
}
class Speech2VadSegmentOnline(Speech2VadSegment):
    """Speech2VadSegmentOnline class
    Examples:
        >>> import soundfile
        >>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2segment(audio)
        [[10, 230], [245, 450], ...]
    """
    def __init__(self, **kwargs):
        super(Speech2VadSegmentOnline, self).__init__(**kwargs)
        vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
        self.frontend = None
        if self.vad_infer_args.frontend is not None:
            self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
            in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False
    ) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
        """Inference
        Args:
            speech: Input speech data
        Returns:
            text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        batch_size = speech.shape[0]
        segments = [[]] * batch_size
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths, is_final)
            fbanks, _ = self.frontend.get_fbank()
        else:
            raise Exception("Need to extract feats first, please configure frontend configuration")
        if feats.shape[0]:
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
            waveforms = self.frontend.get_waveforms()
            batch = {
                "feats": feats,
                "waveform": waveforms,
                "in_cache": in_cache,
                "is_final": is_final
            }
            # a. To device
            batch = to_device(batch, device=self.device)
            segments, in_cache = self.vad_model.forward_online(**batch)
            # in_cache.update(batch['in_cache'])
            # in_cache = {key: value for key, value in batch['in_cache'].items()}
        return fbanks, segments, in_cache
def inference(
        batch_size: int,
        ngpu: int,
        log_level: Union[int, str],
        data_path_and_name_and_type,
        vad_infer_config: Optional[str],
        vad_model_file: Optional[str],
        vad_cmvn_file: Optional[str] = None,
        raw_inputs: Union[np.ndarray, torch.Tensor] = None,
        key_file: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        num_workers: int = 1,
        **kwargs,
):
    inference_pipeline = inference_modelscope(
        batch_size=batch_size,
        ngpu=ngpu,
        log_level=log_level,
        vad_infer_config=vad_infer_config,
        vad_model_file=vad_model_file,
        vad_cmvn_file=vad_cmvn_file,
        key_file=key_file,
        allow_variable_data_keys=allow_variable_data_keys,
        output_dir=output_dir,
        dtype=dtype,
        seed=seed,
        num_workers=num_workers,
        **kwargs,
    )
    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
        batch_size: int,
        ngpu: int,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        vad_infer_config: Optional[str],
        vad_model_file: Optional[str],
        vad_cmvn_file: Optional[str] = None,
        # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
        key_file: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        num_workers: int = 1,
        **kwargs,
):
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2vadsegment
    speech2vadsegment_kwargs = dict(
        vad_infer_config=vad_infer_config,
        vad_model_file=vad_model_file,
        vad_cmvn_file=vad_cmvn_file,
        device=device,
        dtype=dtype,
    )
    logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
    speech2vadsegment = Speech2VadSegmentOnline(**speech2vadsegment_kwargs)
    def _forward(
            data_path_and_name_and_type,
            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
            output_dir_v2: Optional[str] = None,
            fs: dict = None,
            param_dict: dict = None,
    ):
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = VADTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
            collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop
        # FIXME(kamo): The output format should be discussed about
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        if output_path is not None:
            writer = DatadirWriter(output_path)
            ibest_writer = writer[f"1best_recog"]
        else:
            writer = None
            ibest_writer = None
        vad_results = []
        batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict()
        is_final = param_dict['is_final'] if param_dict is not None else False
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            batch['in_cache'] = batch_in_cache
            batch['is_final'] = is_final
            # do vad segment
            _, results, param_dict['in_cache'] = speech2vadsegment(**batch)
            # param_dict['in_cache'] = batch['in_cache']
            if results:
                for i, _ in enumerate(keys):
                    if results[i]:
                        results[i] = json.dumps(results[i])
                        item = {'key': keys[i], 'value': results[i]}
                        vad_results.append(item)
                        if writer is not None:
                            results[i] = json.loads(results[i])
                            ibest_writer["text"][keys[i]] = "{}".format(results[i])
        return vad_results
    return _forward
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="VAD Decoding",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    # Note(kamo): Use '_' instead of '-' as separator.
    # '-' is confusing if written in yaml.
    parser.add_argument(
        "--log_level",
        type=lambda x: x.upper(),
        default="INFO",
        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
        help="The verbose level of logging",
    )
    parser.add_argument("--output_dir", type=str, required=False)
    parser.add_argument(
        "--ngpu",
        type=int,
        default=0,
        help="The number of gpus. 0 indicates CPU mode",
    )
    parser.add_argument(
        "--gpuid_list",
        type=str,
        default="",
        help="The visible gpus",
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument(
        "--dtype",
        default="float32",
        choices=["float16", "float32", "float64"],
        help="Data type",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=1,
        help="The number of workers used for DataLoader",
    )
    group = parser.add_argument_group("Input data related")
    group.add_argument(
        "--data_path_and_name_and_type",
        type=str2triple_str,
        required=False,
        action="append",
    )
    group.add_argument("--raw_inputs", type=list, default=None)
    # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
    group.add_argument("--key_file", type=str_or_none)
    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
    group = parser.add_argument_group("The model configuration related")
    group.add_argument(
        "--vad_infer_config",
        type=str,
        help="VAD infer configuration",
    )
    group.add_argument(
        "--vad_model_file",
        type=str,
        help="VAD model parameter file",
    )
    group.add_argument(
        "--vad_cmvn_file",
        type=str,
        help="Global cmvn file",
    )
    group = parser.add_argument_group("infer related")
    group.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="The batch size for inference",
    )
    return parser
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
    args = parser.parse_args(cmd)
    kwargs = vars(args)
    kwargs.pop("config", None)
    inference(**kwargs)
if __name__ == "__main__":
    main()
funasr/models/e2e_asr_paraformer.py
@@ -926,10 +926,10 @@
    def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out,
        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
                                                                                               encoder_out_mask,
                                                                                               token_num)
        return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
    def forward(
            self,
funasr/models/e2e_diar_eend_ola.py
New file
@@ -0,0 +1,242 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import Tuple
import numpy as np
import torch
import torch.nn as  nn
from typeguard import check_argument_types
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    pass
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
def pad_attractor(att, max_n_speakers):
    C, D = att.shape
    if C < max_n_speakers:
        att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0)
    return att
class DiarEENDOLAModel(AbsESPnetModel):
    """EEND-OLA diarization model"""
    def __init__(
            self,
            frontend: WavFrontendMel23,
            encoder: EENDOLATransformerEncoder,
            encoder_decoder_attractor: EncoderDecoderAttractor,
            n_units: int = 256,
            max_n_speaker: int = 8,
            attractor_loss_weight: float = 1.0,
            mapping_dict=None,
            **kwargs,
    ):
        assert check_argument_types()
        super().__init__()
        self.frontend = frontend
        self.encoder = encoder
        self.encoder_decoder_attractor = encoder_decoder_attractor
        self.attractor_loss_weight = attractor_loss_weight
        self.max_n_speaker = max_n_speaker
        if mapping_dict is None:
            mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
            self.mapping_dict = mapping_dict
        # PostNet
        self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
        self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
    def forward_encoder(self, xs, ilens):
        xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
        pad_shape = xs.shape
        xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
        xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
        emb = self.encoder(xs, xs_mask)
        emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
        emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
        return emb
    def forward_post_net(self, logits, ilens):
        maxlen = torch.max(ilens).to(torch.int).item()
        logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False)
        outputs, (_, _) = self.PostNet(logits)
        outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
        outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
        outputs = [self.output_layer(output) for output in outputs]
        return outputs
    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
                speech.shape[0]
                == speech_lengths.shape[0]
                == text.shape[0]
                == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        batch_size = speech.shape[0]
        # for data-parallel
        text = text[:, : text_lengths.max()]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        loss_att, acc_att, cer_att, wer_att = None, None, None, None
        loss_ctc, cer_ctc = None, None
        stats = dict()
        # 1. CTC branch
        if self.ctc_weight != 0.0:
            loss_ctc, cer_ctc = self._calc_ctc_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
            # Collect CTC branch stats
            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
            stats["cer_ctc"] = cer_ctc
        # Intermediate CTC (optional)
        loss_interctc = 0.0
        if self.interctc_weight != 0.0 and intermediate_outs is not None:
            for layer_idx, intermediate_out in intermediate_outs:
                # we assume intermediate_out has the same length & padding
                # as those of encoder_out
                loss_ic, cer_ic = self._calc_ctc_loss(
                    intermediate_out, encoder_out_lens, text, text_lengths
                )
                loss_interctc = loss_interctc + loss_ic
                # Collect Intermedaite CTC stats
                stats["loss_interctc_layer{}".format(layer_idx)] = (
                    loss_ic.detach() if loss_ic is not None else None
                )
                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
            loss_interctc = loss_interctc / len(intermediate_outs)
            # calculate whole encoder loss
            loss_ctc = (
                               1 - self.interctc_weight
                       ) * loss_ctc + self.interctc_weight * loss_interctc
        # 2b. Attention decoder branch
        if self.ctc_weight != 1.0:
            loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
        # 3. CTC-Att loss definition
        if self.ctc_weight == 0.0:
            loss = loss_att
        elif self.ctc_weight == 1.0:
            loss = loss_ctc
        else:
            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
        # Collect Attn branch stats
        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
        stats["acc"] = acc_att
        stats["cer"] = cer_att
        stats["wer"] = wer_att
        # Collect total loss stats
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def estimate_sequential(self,
                            speech: torch.Tensor,
                            speech_lengths: torch.Tensor,
                            n_speakers: int = None,
                            shuffle: bool = True,
                            threshold: float = 0.5,
                            **kwargs):
        if self.frontend is not None:
            speech = self.frontend(speech)
        speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
        emb = self.forward_encoder(speech, speech_lengths)
        if shuffle:
            orders = [np.arange(e.shape[0]) for e in emb]
            for order in orders:
                np.random.shuffle(order)
            attractors, probs = self.encoder_decoder_attractor.estimate(
                [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
        else:
            attractors, probs = self.encoder_decoder_attractor.estimate(emb)
        attractors_active = []
        for p, att, e in zip(probs, attractors, emb):
            if n_speakers and n_speakers >= 0:
                att = att[:n_speakers, ]
                attractors_active.append(att)
            elif threshold is not None:
                silence = torch.nonzero(p < threshold)[0]
                n_spk = silence[0] if silence.size else None
                att = att[:n_spk, ]
                attractors_active.append(att)
            else:
                NotImplementedError('n_speakers or threshold has to be given.')
        raw_n_speakers = [att.shape[0] for att in attractors_active]
        attractors = [
            pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
            for att in attractors_active]
        ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
        logits = self.forward_post_net(ys, speech_lengths)
        ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
              zip(logits, raw_n_speakers)]
        return ys, emb, attractors, raw_n_speakers
    def recover_y_from_powerlabel(self, logit, n_speaker):
        pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)
        oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
        for i in oov_index:
            if i > 0:
                pred[i] = pred[i - 1]
            else:
                pred[i] = 0
        pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
        decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
        decisions = torch.from_numpy(
            np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
            torch.float32)
        decisions = decisions[:, :n_speaker]
        return decisions
funasr/models/e2e_tp.py
New file
@@ -0,0 +1,175 @@
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import numpy as np
from typeguard import check_argument_types
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.predictor.cif import mae_loss
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.predictor.cif import CifPredictorV3
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
class TimestampPredictor(AbsESPnetModel):
    """
    Author: Speech Lab, Alibaba Group, China
    """
    def __init__(
            self,
            frontend: Optional[AbsFrontend],
            encoder: AbsEncoder,
            predictor: CifPredictorV3,
            predictor_bias: int = 0,
            token_list=None,
    ):
        assert check_argument_types()
        super().__init__()
        # note that eos is the same as sos (equivalent ID)
        self.frontend = frontend
        self.encoder = encoder
        self.encoder.interctc_use_conditioning = False
        self.predictor = predictor
        self.predictor_bias = predictor_bias
        self.criterion_pre = mae_loss()
        self.token_list = token_list
    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
                speech.shape[0]
                == speech_lengths.shape[0]
                == text.shape[0]
                == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        batch_size = speech.shape[0]
        # for data-parallel
        text = text[:, : text_lengths.max()]
        speech = speech[:, :speech_lengths.max()]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        if self.predictor_bias == 1:
            _, text = add_sos_eos(text, 1, 2, -1)
            text_lengths = text_lengths + self.predictor_bias
        _, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
        # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
        loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
        loss = loss_pre
        stats = dict()
        # Collect Attn branch stats
        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def encode(
            self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
        """
        with autocast(False):
            # 1. Extract feats
            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        # 4. Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
        return encoder_out, encoder_out_lens
    def _extract_feats(
            self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        assert speech_lengths.dim() == 1, speech_lengths.shape
        # for data-parallel
        speech = speech[:, : speech_lengths.max()]
        if self.frontend is not None:
            # Frontend
            #  e.g. STFT and Feature extract
            #       data_loader may send time-domain signal in this case
            # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
            # No frontend and no feature extract
            feats, feats_lengths = speech, speech_lengths
        return feats, feats_lengths
    def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
                                                                                               encoder_out_mask,
                                                                                               token_num)
        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
    def collect_feats(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        if self.extract_feats_in_collect_stats:
            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        else:
            # Generate dummy stats if extract_feats_in_collect_stats is False
            logging.warning(
                "Generating dummy stats for feats and feats_lengths, "
                "because encoder_conf.extract_feats_in_collect_stats is "
                f"{self.extract_feats_in_collect_stats}"
            )
            feats, feats_lengths = speech, speech_lengths
        return {"feats": feats, "feats_lengths": feats_lengths}
funasr/models/e2e_vad.py
@@ -215,6 +215,7 @@
        self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
        self.noise_average_decibel = -100.0
        self.pre_end_silence_detected = False
        self.next_seg = True
        self.output_data_buf = []
        self.output_data_buf_offset = 0
@@ -244,6 +245,7 @@
        self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
        self.noise_average_decibel = -100.0
        self.pre_end_silence_detected = False
        self.next_seg = True
        self.output_data_buf = []
        self.output_data_buf_offset = 0
@@ -441,7 +443,7 @@
                        - 1)) / self.vad_opts.noise_frame_num_used_for_snr
        return frame_state
    def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                is_final: bool = False
                ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
@@ -470,6 +472,42 @@
            self.AllResetDetection()
        return segments, in_cache
    def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                is_final: bool = False
                ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
        self.waveform = waveform  # compute decibel for each frame
        self.ComputeDecibel()
        self.ComputeScores(feats, in_cache)
        if not is_final:
            self.DetectCommonFrames()
        else:
            self.DetectLastFrames()
        segments = []
        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
            segment_batch = []
            if len(self.output_data_buf) > 0:
                for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
                    if not self.output_data_buf[i].contain_seg_start_point:
                        continue
                    if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
                        continue
                    start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
                    if self.output_data_buf[i].contain_seg_end_point:
                        end_ms = self.output_data_buf[i].end_ms
                        self.next_seg = True
                        self.output_data_buf_offset += 1
                    else:
                        end_ms = -1
                        self.next_seg = False
                    segment = [start_ms, end_ms]
                    segment_batch.append(segment)
            if segment_batch:
                segments.append(segment_batch)
        if is_final:
            # reset class variables and clear the dict for the next query
            self.AllResetDetection()
        return segments, in_cache
    def DetectCommonFrames(self) -> int:
        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
funasr/models/frontend/eend_ola_feature.py
New file
@@ -0,0 +1,51 @@
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
# Licensed under the MIT license.
#
# This module is for computing audio features
import librosa
import numpy as np
def transform(Y, dtype=np.float32):
    Y = np.abs(Y)
    n_fft = 2 * (Y.shape[1] - 1)
    sr = 8000
    n_mels = 23
    mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
    Y = np.dot(Y ** 2, mel_basis.T)
    Y = np.log10(np.maximum(Y, 1e-10))
    mean = np.mean(Y, axis=0)
    Y = Y - mean
    return Y.astype(dtype)
def subsample(Y, T, subsampling=1):
    Y_ss = Y[::subsampling]
    T_ss = T[::subsampling]
    return Y_ss, T_ss
def splice(Y, context_size=0):
    Y_pad = np.pad(
        Y,
        [(context_size, context_size), (0, 0)],
        'constant')
    Y_spliced = np.lib.stride_tricks.as_strided(
        np.ascontiguousarray(Y_pad),
        (Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
        (Y.itemsize * Y.shape[1], Y.itemsize), writeable=False)
    return Y_spliced
def stft(
        data,
        frame_size=1024,
        frame_shift=256):
    fft_size = 1 << (frame_size - 1).bit_length()
    if len(data) % frame_shift == 0:
        return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
                            hop_length=frame_shift).T[:-1]
    else:
        return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
                            hop_length=frame_shift).T
funasr/models/frontend/wav_frontend.py
@@ -1,6 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
from abc import ABC
from typing import Tuple
import numpy as np
@@ -33,9 +33,9 @@
    means = np.array(means_list).astype(np.float)
    vars = np.array(vars_list).astype(np.float)
    cmvn = np.array([means, vars])
    cmvn = torch.as_tensor(cmvn)
    return cmvn
    cmvn = torch.as_tensor(cmvn)
    return cmvn
def apply_cmvn(inputs, cmvn_file):  # noqa
    """
@@ -78,21 +78,22 @@
class WavFrontend(AbsFrontend):
    """Conventional frontend structure for ASR.
    """
    def __init__(
        self,
        cmvn_file: str = None,
        fs: int = 16000,
        window: str = 'hamming',
        n_mels: int = 80,
        frame_length: int = 25,
        frame_shift: int = 10,
        filter_length_min: int = -1,
        filter_length_max: int = -1,
        lfr_m: int = 1,
        lfr_n: int = 1,
        dither: float = 1.0,
        snip_edges: bool = True,
        upsacle_samples: bool = True,
            self,
            cmvn_file: str = None,
            fs: int = 16000,
            window: str = 'hamming',
            n_mels: int = 80,
            frame_length: int = 25,
            frame_shift: int = 10,
            filter_length_min: int = -1,
            filter_length_max: int = -1,
            lfr_m: int = 1,
            lfr_n: int = 1,
            dither: float = 1.0,
            snip_edges: bool = True,
            upsacle_samples: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
@@ -135,11 +136,11 @@
                              window_type=self.window,
                              sample_frequency=self.fs,
                              snip_edges=self.snip_edges)
            if self.lfr_m != 1 or self.lfr_n != 1:
                mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
            if self.cmvn_file is not None:
                mat = apply_cmvn(mat, self.cmvn_file)
                mat = apply_cmvn(mat, self.cmvn_file)
            feat_length = mat.size(0)
            feats.append(mat)
            feats_lens.append(feat_length)
@@ -170,7 +171,6 @@
                              energy_floor=0.0,
                              window_type=self.window,
                              sample_frequency=self.fs)
            feat_length = mat.size(0)
            feats.append(mat)
@@ -204,3 +204,243 @@
                                 batch_first=True,
                                 padding_value=0.0)
        return feats_pad, feats_lens
class WavFrontendOnline(AbsFrontend):
    """Conventional frontend structure for streaming ASR/VAD.
    """
    def __init__(
            self,
            cmvn_file: str = None,
            fs: int = 16000,
            window: str = 'hamming',
            n_mels: int = 80,
            frame_length: int = 25,
            frame_shift: int = 10,
            filter_length_min: int = -1,
            filter_length_max: int = -1,
            lfr_m: int = 1,
            lfr_n: int = 1,
            dither: float = 1.0,
            snip_edges: bool = True,
            upsacle_samples: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
        self.fs = fs
        self.window = window
        self.n_mels = n_mels
        self.frame_length = frame_length
        self.frame_shift = frame_shift
        self.frame_sample_length = int(self.frame_length * self.fs / 1000)
        self.frame_shift_sample_length = int(self.frame_shift * self.fs / 1000)
        self.filter_length_min = filter_length_min
        self.filter_length_max = filter_length_max
        self.lfr_m = lfr_m
        self.lfr_n = lfr_n
        self.cmvn_file = cmvn_file
        self.dither = dither
        self.snip_edges = snip_edges
        self.upsacle_samples = upsacle_samples
        self.waveforms = None
        self.reserve_waveforms = None
        self.fbanks = None
        self.fbanks_lens = None
        self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
        self.input_cache = None
        self.lfr_splice_cache = []
    def output_size(self) -> int:
        return self.n_mels * self.lfr_m
    @staticmethod
    def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor:
        """
        Apply CMVN with mvn data
        """
        device = inputs.device
        dtype = inputs.dtype
        frame, dim = inputs.shape
        means = np.tile(cmvn[0:1, :dim], (frame, 1))
        vars = np.tile(cmvn[1:2, :dim], (frame, 1))
        inputs += torch.from_numpy(means).type(dtype).to(device)
        inputs *= torch.from_numpy(vars).type(dtype).to(device)
        return inputs.type(torch.float32)
    @staticmethod
    # inputs tensor has catted the cache tensor
    # def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, inputs_lfr_cache: torch.Tensor = None,
    #               is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
    def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
        """
        Apply lfr with data
        """
        LFR_inputs = []
        # inputs = torch.vstack((inputs_lfr_cache, inputs))
        T = inputs.shape[0]  # include the right context
        T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n))  # minus the right context: (lfr_m - 1) // 2
        splice_idx = T_lfr
        for i in range(T_lfr):
            if lfr_m <= T - i * lfr_n:
                LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
            else:  # process last LFR frame
                if is_final:
                    num_padding = lfr_m - (T - i * lfr_n)
                    frame = (inputs[i * lfr_n:]).view(-1)
                    for _ in range(num_padding):
                        frame = torch.hstack((frame, inputs[-1]))
                    LFR_inputs.append(frame)
                else:
                    # update splice_idx and break the circle
                    splice_idx = i
                    break
        splice_idx = min(T - 1, splice_idx * lfr_n)
        lfr_splice_cache = inputs[splice_idx:, :]
        LFR_outputs = torch.vstack(LFR_inputs)
        return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
    @staticmethod
    def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int:
        frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
        return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
    def forward_fbank(
            self,
            input: torch.Tensor,
            input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        batch_size = input.size(0)
        if self.input_cache is None:
            self.input_cache = torch.empty(0)
        input = torch.cat((self.input_cache, input), dim=1)
        frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length)
        # update self.in_cache
        self.input_cache = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
        waveforms = torch.empty(0)
        feats_pad = torch.empty(0)
        feats_lens = torch.empty(0)
        if frame_num:
            waveforms = []
            feats = []
            feats_lens = []
            for i in range(batch_size):
                waveform = input[i]
                # we need accurate wave samples that used for fbank extracting
                waveforms.append(
                    waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)])
                waveform = waveform * (1 << 15)
                waveform = waveform.unsqueeze(0)
                mat = kaldi.fbank(waveform,
                                  num_mel_bins=self.n_mels,
                                  frame_length=self.frame_length,
                                  frame_shift=self.frame_shift,
                                  dither=self.dither,
                                  energy_floor=0.0,
                                  window_type=self.window,
                                  sample_frequency=self.fs)
                feat_length = mat.size(0)
                feats.append(mat)
                feats_lens.append(feat_length)
            waveforms = torch.stack(waveforms)
            feats_lens = torch.as_tensor(feats_lens)
            feats_pad = pad_sequence(feats,
                                     batch_first=True,
                                     padding_value=0.0)
        self.fbanks = feats_pad
        import copy
        self.fbanks_lens = copy.deepcopy(feats_lens)
        return waveforms, feats_pad, feats_lens
    def get_fbank(self) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.fbanks, self.fbanks_lens
    def forward_lfr_cmvn(
            self,
            input: torch.Tensor,
            input_lengths: torch.Tensor,
            is_final: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        lfr_splice_frame_idxs = []
        for i in range(batch_size):
            mat = input[i, :input_lengths[i], :]
            if self.lfr_m != 1 or self.lfr_n != 1:
                # update self.lfr_splice_cache in self.apply_lfr
                # mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
                mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, is_final)
            if self.cmvn_file is not None:
                mat = self.apply_cmvn(mat, self.cmvn)
            feat_length = mat.size(0)
            feats.append(mat)
            feats_lens.append(feat_length)
            lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
        feats_lens = torch.as_tensor(feats_lens)
        feats_pad = pad_sequence(feats,
                                 batch_first=True,
                                 padding_value=0.0)
        lfr_splice_frame_idxs = torch.as_tensor(lfr_splice_frame_idxs)
        return feats_pad, feats_lens, lfr_splice_frame_idxs
    def forward(
            self, input: torch.Tensor, input_lengths: torch.Tensor, is_final: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = input.shape[0]
        assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
        waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths)  # input shape: B T D
        if feats.shape[0]:
            #if self.reserve_waveforms is None and self.lfr_m > 1:
            #    self.reserve_waveforms = waveforms[:, :(self.lfr_m - 1) // 2 * self.frame_shift_sample_length]
            self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat((self.reserve_waveforms, waveforms), dim=1)
            if not self.lfr_splice_cache:  # 初始化splice_cache
                for i in range(batch_size):
                    self.lfr_splice_cache.append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
            # need the number of the input frames + self.lfr_splice_cache[0].shape[0] is greater than self.lfr_m
            if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
                lfr_splice_cache_tensor = torch.stack(self.lfr_splice_cache)  # B T D
                feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
                feats_lengths += lfr_splice_cache_tensor[0].shape[0]
                frame_from_waveforms = int((self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
                minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
                feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
                if self.lfr_m == 1:
                    self.reserve_waveforms = None
                else:
                    reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
                    # print('reserve_frame_idx:  ' + str(reserve_frame_idx))
                    # print('frame_frame:  ' + str(frame_from_waveforms))
                    self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
                    sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
                    self.waveforms = self.waveforms[:, :sample_length]
            else:
                # update self.reserve_waveforms and self.lfr_splice_cache
                self.reserve_waveforms = self.waveforms[:, :-(self.frame_sample_length - self.frame_shift_sample_length)]
                for i in range(batch_size):
                    self.lfr_splice_cache[i] = torch.cat((self.lfr_splice_cache[i], feats[i]), dim=0)
                return torch.empty(0), feats_lengths
        else:
            if is_final:
                self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
                feats = torch.stack(self.lfr_splice_cache)
                feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
                feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
        if is_final:
            self.cache_reset()
        return feats, feats_lengths
    def get_waveforms(self):
        return self.waveforms
    def cache_reset(self):
        self.reserve_waveforms = None
        self.input_cache = None
        self.lfr_splice_cache = []
funasr/models/pooling/statistic_pooling.py
@@ -82,7 +82,7 @@
    tt = xs_pad.shape[2]
    num_chunk = int(math.ceil(tt / pooling_stride))
    pad = pooling_size // 2
    if xs_pad.shape == 4:
    if len(xs_pad.shape) == 4:
        features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
    else:
        features = F.pad(xs_pad, (pad, pad), "reflect")
funasr/modules/eend_ola/encoder.py
@@ -1,5 +1,5 @@
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
@@ -81,10 +81,16 @@
        return self.dropout(x)
class TransformerEncoder(nn.Module):
    def __init__(self, idim, n_layers, n_units,
                 e_units=2048, h=8, dropout_rate=0.1, use_pos_emb=False):
        super(TransformerEncoder, self).__init__()
class EENDOLATransformerEncoder(nn.Module):
    def __init__(self,
                 idim: int,
                 n_layers: int,
                 n_units: int,
                 e_units: int = 2048,
                 h: int = 8,
                 dropout_rate: float = 0.1,
                 use_pos_emb: bool = False):
        super(EENDOLATransformerEncoder, self).__init__()
        self.lnorm_in = nn.LayerNorm(n_units)
        self.n_layers = n_layers
        self.dropout = nn.Dropout(dropout_rate)
funasr/runtime/onnxruntime/readme.md
@@ -29,6 +29,7 @@
## 依赖
- fftw3
- openblas
- onnxruntime
## 导出onnx格式模型文件
@@ -47,18 +48,22 @@
## Building Guidance for Linux/Unix
```
git clone https://github.com/RapidAI/RapidASR.git
cd RapidASR/cpp_onnx/
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/onnxruntime
mkdir build
cd build
# download an appropriate onnxruntime from https://github.com/microsoft/onnxruntime/releases/tag/v1.14.0
# here we get a copy of onnxruntime for linux 64
wget https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
# ls
# onnxruntime-linux-x64-1.14.0  onnxruntime-linux-x64-1.14.0.tgz
#install fftw3-dev
apt install libfftw3-dev
ubuntu: apt install libfftw3-dev
centos: yum install fftw fftw-devel
#install openblas
bash ./third_party/install_openblas.sh
# build
 cmake  -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/mnt/c/Users/ma139/RapidASR/cpp_onnx/build/onnxruntime-linux-x64-1.14.0
funasr/runtime/onnxruntime/third_party/install_openblas.sh
New file
@@ -0,0 +1,39 @@
#!/usr/bin/env bash
OPENBLAS_VERSION=0.3.13
WGET=${WGET:-wget}
set -e
if ! command -v gfortran 2>/dev/null; then
  echo "$0: gfortran is not installed.  Please install it, e.g. by:"
  echo " apt-get install gfortran"
  echo "(if on Debian or Ubuntu), or:"
  echo " yum install gcc-gfortran"
  echo "(if on RedHat/CentOS).  On a Mac, if brew is installed, it's:"
  echo " brew install gfortran"
  exit 1
fi
tarball=OpenBLAS-$OPENBLAS_VERSION.tar.gz
rm -rf xianyi-OpenBLAS-* OpenBLAS OpenBLAS-*.tar.gz
if [ -d "$DOWNLOAD_DIR" ]; then
  cp -p "$DOWNLOAD_DIR/$tarball" .
else
  url=$($WGET -qO- "https://api.github.com/repos/xianyi/OpenBLAS/releases/tags/v${OPENBLAS_VERSION}" | python -c 'import sys,json;print(json.load(sys.stdin)["tarball_url"])')
  test -n "$url"
  $WGET -t3 -nv -O $tarball "$url"
fi
tar xzf $tarball
mv xianyi-OpenBLAS-* OpenBLAS
make PREFIX=$(pwd)/OpenBLAS/install USE_LOCKING=1 USE_THREAD=0 -C OpenBLAS all install
if [ $? -eq 0 ]; then
   echo "OpenBLAS is installed successfully."
   rm $tarball
fi
funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py
@@ -150,6 +150,7 @@
    def __init__(self, model_file, device_id=-1):
        device_id = str(device_id)
        sess_opt = SessionOptions()
        sess_opt.intra_op_num_threads = 4
        sess_opt.log_severity_level = 4
        sess_opt.enable_cpu_mem_arena = False
        sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
funasr/runtime/triton_gpu/client/aishell_test.txt
New file
Diff too large
funasr/runtime/triton_gpu/client/decode_manifest_triton_wo_cuts.py
New file
@@ -0,0 +1,561 @@
#!/usr/bin/env python3
# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang)
#                2023  Nvidia              (authors: Yuekai Zhang)
#                2023  Recurrent.ai    (authors: Songtao Shi)
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a manifest in nemo format and sends it to the server
for decoding, in parallel.
{'audio_filepath':'','text':'',duration:}\n
{'audio_filepath':'','text':'',duration:}\n
Usage:
# For aishell manifests:
apt-get install git-lfs
git-lfs install
git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
sudo mkdir -p ./aishell-test-dev-manifests/aishell
tar xf ./aishell-test-dev-manifests/data_aishell.tar.gz -C ./aishell-test-dev-manifests/aishell # noqa
# cmd run
manifest_path='./client/aishell_test.txt'
serveraddr=localhost
num_task=60
python3 client/decode_manifest_triton_wo_cuts.py \
    --server-addr $serveraddr \
    --compute-cer \
    --model-name infer_pipeline \
    --num-tasks $num_task \
    --manifest-filename $manifest_path \
"""
from pydub import AudioSegment
import argparse
import asyncio
import math
import time
import types
from pathlib import Path
import json
import os
import numpy as np
import tritonclient
import tritonclient.grpc.aio as grpcclient
from tritonclient.utils import np_to_triton_dtype
from icefall.utils import store_transcripts, write_error_stats
DEFAULT_MANIFEST_FILENAME = "./aishell_test.txt"  # noqa
DEFAULT_ROOT = './'
DEFAULT_ROOT = '/mfs/songtao/researchcode/FunASR/data/'
def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        "--server-addr",
        type=str,
        default="localhost",
        help="Address of the server",
    )
    parser.add_argument(
        "--server-port",
        type=int,
        default=8001,
        help="Port of the server",
    )
    parser.add_argument(
        "--manifest-filename",
        type=str,
        default=DEFAULT_MANIFEST_FILENAME,
        help="Path to the manifest for decoding",
    )
    parser.add_argument(
        "--model-name",
        type=str,
        default="transducer",
        help="triton model_repo module name to request",
    )
    parser.add_argument(
        "--num-tasks",
        type=int,
        default=50,
        help="Number of tasks to use for sending",
    )
    parser.add_argument(
        "--log-interval",
        type=int,
        default=5,
        help="Controls how frequently we print the log.",
    )
    parser.add_argument(
        "--compute-cer",
        action="store_true",
        default=False,
        help="""True to compute CER, e.g., for Chinese.
        False to compute WER, e.g., for English words.
        """,
    )
    parser.add_argument(
        "--streaming",
        action="store_true",
        default=False,
        help="""True for streaming ASR.
        """,
    )
    parser.add_argument(
        "--simulate-streaming",
        action="store_true",
        default=False,
        help="""True for strictly simulate streaming ASR.
        Threads will sleep to simulate the real speaking scene.
        """,
    )
    parser.add_argument(
        "--chunk_size",
        type=int,
        required=False,
        default=16,
        help="chunk size default is 16",
    )
    parser.add_argument(
        "--context",
        type=int,
        required=False,
        default=-1,
        help="subsampling context for wenet",
    )
    parser.add_argument(
        "--encoder_right_context",
        type=int,
        required=False,
        default=2,
        help="encoder right context",
    )
    parser.add_argument(
        "--subsampling",
        type=int,
        required=False,
        default=4,
        help="subsampling rate",
    )
    parser.add_argument(
        "--stats_file",
        type=str,
        required=False,
        default="./stats.json",
        help="output of stats anaylasis",
    )
    return parser.parse_args()
def load_manifest(fp):
    data = []
    with open(fp) as f:
        for i, dp in enumerate(f.readlines()):
            dp = eval(dp)
            dp['id'] = i
            data.append(dp)
    return data
def split_dps(dps, num_tasks):
    dps_splited = []
    # import pdb;pdb.set_trace()
    assert len(dps) > num_tasks
    one_task_num = len(dps)//num_tasks
    for i in range(0, len(dps), one_task_num):
        if i+one_task_num >= len(dps):
            for k, j in enumerate(range(i, len(dps))):
                dps_splited[k].append(dps[j])
        else:
            dps_splited.append(dps[i:i+one_task_num])
    return dps_splited
def load_audio(path):
    audio = AudioSegment.from_wav(path).set_frame_rate(16000).set_channels(1)
    audiop_np = np.array(audio.get_array_of_samples())/32768.0
    return audiop_np.astype(np.float32), audio.duration_seconds
async def send(
    dps: list,
    name: str,
    triton_client: tritonclient.grpc.aio.InferenceServerClient,
    protocol_client: types.ModuleType,
    log_interval: int,
    compute_cer: bool,
    model_name: str,
):
    total_duration = 0.0
    results = []
    for i, dp in enumerate(dps):
        if i % log_interval == 0:
            print(f"{name}: {i}/{len(dps)}")
        waveform, duration = load_audio(
            os.path.join(DEFAULT_ROOT, dp['audio_filepath']))
        sample_rate = 16000
        # padding to nearset 10 seconds
        samples = np.zeros(
            (
                1,
                10 * sample_rate *
                (int(len(waveform) / sample_rate // 10) + 1),
            ),
            dtype=np.float32,
        )
        samples[0, : len(waveform)] = waveform
        lengths = np.array([[len(waveform)]], dtype=np.int32)
        inputs = [
            protocol_client.InferInput(
                "WAV", samples.shape, np_to_triton_dtype(samples.dtype)
            ),
            protocol_client.InferInput(
                "WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
            ),
        ]
        inputs[0].set_data_from_numpy(samples)
        inputs[1].set_data_from_numpy(lengths)
        outputs = [protocol_client.InferRequestedOutput("TRANSCRIPTS")]
        sequence_id = 10086 + i
        response = await triton_client.infer(
            model_name, inputs, request_id=str(sequence_id), outputs=outputs
        )
        decoding_results = response.as_numpy("TRANSCRIPTS")[0]
        if type(decoding_results) == np.ndarray:
            decoding_results = b" ".join(decoding_results).decode("utf-8")
        else:
            # For wenet
            decoding_results = decoding_results.decode("utf-8")
        total_duration += duration
        if compute_cer:
            ref = dp['text'].split()
            hyp = decoding_results.split()
            ref = list("".join(ref))
            hyp = list("".join(hyp))
            results.append((dp['id'], ref, hyp))
        else:
            results.append(
                (
                    dp['id'],
                    dp['text'].split(),
                    decoding_results.split(),
                )
            )  # noqa
    return total_duration, results
async def send_streaming(
    dps: list,
    name: str,
    triton_client: tritonclient.grpc.aio.InferenceServerClient,
    protocol_client: types.ModuleType,
    log_interval: int,
    compute_cer: bool,
    model_name: str,
    first_chunk_in_secs: float,
    other_chunk_in_secs: float,
    task_index: int,
    simulate_mode: bool = False,
):
    total_duration = 0.0
    results = []
    latency_data = []
    for i, dp in enumerate(dps):
        if i % log_interval == 0:
            print(f"{name}: {i}/{len(dps)}")
        waveform, duration = load_audio(dp['audio_filepath'])
        sample_rate = 16000
        wav_segs = []
        j = 0
        while j < len(waveform):
            if j == 0:
                stride = int(first_chunk_in_secs * sample_rate)
                wav_segs.append(waveform[j: j + stride])
            else:
                stride = int(other_chunk_in_secs * sample_rate)
                wav_segs.append(waveform[j: j + stride])
            j += len(wav_segs[-1])
        sequence_id = task_index + 10086
        for idx, seg in enumerate(wav_segs):
            chunk_len = len(seg)
            if simulate_mode:
                await asyncio.sleep(chunk_len / sample_rate)
            chunk_start = time.time()
            if idx == 0:
                chunk_samples = int(first_chunk_in_secs * sample_rate)
                expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
            else:
                chunk_samples = int(other_chunk_in_secs * sample_rate)
                expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
            expect_input[0][0:chunk_len] = seg
            input0_data = expect_input
            input1_data = np.array([[chunk_len]], dtype=np.int32)
            inputs = [
                protocol_client.InferInput(
                    "WAV",
                    input0_data.shape,
                    np_to_triton_dtype(input0_data.dtype),
                ),
                protocol_client.InferInput(
                    "WAV_LENS",
                    input1_data.shape,
                    np_to_triton_dtype(input1_data.dtype),
                ),
            ]
            inputs[0].set_data_from_numpy(input0_data)
            inputs[1].set_data_from_numpy(input1_data)
            outputs = [protocol_client.InferRequestedOutput("TRANSCRIPTS")]
            end = False
            if idx == len(wav_segs) - 1:
                end = True
            response = await triton_client.infer(
                model_name,
                inputs,
                outputs=outputs,
                sequence_id=sequence_id,
                sequence_start=idx == 0,
                sequence_end=end,
            )
            idx += 1
            decoding_results = response.as_numpy("TRANSCRIPTS")
            if type(decoding_results) == np.ndarray:
                decoding_results = b" ".join(decoding_results).decode("utf-8")
            else:
                # For wenet
                decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode(
                    "utf-8"
                )
            chunk_end = time.time() - chunk_start
            latency_data.append((chunk_end, chunk_len / sample_rate))
        total_duration += duration
        if compute_cer:
            ref = dp['text'].split()
            hyp = decoding_results.split()
            ref = list("".join(ref))
            hyp = list("".join(hyp))
            results.append((dp['id'], ref, hyp))
        else:
            results.append(
                (
                    dp['id'],
                    dp['text'].split(),
                    decoding_results.split(),
                )
            )  # noqa
    return total_duration, results, latency_data
async def main():
    args = get_args()
    filename = args.manifest_filename
    server_addr = args.server_addr
    server_port = args.server_port
    url = f"{server_addr}:{server_port}"
    num_tasks = args.num_tasks
    log_interval = args.log_interval
    compute_cer = args.compute_cer
    dps = load_manifest(filename)
    dps_list = split_dps(dps, num_tasks)
    tasks = []
    triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
    protocol_client = grpcclient
    if args.streaming or args.simulate_streaming:
        frame_shift_ms = 10
        frame_length_ms = 25
        add_frames = math.ceil(
            (frame_length_ms - frame_shift_ms) / frame_shift_ms
        )
        # decode_window_length: input sequence length of streaming encoder
        if args.context > 0:
            # decode window length calculation for wenet
            decode_window_length = (
                args.chunk_size - 1
            ) * args.subsampling + args.context
        else:
            # decode window length calculation for icefall
            decode_window_length = (
                args.chunk_size + 2 + args.encoder_right_context
            ) * args.subsampling + 3
        first_chunk_ms = (decode_window_length + add_frames) * frame_shift_ms
    start_time = time.time()
    for i in range(num_tasks):
        if args.streaming:
            assert not args.simulate_streaming
            task = asyncio.create_task(
                send_streaming(
                    dps=dps_list[i],
                    name=f"task-{i}",
                    triton_client=triton_client,
                    protocol_client=protocol_client,
                    log_interval=log_interval,
                    compute_cer=compute_cer,
                    model_name=args.model_name,
                    first_chunk_in_secs=first_chunk_ms / 1000,
                    other_chunk_in_secs=args.chunk_size
                    * args.subsampling
                    * frame_shift_ms
                    / 1000,
                    task_index=i,
                )
            )
        elif args.simulate_streaming:
            task = asyncio.create_task(
                send_streaming(
                    dps=dps_list[i],
                    name=f"task-{i}",
                    triton_client=triton_client,
                    protocol_client=protocol_client,
                    log_interval=log_interval,
                    compute_cer=compute_cer,
                    model_name=args.model_name,
                    first_chunk_in_secs=first_chunk_ms / 1000,
                    other_chunk_in_secs=args.chunk_size
                    * args.subsampling
                    * frame_shift_ms
                    / 1000,
                    task_index=i,
                    simulate_mode=True,
                )
            )
        else:
            task = asyncio.create_task(
                send(
                    dps=dps_list[i],
                    name=f"task-{i}",
                    triton_client=triton_client,
                    protocol_client=protocol_client,
                    log_interval=log_interval,
                    compute_cer=compute_cer,
                    model_name=args.model_name,
                )
            )
        tasks.append(task)
    ans_list = await asyncio.gather(*tasks)
    end_time = time.time()
    elapsed = end_time - start_time
    results = []
    total_duration = 0.0
    latency_data = []
    for ans in ans_list:
        total_duration += ans[0]
        results += ans[1]
        if args.streaming or args.simulate_streaming:
            latency_data += ans[2]
    rtf = elapsed / total_duration
    s = f"RTF: {rtf:.4f}\n"
    s += f"total_duration: {total_duration:.3f} seconds\n"
    s += f"({total_duration/3600:.2f} hours)\n"
    s += (
        f"processing time: {elapsed:.3f} seconds "
        f"({elapsed/3600:.2f} hours)\n"
    )
    if args.streaming or args.simulate_streaming:
        latency_list = [
            chunk_end for (chunk_end, chunk_duration) in latency_data
        ]
        latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
        latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
        s += f"latency_variance: {latency_variance:.2f}\n"
        s += f"latency_50_percentile: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
        s += f"latency_90_percentile: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
        s += f"latency_99_percentile: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
        s += f"average_latency_ms: {latency_ms:.2f}\n"
    print(s)
    with open("rtf.txt", "w") as f:
        f.write(s)
    name = Path(filename).stem.split(".")[0]
    results = sorted(results)
    store_transcripts(filename=f"recogs-{name}.txt", texts=results)
    with open(f"errs-{name}.txt", "w") as f:
        write_error_stats(f, "test-set", results, enable_log=True)
    with open(f"errs-{name}.txt", "r") as f:
        print(f.readline())  # WER
        print(f.readline())  # Detailed errors
    if args.stats_file:
        stats = await triton_client.get_inference_statistics(
            model_name="", as_json=True
        )
        with open(args.stats_file, "w") as f:
            json.dump(stats, f)
if __name__ == "__main__":
    asyncio.run(main())
funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/encoder/config.pbtxt
@@ -40,7 +40,7 @@
  },
  {
    name: "token_num"
    data_type: TYPE_INT64
    data_type: TYPE_INT32
    dims: [1]
    reshape: { shape: [ ] }
  }
funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/config.pbtxt
@@ -43,7 +43,7 @@
  },
  {
    name: "token_num"
    data_type: TYPE_INT64
    data_type: TYPE_INT32
    dims: [1]
    reshape: { shape: [ ] }
  }
funasr/tasks/asr.py
@@ -40,6 +40,7 @@
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.abs_encoder import AbsEncoder
@@ -124,6 +125,7 @@
        bicif_paraformer=BiCifParaformer,
        contextual_paraformer=ContextualParaformer,
        mfcca=MFCCA,
        timestamp_prediction=TimestampPredictor,
    ),
    type_check=AbsESPnetModel,
    default="asr",
@@ -1245,9 +1247,87 @@
class ASRTaskAligner(ASRTaskParaformer):
    # If you need more than one optimizers, change this value
    num_optimizers: int = 1
    # Add variable objects configurations
    class_choices_list = [
        # --frontend and --frontend_conf
        frontend_choices,
        # --model and --model_conf
        model_choices,
        # --encoder and --encoder_conf
        encoder_choices,
        # --decoder and --decoder_conf
        decoder_choices,
    ]
    # If you need to modify train() or eval() procedures, change Trainer class here
    trainer = Trainer
    @classmethod
    def build_model(cls, args: argparse.Namespace):
        assert check_argument_types()
        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
                token_list = [line.rstrip() for line in f]
            # Overwriting token_list to keep it as "portable".
            args.token_list = list(token_list)
        elif isinstance(args.token_list, (tuple, list)):
            token_list = list(args.token_list)
        else:
            raise RuntimeError("token_list must be str or list")
        # 1. frontend
        if args.input_size is None:
            # Extract features in the model
            frontend_class = frontend_choices.get_class(args.frontend)
            if args.frontend == 'wav_frontend':
                frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
            else:
                frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            # Give features from data-loader
            args.frontend = None
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size
        # 2. Encoder
        encoder_class = encoder_choices.get_class(args.encoder)
        encoder = encoder_class(input_size=input_size, **args.encoder_conf)
        # 3. Predictor
        predictor_class = predictor_choices.get_class(args.predictor)
        predictor = predictor_class(**args.predictor_conf)
        # 10. Build model
        try:
            model_class = model_choices.get_class(args.model)
        except AttributeError:
            model_class = model_choices.get_class("asr")
        # 8. Build model
        model = model_class(
            frontend=frontend,
            encoder=encoder,
            predictor=predictor,
            token_list=token_list,
            **args.model_conf,
        )
        # 11. Initialize
        if args.init is not None:
            initialize(model, args.init)
        assert check_return_type(model)
        return model
    @classmethod
    def required_data_names(
            cls, train: bool = True, inference: bool = False
    ) -> Tuple[str, ...]:
        retval = ("speech", "text")
        return retval
        return retval
funasr/tasks/diar.py
@@ -20,19 +20,19 @@
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.layers.label_aggregation import LabelAggregate
from funasr.models.ctc import CTC
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.e2e_diar_sond import DiarSondModel
from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
@@ -41,17 +41,13 @@
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
    HuggingFaceTransformersPostEncoder,  # noqa: H301
)
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.preencoder.linear import LinearProjection
from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -70,6 +66,7 @@
        s3prl=S3prlFrontend,
        fused=FusedFrontends,
        wav_frontend=WavFrontend,
        wav_frontend_mel23=WavFrontendMel23,
    ),
    type_check=AbsFrontend,
    default="default",
@@ -107,6 +104,7 @@
    "model",
    classes=dict(
        sond=DiarSondModel,
        eend_ola=DiarEENDOLAModel,
    ),
    type_check=AbsESPnetModel,
    default="sond",
@@ -126,6 +124,7 @@
        sanm_chunk_opt=SANMEncoderChunkOpt,
        data2vec_encoder=Data2VecEncoder,
        ecapa_tdnn=ECAPA_TDNN,
        eend_ola_transformer=EENDOLATransformerEncoder,
    ),
    type_check=torch.nn.Module,
    default="resnet34",
@@ -176,6 +175,15 @@
    ),
    type_check=torch.nn.Module,
    default="fsmn",
)
# encoder_decoder_attractor is used for EEND-OLA
encoder_decoder_attractor_choices = ClassChoices(
    "encoder_decoder_attractor",
    classes=dict(
        eda=EncoderDecoderAttractor,
    ),
    type_check=torch.nn.Module,
    default="eda",
)
@@ -594,3 +602,294 @@
            var_dict_torch_update.update(var_dict_torch_update_local)
        return var_dict_torch_update
class EENDOLADiarTask(AbsTask):
    # If you need more than 1 optimizer, change this value
    num_optimizers: int = 1
    # Add variable objects configurations
    class_choices_list = [
        # --frontend and --frontend_conf
        frontend_choices,
        # --specaug and --specaug_conf
        model_choices,
        # --encoder and --encoder_conf
        encoder_choices,
        # --speaker_encoder and --speaker_encoder_conf
        encoder_decoder_attractor_choices,
    ]
    # If you need to modify train() or eval() procedures, change Trainer class here
    trainer = Trainer
    @classmethod
    def add_task_arguments(cls, parser: argparse.ArgumentParser):
        group = parser.add_argument_group(description="Task related")
        # NOTE(kamo): add_arguments(..., required=True) can't be used
        # to provide --print_config mode. Instead of it, do as
        # required = parser.get_default("required")
        # required += ["token_list"]
        group.add_argument(
            "--token_list",
            type=str_or_none,
            default=None,
            help="A text mapping int-id to token",
        )
        group.add_argument(
            "--split_with_space",
            type=str2bool,
            default=True,
            help="whether to split text using <space>",
        )
        group.add_argument(
            "--seg_dict_file",
            type=str,
            default=None,
            help="seg_dict_file for text processing",
        )
        group.add_argument(
            "--init",
            type=lambda x: str_or_none(x.lower()),
            default=None,
            help="The initialization method",
            choices=[
                "chainer",
                "xavier_uniform",
                "xavier_normal",
                "kaiming_uniform",
                "kaiming_normal",
                None,
            ],
        )
        group.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
        group = parser.add_argument_group(description="Preprocess related")
        group.add_argument(
            "--use_preprocessor",
            type=str2bool,
            default=True,
            help="Apply preprocessing to data or not",
        )
        group.add_argument(
            "--token_type",
            type=str,
            default="char",
            choices=["char"],
            help="The text will be tokenized in the specified level token",
        )
        parser.add_argument(
            "--speech_volume_normalize",
            type=float_or_none,
            default=None,
            help="Scale the maximum amplitude to the given value.",
        )
        parser.add_argument(
            "--rir_scp",
            type=str_or_none,
            default=None,
            help="The file path of rir scp file.",
        )
        parser.add_argument(
            "--rir_apply_prob",
            type=float,
            default=1.0,
            help="THe probability for applying RIR convolution.",
        )
        parser.add_argument(
            "--cmvn_file",
            type=str_or_none,
            default=None,
            help="The file path of noise scp file.",
        )
        parser.add_argument(
            "--noise_scp",
            type=str_or_none,
            default=None,
            help="The file path of noise scp file.",
        )
        parser.add_argument(
            "--noise_apply_prob",
            type=float,
            default=1.0,
            help="The probability applying Noise adding.",
        )
        parser.add_argument(
            "--noise_db_range",
            type=str,
            default="13_15",
            help="The range of noise decibel level.",
        )
        for class_choices in cls.class_choices_list:
            # Append --<name> and --<name>_conf.
            # e.g. --encoder and --encoder_conf
            class_choices.add_arguments(group)
    @classmethod
    def build_collate_fn(
            cls, args: argparse.Namespace, train: bool
    ) -> Callable[
        [Collection[Tuple[str, Dict[str, np.ndarray]]]],
        Tuple[List[str], Dict[str, torch.Tensor]],
    ]:
        assert check_argument_types()
        # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
        return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
    @classmethod
    def build_preprocess_fn(
            cls, args: argparse.Namespace, train: bool
    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
        assert check_argument_types()
        if args.use_preprocessor:
            retval = CommonPreprocessor(
                train=train,
                token_type=args.token_type,
                token_list=args.token_list,
                bpemodel=None,
                non_linguistic_symbols=None,
                text_cleaner=None,
                g2p_type=None,
                split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
                seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
                # NOTE(kamo): Check attribute existence for backward compatibility
                rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
                rir_apply_prob=args.rir_apply_prob
                if hasattr(args, "rir_apply_prob")
                else 1.0,
                noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
                noise_apply_prob=args.noise_apply_prob
                if hasattr(args, "noise_apply_prob")
                else 1.0,
                noise_db_range=args.noise_db_range
                if hasattr(args, "noise_db_range")
                else "13_15",
                speech_volume_normalize=args.speech_volume_normalize
                if hasattr(args, "rir_scp")
                else None,
            )
        else:
            retval = None
        assert check_return_type(retval)
        return retval
    @classmethod
    def required_data_names(
            cls, train: bool = True, inference: bool = False
    ) -> Tuple[str, ...]:
        if not inference:
            retval = ("speech", "profile", "binary_labels")
        else:
            # Recognition mode
            retval = ("speech")
        return retval
    @classmethod
    def optional_data_names(
            cls, train: bool = True, inference: bool = False
    ) -> Tuple[str, ...]:
        retval = ()
        assert check_return_type(retval)
        return retval
    @classmethod
    def build_model(cls, args: argparse.Namespace):
        assert check_argument_types()
        # 1. frontend
        if args.input_size is None or args.frontend == "wav_frontend_mel23":
            # Extract features in the model
            frontend_class = frontend_choices.get_class(args.frontend)
            if args.frontend == 'wav_frontend':
                frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
            else:
                frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            # Give features from data-loader
            args.frontend = None
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size
        # 2. Encoder
        encoder_class = encoder_choices.get_class(args.encoder)
        encoder = encoder_class(input_size=input_size, **args.encoder_conf)
        # 3. EncoderDecoderAttractor
        encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
        encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
        # 9. Build model
        model_class = model_choices.get_class(args.model)
        model = model_class(
            frontend=frontend,
            encoder=encoder,
            encoder_decoder_attractor=encoder_decoder_attractor,
            **args.model_conf,
        )
        # 10. Initialize
        if args.init is not None:
            initialize(model, args.init)
        assert check_return_type(model)
        return model
    # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
    @classmethod
    def build_model_from_file(
            cls,
            config_file: Union[Path, str] = None,
            model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            device: str = "cpu",
    ):
        """Build model from the files.
        This method is used for inference or fine-tuning.
        Args:
            config_file: The yaml file saved when training.
            model_file: The model file saved when training.
            cmvn_file: The cmvn file for front-end
            device: Device type, "cpu", "cuda", or "cuda:N".
        """
        assert check_argument_types()
        if config_file is None:
            assert model_file is not None, (
                "The argument 'model_file' must be provided "
                "if the argument 'config_file' is not specified."
            )
            config_file = Path(model_file).parent / "config.yaml"
        else:
            config_file = Path(config_file)
        with config_file.open("r", encoding="utf-8") as f:
            args = yaml.safe_load(f)
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
            )
        if model_file is not None:
            if device == "cuda":
                device = f"cuda:{torch.cuda.current_device()}"
            checkpoint = torch.load(model_file, map_location=device)
            if "state_dict" in checkpoint.keys():
                model.load_state_dict(checkpoint["state_dict"])
            else:
                model.load_state_dict(checkpoint)
        model.to(device)
        return model, args
funasr/utils/timestamp_tools.py
@@ -5,55 +5,69 @@
from typing import Any, List, Tuple, Union
def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None):
def ts_prediction_lfr6_standard(us_alphas,
                       us_peaks,
                       char_list,
                       vad_offset=0.0,
                       force_time_shift=-1.5
                       ):
    if not len(char_list):
        return []
    START_END_THRESHOLD = 5
    MAX_TOKEN_DURATION = 12
    TIME_RATE = 10.0 * 6 / 1000 / 3  #  3 times upsampled
    if len(us_alphas.shape) == 3:
        alphas, cif_peak = us_alphas[0], us_cif_peak[0]  # support inference batch_size=1 only
    if len(us_alphas.shape) == 2:
        _, peaks = us_alphas[0], us_peaks[0]  # support inference batch_size=1 only
    else:
        alphas, cif_peak = us_alphas, us_cif_peak
    num_frames = cif_peak.shape[0]
        _, peaks = us_alphas, us_peaks
    num_frames = peaks.shape[0]
    if char_list[-1] == '</s>':
        char_list = char_list[:-1]
    # char_list = [i for i in text]
    timestamp_list = []
    new_char_list = []
    # for bicif model trained with large data, cif2 actually fires when a character starts
    # so treat the frames between two peaks as the duration of the former token
    fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 1.5
    fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    num_peak = len(fire_place)
    assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    # begin silence
    if fire_place[0] > START_END_THRESHOLD:
        char_list.insert(0, '<sil>')
        # char_list.insert(0, '<sil>')
        timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
        new_char_list.append('<sil>')
    # tokens timestamp
    for i in range(len(fire_place)-1):
        # the peak is always a little ahead of the start time
        # timestamp_list.append([(fire_place[i]-1.2)*TIME_RATE, fire_place[i+1]*TIME_RATE])
        timestamp_list.append([(fire_place[i])*TIME_RATE, fire_place[i+1]*TIME_RATE])
        # cut the duration to token and sil of the 0-weight frames last long
        new_char_list.append(char_list[i])
        if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] <= MAX_TOKEN_DURATION:
            timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE])
        else:
            # cut the duration to token and sil of the 0-weight frames last long
            _split = fire_place[i] + MAX_TOKEN_DURATION
            timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE])
            timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE])
            new_char_list.append('<sil>')
    # tail token and end silence
    # new_char_list.append(char_list[-1])
    if num_frames - fire_place[-1] > START_END_THRESHOLD:
        _end = (num_frames + fire_place[-1]) / 2
        _end = (num_frames + fire_place[-1]) * 0.5
        # _end = fire_place[-1]
        timestamp_list[-1][1] = _end*TIME_RATE
        timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
        char_list.append("<sil>")
        new_char_list.append("<sil>")
    else:
        timestamp_list[-1][1] = num_frames*TIME_RATE
    if begin_time:  # add offset time in model with vad
    if vad_offset:  # add offset time in model with vad
        for i in range(len(timestamp_list)):
            timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0
            timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0
            timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0
            timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
    res_txt = ""
    for char, timestamp in zip(char_list, timestamp_list):
        res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1])
    for char, timestamp in zip(new_char_list, timestamp_list):
        res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
    res = []
    for char, timestamp in zip(char_list, timestamp_list):
    for char, timestamp in zip(new_char_list, timestamp_list):
        if char != '<sil>':
            res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
    return res
    return res_txt, res
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
tests/test_asr_inference_pipeline.py
@@ -451,8 +451,8 @@
    def test_uniasr_2pass_zhcn_16k_common_vocab8358_offline(self):
        inference_pipeline = pipeline(
            task=Tasks.auto_speech_recognition,
            model='damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline')
            task=Tasks.,
            model='damo/speech_UniASauto_speech_recognitionR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline')
        rec_result = inference_pipeline(
            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav',
            param_dict={"decoding_model": "offline"})
tests/test_asr_vad_punc_inference_pipeline.py
New file
@@ -0,0 +1,32 @@
import unittest
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
class TestParaformerInferencePipelines(unittest.TestCase):
    def test_funasr_path(self):
        import funasr
        import os
        logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
    def test_inference_pipeline(self):
        inference_pipeline = pipeline(
            task=Tasks.auto_speech_recognition,
            model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
            model_revision="v1.2.1",
            vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
            vad_model_revision="v1.1.8",
            punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
            punc_model_revision="v1.1.6",
            ngpu=1,
        )
        rec_result = inference_pipeline(
            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
        logger.info("asr_vad_punc inference result: {0}".format(rec_result))
if __name__ == '__main__':
    unittest.main()
tests/test_lm_pipeline.py
New file
@@ -0,0 +1,25 @@
import unittest
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
class TestTransformerInferencePipelines(unittest.TestCase):
    def test_funasr_path(self):
        import funasr
        import os
        logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
    def test_inference_pipeline(self):
        inference_pipeline = pipeline(
            task=Tasks.language_score_prediction,
            model='damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch',
        )
        rec_result = inference_pipeline(text_in="hello 大 家 好 呀")
        logger.info("lm inference result: {0}".format(rec_result))
if __name__ == '__main__':
    unittest.main()
tests/test_punctuation_pipeline.py
New file
@@ -0,0 +1,43 @@
import unittest
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
class TestTransformerInferencePipelines(unittest.TestCase):
    def test_funasr_path(self):
        import funasr
        import os
        logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
    def test_inference_pipeline(self):
        inference_pipeline = pipeline(
            task=Tasks.punctuation,
            model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
            model_revision="v1.1.7",
        )
        inputs = "./egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt"
        rec_result = inference_pipeline(text_in=inputs)
        logger.info("punctuation inference result: {0}".format(rec_result))
    def test_vadrealtime_inference_pipeline(self):
        inference_pipeline = pipeline(
            task=Tasks.punctuation,
            model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
            model_revision="v1.0.0",
        )
        inputs = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
        vads = inputs.split("|")
        cache_out = []
        rec_result_all = "outputs:"
        for vad in vads:
            rec_result = inference_pipeline(text_in=vad, cache=cache_out)
            cache_out = rec_result['cache']
            rec_result_all += rec_result['text']
        logger.info("punctuation inference result: {0}".format(rec_result_all))
if __name__ == '__main__':
    unittest.main()
tests/test_sv_inference_pipeline.py
New file
@@ -0,0 +1,47 @@
import unittest
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
class TestXVectorInferencePipelines(unittest.TestCase):
    def test_funasr_path(self):
        import funasr
        import os
        logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
    def test_inference_pipeline(self):
        inference_sv_pipline = pipeline(
            task=Tasks.speaker_verification,
            model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
        )
        # 提取不同句子的说话人嵌入码
        rec_result = inference_sv_pipline(
            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav')
        enroll = rec_result["spk_embedding"]
        rec_result = inference_sv_pipline(
            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav')
        same = rec_result["spk_embedding"]
        rec_result = inference_sv_pipline(
            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav')
        different = rec_result["spk_embedding"]
        # 对相同的说话人计算余弦相似度
        sv_threshold = 0.9465
        same_cos = np.sum(enroll * same) / (np.linalg.norm(enroll) * np.linalg.norm(same))
        same_cos = max(same_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
        logger.info("Similarity: {}".format(same_cos))
        # 对不同的说话人计算余弦相似度
        diff_cos = np.sum(enroll * different) / (np.linalg.norm(enroll) * np.linalg.norm(different))
        diff_cos = max(diff_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
        logger.info("Similarity: {}".format(diff_cos))
if __name__ == '__main__':
    unittest.main()
tests/test_vad_inference_pipeline.py
New file
@@ -0,0 +1,35 @@
import unittest
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
class TestFSMNInferencePipelines(unittest.TestCase):
    def test_funasr_path(self):
        import funasr
        import os
        logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
    def test_8k(self):
        inference_pipeline = pipeline(
            task=Tasks.voice_activity_detection,
            model="damo/speech_fsmn_vad_zh-cn-8k-common",
        )
        rec_result = inference_pipeline(
            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example_8k.wav')
        logger.info("vad inference result: {0}".format(rec_result))
    def test_16k(self):
        inference_pipeline = pipeline(
            task=Tasks.voice_activity_detection,
            model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
        )
        rec_result = inference_pipeline(
            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav')
        logger.info("vad inference result: {0}".format(rec_result))
if __name__ == '__main__':
    unittest.main()