From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/ct_transformer/model.py |  181 +++++++++++++++++++++++++++++++--------------
 1 files changed, 125 insertions(+), 56 deletions(-)

diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index 8c3f043..abc5dfd 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -3,6 +3,7 @@
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 
+import copy
 import torch
 import numpy as np
 import torch.nn.functional as F
@@ -17,7 +18,10 @@
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
 
-
+try:
+    import jieba
+except:
+    pass
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
 else:
@@ -34,6 +38,7 @@
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
     https://arxiv.org/pdf/2003.01309.pdf
     """
+
     def __init__(
         self,
         encoder: str = None,
@@ -55,8 +60,7 @@
         punc_size = len(punc_list)
         if punc_weight is None:
             punc_weight = [1] * punc_size
-        
-        
+
         self.embed = torch.nn.Embedding(vocab_size, embed_unit)
         encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(**encoder_conf)
@@ -69,8 +73,10 @@
         self.sos = sos
         self.eos = eos
         self.sentence_end_id = sentence_end_id
-        
-        
+        self.jieba_usr_dict = None
+        if kwargs.get("jieba_usr_dict", None) is not None:
+            jieba.load_userdict(kwargs["jieba_usr_dict"])
+            self.jieba_usr_dict = jieba
 
     def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs):
         """Compute loss value from buffer sequences.
@@ -104,12 +110,16 @@
 
         """
         y = y.unsqueeze(0)
-        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
+        h, _, cache = self.encoder.forward_one_step(
+            self.embed(y), self._target_mask(y), cache=state
+        )
         h = self.decoder(h[:, -1])
         logp = h.log_softmax(dim=-1).squeeze(0)
         return logp, cache
 
-    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
+    def batch_score(
+        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
+    ) -> Tuple[torch.Tensor, List[Any]]:
         """Score new token batch.
 
         Args:
@@ -131,10 +141,14 @@
             batch_state = None
         else:
             # transpose state of [batch, layer] into [layer, batch]
-            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
+            batch_state = [
+                torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)
+            ]
 
         # batch decoding
-        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
+        h, _, states = self.encoder.forward_one_step(
+            self.embed(ys), self._target_mask(ys), cache=batch_state
+        )
         h = self.decoder(h[:, -1])
         logp = h.log_softmax(dim=-1)
 
@@ -164,12 +178,12 @@
         batch_size = text.size(0)
         # For data parallel
         if max_length is None:
-            text = text[:, :text_lengths.max()]
-            punc = punc[:, :text_lengths.max()]
+            text = text[:, : text_lengths.max()]
+            punc = punc[:, : text_lengths.max()]
         else:
             text = text[:, :max_length]
             punc = punc[:, :max_length]
-    
+
         if self.with_vad():
             # Should be VadRealtimeTransformer
             assert vad_indexes is not None
@@ -177,21 +191,29 @@
         else:
             # Should be TargetDelayTransformer,
             y, _ = self.punc_forward(text, text_lengths)
-    
+
         # Calc negative log likelihood
         # nll: (BxL,)
         if self.training == False:
             _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
             from sklearn.metrics import f1_score
-            f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
-                                indices.squeeze(-1).detach().cpu().numpy(),
-                                average='micro')
+
+            f1_score = f1_score(
+                punc.view(-1).detach().cpu().numpy(),
+                indices.squeeze(-1).detach().cpu().numpy(),
+                average="micro",
+            )
             nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
             return nll, text_lengths
         else:
             self.punc_weight = self.punc_weight.to(punc.device)
-            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
-                                  ignore_index=self.ignore_id)
+            nll = F.cross_entropy(
+                y.view(-1, y.shape[-1]),
+                punc.view(-1),
+                self.punc_weight,
+                reduction="none",
+                ignore_index=self.ignore_id,
+            )
         # nll: (BxL,) -> (BxL,)
         if max_length is None:
             nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
@@ -203,7 +225,6 @@
         # nll: (BxL,) -> (B, L)
         nll = nll.view(batch_size, -1)
         return nll, text_lengths
