aky15
2023-04-12 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f
funasr/bin/asr_inference_uniasr.py
@@ -37,16 +37,13 @@
from funasr.models.frontend.wav_frontend import WavFrontend
header_colors = '\033[95m'
end_colors = '\033[0m'
class Speech2Text:
    """Speech2Text class
    Examples:
        >>> import soundfile
        >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
        >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2text(audio)
        [(text, token, token_int, hypothesis object), ...]
@@ -261,6 +258,7 @@
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)
            token = list(filter(lambda x: x != "<gbg>", token))
            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
@@ -397,6 +395,19 @@
        device = "cuda"
    else:
        device = "cpu"
    if param_dict is not None and "decoding_model" in param_dict:
        if param_dict["decoding_model"] == "fast":
            decoding_ind = 0
            decoding_mode = "model1"
        elif param_dict["decoding_model"] == "normal":
            decoding_ind = 0
            decoding_mode = "model2"
        elif param_dict["decoding_model"] == "offline":
            decoding_ind = 1
            decoding_mode = "model2"
        else:
            raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
    # 1. Set random-seed
    set_all_random_seed(seed)
@@ -433,6 +444,7 @@
                 output_dir_v2: Optional[str] = None,
                 fs: dict = None,
                 param_dict: dict = None,
                 **kwargs,
                 ):
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
@@ -492,13 +504,13 @@
                    ibest_writer["score"][key] = str(hyp.score)
    
                if text is not None:
                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                    text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
                    asr_result_list.append(item)
                    finish_count += 1
                    asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text
                        ibest_writer["text"][key] = " ".join(word_lists)
        return asr_result_list
    
    return _forward