From c0e72dd1ba86c19205ee633673b2497d18a68077 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 11 一月 2024 17:36:59 +0800
Subject: [PATCH] Merge branch 'funasr1.0' of github.com:alibaba-damo-academy/FunASR into funasr1.0 add

---
 funasr/models/ct_transformer/model.py |   10 ++++++++--
 1 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index e32aa25..fbf1804 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -239,6 +239,7 @@
         cache_pop_trigger_limit = 200
         results = []
         meta_data = {}
+        punc_array = None
         for mini_sentence_i in range(len(mini_sentences)):
             mini_sentence = mini_sentences[mini_sentence_i]
             mini_sentence_id = mini_sentences_id[mini_sentence_i]
@@ -320,8 +321,13 @@
                 elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
                     new_mini_sentence_out = new_mini_sentence + "."
                     new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
-
-        result_i = {"key": key[0], "text": new_mini_sentence_out}
+            # keep a punctuations array for punc segment
+            if punc_array is None:
+                punc_array = punctuations
+            else:
+                punc_array = torch.cat([punc_array, punctuations], dim=0)
+        
+        result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
         results.append(result_i)
     
         return results, meta_data

--
Gitblit v1.9.1