-
 
     def forward(
         self,
@@ -218,19 +239,20 @@
         ntokens = y_lengths.sum()
         loss = nll.sum() / ntokens
         stats = dict(loss=loss.detach())
-    
+
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
         return loss, stats, weight
-    
-    def inference(self,
-                 data_in,
-                 data_lengths=None,
-                 key: list = None,
-                 tokenizer=None,
-                 frontend=None,
-                 **kwargs,
-                 ):
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
         assert len(data_in) == 1
         text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
         vad_indexes = kwargs.get("vad_indexes", None)
@@ -238,20 +260,14 @@
         # text_lengths = data_lengths[0] if data_lengths is not None else None
         split_size = kwargs.get("split_size", 20)
 
-        jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
-        if jieba_usr_dict and isinstance(jieba_usr_dict, str):
-            import jieba
-            jieba.load_userdict(jieba_usr_dict)
-            jieba_usr_dict = jieba
-            kwargs["jieba_usr_dict"] = "jieba_usr_dict"
-        tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
+        tokens = split_words(text, jieba_usr_dict=self.jieba_usr_dict)
         tokens_int = tokenizer.encode(tokens)
 
         mini_sentences = split_to_mini_sentence(tokens, split_size)
         mini_sentences_id = split_to_mini_sentence(tokens_int, split_size)
         assert len(mini_sentences) == len(mini_sentences_id)
         cache_sent = []
-        cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
+        cache_sent_id = torch.from_numpy(np.array([], dtype="int32"))
         new_mini_sentence = ""
         new_mini_sentence_punc = []
         cache_pop_trigger_limit = 200
@@ -265,15 +281,13 @@
             mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
             data = {
                 "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
-                "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
+                "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype="int32")),
             }
             data = to_device(data, kwargs["device"])
             # y, _ = self.wrapped_model(**data)
             y, _ = self.punc_forward(**data)
             _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
-            punctuations = indices
-            if indices.size()[0] != 1:
-                punctuations = torch.squeeze(indices)
+            punctuations = torch.squeeze(indices, dim=1)
             assert punctuations.size()[0] == len(mini_sentence)
 
             # Search for the last Period/QuestionMark as cache
@@ -281,20 +295,27 @@
                 sentenceEnd = -1
                 last_comma_index = -1
                 for i in range(len(punctuations) - 2, 1, -1):
-                    if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+                    if (
+                        self.punc_list[punctuations[i]] == "銆�"
+                        or self.punc_list[punctuations[i]] == "锛�"
+                    ):
                         sentenceEnd = i
                         break
                     if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
                         last_comma_index = i
 
-                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+                if (
+                    sentenceEnd < 0
+                    and len(mini_sentence) > cache_pop_trigger_limit
+                    and last_comma_index >= 0
+                ):
                     # The sentence it too long, cut off at a comma.
                     sentenceEnd = last_comma_index
                     punctuations[sentenceEnd] = self.sentence_end_id
-                cache_sent = mini_sentence[sentenceEnd + 1:]
-                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
-                mini_sentence = mini_sentence[0:sentenceEnd + 1]
-                punctuations = punctuations[0:sentenceEnd + 1]
+                cache_sent = mini_sentence[sentenceEnd + 1 :]
+                cache_sent_id = mini_sentence_id[sentenceEnd + 1 :]
+                mini_sentence = mini_sentence[0 : sentenceEnd + 1]
+                punctuations = punctuations[0 : sentenceEnd + 1]
 
             # if len(punctuations) == 0:
             #    continue
@@ -303,13 +324,20 @@
             new_mini_sentence_punc += [int(x) for x in punctuations_np]
             words_with_punc = []
             for i in range(len(mini_sentence)):
-                if (i==0 or self.punc_list[punctuations[i-1]] == "銆�" or self.punc_list[punctuations[i-1]] == "锛�") and len(mini_sentence[i][0].encode()) == 1:
+                if (
+                    i == 0
+                    or self.punc_list[punctuations[i - 1]] == "銆�"
+                    or self.punc_list[punctuations[i - 1]] == "锛�"
+                ) and len(mini_sentence[i][0].encode()) == 1:
                     mini_sentence[i] = mini_sentence[i].capitalize()
                 if i == 0:
                     if len(mini_sentence[i][0].encode()) == 1:
                         mini_sentence[i] = " " + mini_sentence[i]
                 if i > 0:
-                    if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
+                    if (
+                        len(mini_sentence[i][0].encode()) == 1
+                        and len(mini_sentence[i - 1][0].encode()) == 1
+                    ):
                         mini_sentence[i] = " " + mini_sentence[i]
                 words_with_punc.append(mini_sentence[i])
                 if self.punc_list[punctuations[i]] != "_":
@@ -329,23 +357,64 @@
             if mini_sentence_i == len(mini_sentences) - 1:
                 if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
                     new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
+                        self.sentence_end_id
+                    ]
                 elif new_mini_sentence[-1] == ",":
                     new_mini_sentence_out = new_mini_sentence[:-1] + "."
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
-                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())==0:
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
+                        self.sentence_end_id
+                    ]
+                elif (
+                    new_mini_sentence[-1] != "銆�"
+                    and new_mini_sentence[-1] != "锛�"
+                    and len(new_mini_sentence[-1].encode()) != 1
+                ):
                     new_mini_sentence_out = new_mini_sentence + "銆�"
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
-                elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
+                        self.sentence_end_id
+                    ]
+                    if len(punctuations):
+                        punctuations[-1] = 2
+                elif (
+                    new_mini_sentence[-1] != "."
+                    and new_mini_sentence[-1] != "?"
+                    and len(new_mini_sentence[-1].encode()) == 1
+                ):
                     new_mini_sentence_out = new_mini_sentence + "."
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
+                        self.sentence_end_id
+                    ]
+                    if len(punctuations):
+                        punctuations[-1] = 2
             # keep a punctuations array for punc segment
             if punc_array is None:
                 punc_array = punctuations
             else:
                 punc_array = torch.cat([punc_array, punctuations], dim=0)
+
+        # post processing when using word level punc model
+        if self.jieba_usr_dict is not None:
+            punc_array = punc_array.reshape(-1)
+            len_tokens = len(tokens)
+            new_punc_array = copy.copy(punc_array).tolist()
+            # for i, (token, punc_id) in enumerate(zip(tokens[::-1], punc_array.tolist()[::-1])):
+            for i, token in enumerate(tokens[::-1]):
+                if "\u0e00" <= token[0] <= "\u9fa5":  # ignore en words
+                    if len(token) > 1:
+                        num_append = len(token) - 1
+                        ind_append = len_tokens - i - 1
+                        for _ in range(num_append):
+                            new_punc_array.insert(ind_append, 1)
+            punc_array = torch.tensor(new_punc_array)
+
         result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
         results.append(result_i)
-    
         return results, meta_data
 
+    def export(self, **kwargs):
+
+        from .export_meta import export_rebuild_model
+
+        models = export_rebuild_model(model=self, **kwargs)
+        return models

--
Gitblit v1.9.1