| | |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | |
| | | |
| | | header_colors = '\033[95m' |
| | | end_colors = '\033[0m' |
| | | |
| | | |
| | | class Speech2Text: |
| | | """Speech2Text class |
| | | |
| | | Examples: |
| | | >>> import soundfile |
| | | >>> speech2text = Speech2Text("asr_config.yml", "asr.pth") |
| | | >>> speech2text = Speech2Text("asr_config.yml", "asr.pb") |
| | | >>> audio, rate = soundfile.read("speech.wav") |
| | | >>> speech2text(audio) |
| | | [(text, token, token_int, hypothesis object), ...] |
| | |
| | | |
| | | # Change integer-ids to tokens |
| | | token = self.converter.ids2tokens(token_int) |
| | | token = list(filter(lambda x: x != "<gbg>", token)) |
| | | |
| | | if self.tokenizer is not None: |
| | | text = self.tokenizer.tokens2text(token) |
| | |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | ncpu = kwargs.get("ncpu", 1) |
| | | torch.set_num_threads(ncpu) |
| | | if batch_size > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | if word_lm_train_config is not None: |
| | |
| | | else: |
| | | device = "cpu" |
| | | |
| | | if param_dict is not None and "decoding_model" in param_dict: |
| | | if param_dict["decoding_model"] == "fast": |
| | | decoding_ind = 0 |
| | | decoding_mode = "model1" |
| | | elif param_dict["decoding_model"] == "normal": |
| | | decoding_ind = 0 |
| | | decoding_mode = "model2" |
| | | elif param_dict["decoding_model"] == "offline": |
| | | decoding_ind = 1 |
| | | decoding_mode = "model2" |
| | | else: |
| | | raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"])) |
| | | |
| | | # 1. Set random-seed |
| | | set_all_random_seed(seed) |
| | | |
| | |
| | | if isinstance(raw_inputs, torch.Tensor): |
| | | raw_inputs = raw_inputs.numpy() |
| | | data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] |
| | | if param_dict is not None and "decoding_model" in param_dict: |
| | | if param_dict["decoding_model"] == "fast": |
| | | speech2text.decoding_ind = 0 |
| | | speech2text.decoding_mode = "model1" |
| | | elif param_dict["decoding_model"] == "normal": |
| | | speech2text.decoding_ind = 0 |
| | | speech2text.decoding_mode = "model2" |
| | | elif param_dict["decoding_model"] == "offline": |
| | | speech2text.decoding_ind = 1 |
| | | speech2text.decoding_mode = "model2" |
| | | else: |
| | | raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"])) |
| | | loader = ASRTask.build_streaming_iterator( |
| | | data_path_and_name_and_type, |
| | | dtype=dtype, |
| | |
| | | ibest_writer["score"][key] = str(hyp.score) |
| | | |
| | | if text is not None: |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token) |
| | | item = {'key': key, 'value': text_postprocessed} |
| | | asr_result_list.append(item) |
| | | finish_count += 1 |
| | | asr_utils.print_progress(finish_count / file_count) |
| | | if writer is not None: |
| | | ibest_writer["text"][key] = text |
| | | ibest_writer["text"][key] = " ".join(word_lists) |
| | | return asr_result_list |
| | | |
| | | return _forward |