游雁
2023-04-21 3cd3473bf7a3b41484baa86d9092248d78e7af39
funasr/bin/asr_inference_uniasr.py
@@ -37,16 +37,13 @@
from funasr.models.frontend.wav_frontend import WavFrontend
header_colors = '\033[95m'
end_colors = '\033[0m'
class Speech2Text:
    """Speech2Text class
    Examples:
        >>> import soundfile
        >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
        >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2text(audio)
        [(text, token, token_int, hypothesis object), ...]
@@ -261,6 +258,7 @@
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)
            token = list(filter(lambda x: x != "<gbg>", token))
            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
@@ -381,6 +379,8 @@
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if word_lm_train_config is not None:
@@ -398,6 +398,19 @@
    else:
        device = "cpu"
    
    if param_dict is not None and "decoding_model" in param_dict:
        if param_dict["decoding_model"] == "fast":
            decoding_ind = 0
            decoding_mode = "model1"
        elif param_dict["decoding_model"] == "normal":
            decoding_ind = 0
            decoding_mode = "model2"
        elif param_dict["decoding_model"] == "offline":
            decoding_ind = 1
            decoding_mode = "model2"
        else:
            raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
    # 1. Set random-seed
    set_all_random_seed(seed)
@@ -440,18 +453,6 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        if param_dict is not None and "decoding_model" in param_dict:
            if param_dict["decoding_model"] == "fast":
                speech2text.decoding_ind = 0
                speech2text.decoding_mode = "model1"
            elif param_dict["decoding_model"] == "normal":
                speech2text.decoding_ind = 0
                speech2text.decoding_mode = "model2"
            elif param_dict["decoding_model"] == "offline":
                speech2text.decoding_ind = 1
                speech2text.decoding_mode = "model2"
            else:
                raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
@@ -505,13 +506,13 @@
                    ibest_writer["score"][key] = str(hyp.score)
    
                if text is not None:
                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                    text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
                    asr_result_list.append(item)
                    finish_count += 1
                    asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text
                        ibest_writer["text"][key] = " ".join(word_lists)
        return asr_result_list
    
    return _forward