From 831c48a886a5b879b46fc422cdc9c0898f6c34ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 15 一月 2024 11:07:33 +0800
Subject: [PATCH] download configuration.json
---
funasr/models/ct_transformer/model.py | 29 +++++++++++++++++++++--------
1 files changed, 21 insertions(+), 8 deletions(-)
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index 24a6aea..7187f45 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -11,6 +11,7 @@
import torch
import torch.nn as nn
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
+from funasr.utils.load_utils import load_audio_text_image_video
from funasr.register import tables
@@ -59,7 +60,7 @@
- def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
+ def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs):
"""Compute loss value from buffer sequences.
Args:
@@ -219,13 +220,19 @@
**kwargs,
):
assert len(data_in) == 1
-
+ text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
vad_indexes = kwargs.get("vad_indexes", None)
- text = data_in[0]
- text_lengths = data_lengths[0] if data_lengths is not None else None
+ # text = data_in[0]
+ # text_lengths = data_lengths[0] if data_lengths is not None else None
split_size = kwargs.get("split_size", 20)
-
- tokens = split_words(text)
+
+ jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
+ if jieba_usr_dict and isinstance(jieba_usr_dict, str):
+ import jieba
+ jieba.load_userdict(jieba_usr_dict)
+ jieba_usr_dict = jieba
+ kwargs["jieba_usr_dict"] = "jieba_usr_dict"
+ tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
tokens_int = tokenizer.encode(tokens)
mini_sentences = split_to_mini_sentence(tokens, split_size)
@@ -238,6 +245,7 @@
cache_pop_trigger_limit = 200
results = []
meta_data = {}
+ punc_array = None
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
@@ -319,8 +327,13 @@
elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
new_mini_sentence_out = new_mini_sentence + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
-
- result_i = {"key": key[0], "text": new_mini_sentence_out}
+ # keep a punctuations array for punc segment
+ if punc_array is None:
+ punc_array = punctuations
+ else:
+ punc_array = torch.cat([punc_array, punctuations], dim=0)
+
+ result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
results.append(result_i)
return results, meta_data
--
Gitblit v1.9.1