zhifu gao
2023-04-13 33681507e1d3468f6a8670e82c47a1b69cb4a394
Merge pull request #342 from alibaba-damo-academy/dev_cmz

fix task.py with no dest_sample_rate task; fix bug in train and infer
5个文件已修改
1个文件已删除
57 ■■■■ 已修改文件
funasr/bin/punc_train_vadrealtime.py 44 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punctuation_infer_vadrealtime.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/utils/tokenize.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/preprocessor.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/abs_task.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_train_vadrealtime.py
File was deleted
funasr/bin/punctuation_infer_vadrealtime.py
@@ -90,7 +90,7 @@
            data = {
                "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
                "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
                "vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')),
                "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
            }
            data = to_device(data, self.device)
            y, _ = self.wrapped_model(**data)
funasr/datasets/large_datasets/utils/tokenize.py
@@ -47,8 +47,8 @@
    length = len(text)
    for i in range(length):
        x = text[i]
        if i == length-1 and "punc" in data and text[i].startswith("vad:"):
            vad = x[-1][4:]
        if i == length-1 and "punc" in data and x.startswith("vad:"):
            vad = x[4:]
            if len(vad) == 0:
                vad = -1
            else:
funasr/datasets/preprocessor.py
@@ -786,6 +786,7 @@
    ) -> Dict[str, np.ndarray]:
        for i in range(self.num_tokenizer):
            text_name = self.text_name[i]
            #import pdb; pdb.set_trace()
            if text_name in data and self.tokenizer[i] is not None:
                text = data[text_name]
                text = self.text_cleaner(text)
@@ -800,7 +801,7 @@
                    data[self.vad_name] = np.array([vad], dtype=np.int64)
                text_ints = self.token_id_converter[i].tokens2ids(tokens)
                data[text_name] = np.array(text_ints, dtype=np.int64)
        return data
def split_to_mini_sentence(words: list, word_limit: int = 20):
    assert word_limit > 1
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -159,7 +159,7 @@
            data = {
                "input": mini_sentence_id[None,:],
                "text_lengths": np.array([text_length], dtype='int32'),
                "vad_mask": self.vad_mask(text_length, len(cache) - 1)[None, None, :, :].astype(np.float32),
                "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
                "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
            }
            try:
funasr/tasks/abs_task.py
@@ -1587,6 +1587,8 @@
                dest_sample_rate = args.frontend_conf["fs"]
            else:
                dest_sample_rate = 16000
        else:
            dest_sample_rate = 16000
        dataset = ESPnetDataset(
            iter_options.data_path_and_name_and_type,