| | |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import copy |
| | | import torch |
| | | import numpy as np |
| | | import torch.nn.functional as F |
| | |
| | | punc_array = punctuations |
| | | else: |
| | | punc_array = torch.cat([punc_array, punctuations], dim=0) |
| | | # post processing when using word level punc model |
| | | if jieba_usr_dict: |
| | | len_tokens = len(tokens) |
| | | new_punc_array = copy.copy(punc_array).tolist() |
| | | # for i, (token, punc_id) in enumerate(zip(tokens[::-1], punc_array.tolist()[::-1])): |
| | | for i, token in enumerate(tokens[::-1]): |
| | | if '\u0e00' <= token[0] <= '\u9fa5': # ignore en words |
| | | if len(token) > 1: |
| | | num_append = len(token) - 1 |
| | | ind_append = len_tokens - i - 1 |
| | | for _ in range(num_append): |
| | | new_punc_array.insert(ind_append, 1) |
| | | punc_array = torch.tensor(new_punc_array) |
| | | |
| | | result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array} |
| | | results.append(result_i) |
| | | return results, meta_data |