游雁
2023-07-05 4e2fe544ae37174a3e09dfcdbbdae5abfe711e53
funasr/bin/asr_inference_launch.py
@@ -19,8 +19,8 @@
import numpy as np
import torch
import torchaudio
import soundfile
import yaml
from typeguard import check_argument_types
from funasr.bin.asr_infer import Speech2Text
from funasr.bin.asr_infer import Speech2TextMFCCA
@@ -79,7 +79,6 @@
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if batch_size > 1:
@@ -239,7 +238,6 @@
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
@@ -257,8 +255,10 @@
    if param_dict is not None:
        hotword_list_or_file = param_dict.get('hotword')
        export_mode = param_dict.get("export_mode", False)
        clas_scale = param_dict.get('clas_scale', 1.0)
    else:
        hotword_list_or_file = None
        clas_scale = 1.0
    if kwargs.get("device", None) == "cpu":
        ngpu = 0
@@ -291,6 +291,7 @@
        penalty=penalty,
        nbest=nbest,
        hotword_list_or_file=hotword_list_or_file,
        clas_scale=clas_scale,
    )
    speech2text = Speech2TextParaformer(**speech2text_kwargs)
@@ -480,7 +481,6 @@
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
@@ -620,6 +620,22 @@
            sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
            results_sorted = []
            
            if not len(sorted_data):
                key = keys[0]
                # no active segments after VAD
                if writer is not None:
                    # Write empty results
                    ibest_writer["token"][key] = ""
                    ibest_writer["token_int"][key] = ""
                    ibest_writer["vad"][key] = ""
                    ibest_writer["text"][key] = ""
                    ibest_writer["text_with_punc"][key] = ""
                    if use_timestamp:
                        ibest_writer["time_stamp"][key] = ""
                logging.info("decoding, utt: {}, empty speech".format(key))
                continue
            batch_size_token_ms = batch_size_token*60
            if speech2text.device == "cpu":
                batch_size_token_ms = 0
@@ -748,7 +764,6 @@
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    if word_lm_train_config is not None:
        raise NotImplementedError("Word LM is not implemented")
@@ -863,7 +878,13 @@
            raw_inputs = _load_bytes(data_path_and_name_and_type[0])
            raw_inputs = torch.tensor(raw_inputs)
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
            raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
            try:
                raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
            except:
                raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
                if raw_inputs.ndim == 2:
                    raw_inputs = raw_inputs[:, 0]
                raw_inputs = torch.tensor(raw_inputs)
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, np.ndarray):
                raw_inputs = torch.tensor(raw_inputs)
@@ -950,7 +971,6 @@
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if batch_size > 1:
@@ -1119,7 +1139,6 @@
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if batch_size > 1:
@@ -1307,7 +1326,6 @@
        right_context: Number of frames in right context AFTER subsampling.
        display_partial_hypotheses: Whether to display partial hypotheses.
    """
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
@@ -1350,10 +1368,7 @@
        left_context=left_context,
        right_context=right_context,
    )
    speech2text = Speech2TextTransducer.from_pretrained(
        model_tag=model_tag,
        **speech2text_kwargs,
    )
    speech2text = Speech2TextTransducer(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
@@ -1457,7 +1472,6 @@
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if word_lm_train_config is not None:
@@ -1606,6 +1620,8 @@
        return inference_mfcca(**kwargs)
    elif mode == "rnnt":
        return inference_transducer(**kwargs)
    elif mode == "bat":
        return inference_transducer(**kwargs)
    elif mode == "sa_asr":
        return inference_sa_asr(**kwargs)
    else: