zhifu gao
2023-04-14 2cc6b504537fd4f54a7063f0b102917c61f05456
Merge pull request #352 from alibaba-damo-academy/dev_lhn2

support wav_file input
2个文件已修改
29 ■■■■■ 已修改文件
funasr/bin/asr_inference_paraformer_streaming.py 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/predictor/cif.py 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_streaming.py
@@ -19,6 +19,7 @@
import numpy as np
import torch
import torchaudio
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
@@ -607,17 +608,21 @@
    ):
        # 3. Build data-iterator
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
            raw_inputs = _load_bytes(data_path_and_name_and_type[0])
            raw_inputs = torch.tensor(raw_inputs)
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, np.ndarray):
                raw_inputs = torch.tensor(raw_inputs)
        is_final = False
        if param_dict is not None and "cache" in param_dict:
            cache = param_dict["cache"]
        if param_dict is not None and "is_final" in param_dict:
            is_final = param_dict["is_final"]
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
            raw_inputs = _load_bytes(data_path_and_name_and_type[0])
            raw_inputs = torch.tensor(raw_inputs)
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
            raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
            is_final = True
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, np.ndarray):
                raw_inputs = torch.tensor(raw_inputs)
        # 7 .Start for-loop
        # FIXME(kamo): The output format should be discussed about
        asr_result_list = []
funasr/models/predictor/cif.py
@@ -234,6 +234,7 @@
        last_fire_place = len_time - 1
        last_fire_remainds = 0.0
        pre_alphas_length = 0
        last_fire = False
 
        mask_chunk_peak_predictor = None
        if cache is not None:
@@ -251,10 +252,15 @@
            if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold:
                last_fire_place = len_time - 1 - i
                last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold
                last_fire = True
                break
        last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
        cache["cif_hidden"] = hidden[:, last_fire_place:, :]
        cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
        if last_fire:
           last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
           cache["cif_hidden"] = hidden[:, last_fire_place:, :]
           cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
        else:
           cache["cif_hidden"] = hidden
           cache["cif_alphas"] = alphas
        token_num_int = token_num.floor().type(torch.int32).item()
        return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak