zhuyunfeng
2023-05-09 b15db52e4e67da8a133a67e8ffa415386de48b40
funasr/bin/asr_inference_paraformer_streaming.py
@@ -8,6 +8,7 @@
import codecs
import tempfile
import requests
import yaml
from pathlib import Path
from typing import Optional
from typing import Sequence
@@ -19,6 +20,7 @@
import numpy as np
import torch
import torchaudio
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
@@ -202,10 +204,15 @@
        assert check_argument_types()
        results = []
        cache_en = cache["encoder"]
        if speech.shape[1] < 16 * 60 and cache["is_final"]:
            cache["last_chunk"] = True
            feats = cache["feats"]
        if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
            if cache_en["start_idx"] == 0:
                return []
            cache_en["tail_chunk"] = True
            feats = cache_en["feats"]
            feats_len = torch.tensor([feats.shape[1]])
            self.asr_model.frontend = None
            results = self.infer(feats, feats_len, cache)
            return results
        else:
            if self.frontend is not None:
                feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"])
@@ -232,7 +239,7 @@
                        feats_len = torch.tensor([feats_chunk2.shape[1]])
                        results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
                        return results_chunk1 + results_chunk2
                        return [" ".join(results_chunk1 + results_chunk2)]
                results = self.infer(feats, feats_len, cache)
@@ -292,12 +299,9 @@
                # Change integer-ids to tokens
                token = self.converter.ids2tokens(token_int)
                token = " ".join(token)
                if self.tokenizer is not None:
                    text = self.tokenizer.tokens2text(token)
                else:
                    text = None
                results.append(text)
                results.append(token)
        # assert check_return_type(results)
        return results
@@ -460,13 +464,23 @@
        array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
        return array
    def _read_yaml(yaml_path: Union[str, Path]) -> Dict:
        if not Path(yaml_path).exists():
            raise FileExistsError(f'The {yaml_path} does not exist.')
        with open(str(yaml_path), 'rb') as f:
            data = yaml.load(f, Loader=yaml.Loader)
        return data
    def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
        if len(cache) > 0:
            return cache
        cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
        config = _read_yaml(asr_train_config)
        enc_output_size = config["encoder_conf"]["output_size"]
        feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
        cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
                    "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
                    "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))}
                    "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
        cache["encoder"] = cache_en
        cache_de = {"decode_fsmn": None}
@@ -476,9 +490,12 @@
    def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
        if len(cache) > 0:
            cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
            config = _read_yaml(asr_train_config)
            enc_output_size = config["encoder_conf"]["output_size"]
            feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
            cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
                        "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))}
                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
            cache["encoder"] = cache_en
            cache_de = {"decode_fsmn": None}
@@ -499,6 +516,8 @@
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
            raw_inputs = _load_bytes(data_path_and_name_and_type[0])
            raw_inputs = torch.tensor(raw_inputs)
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
            raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, np.ndarray):
                raw_inputs = torch.tensor(raw_inputs)
@@ -515,13 +534,32 @@
        # 7 .Start for-loop
        # FIXME(kamo): The output format should be discussed about
        raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
        input_lens = torch.tensor([raw_inputs.shape[1]])
        asr_result_list = []
        cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
        cache["encoder"]["is_final"] = is_final
        asr_result = speech2text(cache, raw_inputs, input_lens)
        item = {'key': "utt", 'value': asr_result}
        item = {}
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
            sample_offset = 0
            speech_length = raw_inputs.shape[1]
            stride_size =  chunk_size[1] * 960
            cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
            final_result = ""
            for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
                if sample_offset + stride_size >= speech_length - 1:
                    stride_size = speech_length - sample_offset
                    cache["encoder"]["is_final"] = True
                else:
                    cache["encoder"]["is_final"] = False
                input_lens = torch.tensor([stride_size])
                asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
                if len(asr_result) != 0:
                    final_result += " ".join(asr_result) + " "
            item = {'key': "utt", 'value': final_result.strip()}
        else:
            input_lens = torch.tensor([raw_inputs.shape[1]])
            cache["encoder"]["is_final"] = is_final
            asr_result = speech2text(cache, raw_inputs, input_lens)
            item = {'key': "utt", 'value': " ".join(asr_result)}
        asr_result_list.append(item)
        if is_final:
            cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1)
@@ -708,14 +746,4 @@
if __name__ == "__main__":
    main()
    # from modelscope.pipelines import pipeline
    # from modelscope.utils.constant import Tasks
    #
    # inference_16k_pipline = pipeline(
    #     task=Tasks.auto_speech_recognition,
    #     model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
    #
    # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
    # print(rec_result)