update proc for oov in hotword onnx inference
| | |
| | | model = ContextualParaformer(model_dir, batch_size=1) |
| | | |
| | | wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())] |
| | | hotwords = '随机热词 各种热词 魔搭 阿里巴巴' |
| | | hotwords = '随机热词 各种热词 魔搭 阿里巴巴 仏' |
| | | |
| | | result = model(wav_path, hotwords) |
| | | print(result) |
| | |
| | | hotwords_length = torch.Tensor(hotwords_length).to(torch.int32) |
| | | # hotwords.append('<s>') |
| | | def word_map(word): |
| | | return torch.tensor([self.vocab[i] for i in word]) |
| | | hotwords = [] |
| | | for c in word: |
| | | if c not in self.vocab.keys(): |
| | | hotwords.append(8403) |
| | | logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word)) |
| | | else: |
| | | hotwords.append(self.vocab[c]) |
| | | return torch.tensor(hotwords) |
| | | hotword_int = [word_map(i) for i in hotwords] |
| | | # import pdb; pdb.set_trace() |
| | | hotword_int.append(torch.tensor([1])) |