From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/ct_transformer/model.py |  294 ++++++++++++++++++++++++++++++++++++++++++++++++++--------
 1 files changed, 251 insertions(+), 43 deletions(-)

diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index 31b2af2..abc5dfd 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -1,23 +1,48 @@
-from typing import Any
-from typing import List
-from typing import Tuple
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
+import copy
 import torch
-import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Any, List, Tuple, Optional
 
-from funasr.utils.register import register_class, registry_tables
+from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
 
-@register_class("model_classes", "CTTransformer")
-class CTTransformer(nn.Module):
+try:
+    import jieba
+except:
+    pass
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+    # Nothing to do if torch<1.6.0
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
+
+@tables.register("model_classes", "CTTransformer")
+class CTTransformer(torch.nn.Module):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
     https://arxiv.org/pdf/2003.01309.pdf
     """
+
     def __init__(
         self,
         encoder: str = None,
-        encoder_conf: str = None,
+        encoder_conf: dict = None,
         vocab_size: int = -1,
         punc_list: list = None,
         punc_weight: list = None,
@@ -27,6 +52,7 @@
         ignore_id: int = -1,
         sos: int = 1,
         eos: int = 2,
+        sentence_end_id: int = 3,
         **kwargs,
     ):
         super().__init__()
@@ -34,23 +60,25 @@
         punc_size = len(punc_list)
         if punc_weight is None:
             punc_weight = [1] * punc_size
-        
-        
-        self.embed = nn.Embedding(vocab_size, embed_unit)
-        encoder_class = registry_tables.encoder_classes.get(encoder.lower())
+
+        self.embed = torch.nn.Embedding(vocab_size, embed_unit)
+        encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(**encoder_conf)
 
-        self.decoder = nn.Linear(att_unit, punc_size)
+        self.decoder = torch.nn.Linear(att_unit, punc_size)
         self.encoder = encoder
         self.punc_list = punc_list
         self.punc_weight = punc_weight
         self.ignore_id = ignore_id
         self.sos = sos
         self.eos = eos
-        
-        
+        self.sentence_end_id = sentence_end_id
+        self.jieba_usr_dict = None
+        if kwargs.get("jieba_usr_dict", None) is not None:
+            jieba.load_userdict(kwargs["jieba_usr_dict"])
+            self.jieba_usr_dict = jieba
 
-    def punc_forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
+    def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs):
         """Compute loss value from buffer sequences.
 
         Args:
@@ -58,7 +86,7 @@
             hidden (torch.Tensor): Target ids. (batch, len)
 
         """
-        x = self.embed(input)
+        x = self.embed(text)
         # mask = self._target_mask(input)
         h, _, _ = self.encoder(x, text_lengths)
         y = self.decoder(h)
@@ -82,12 +110,16 @@
 
         """
         y = y.unsqueeze(0)
-        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
+        h, _, cache = self.encoder.forward_one_step(
+            self.embed(y), self._target_mask(y), cache=state
+        )
         h = self.decoder(h[:, -1])
         logp = h.log_softmax(dim=-1).squeeze(0)
         return logp, cache
 
-    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
+    def batch_score(
+        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
+    ) -> Tuple[torch.Tensor, List[Any]]:
         """Score new token batch.
 
         Args:
@@ -109,10 +141,14 @@
             batch_state = None
         else:
             # transpose state of [batch, layer] into [layer, batch]
-            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
+            batch_state = [
+                torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)
+            ]
 
         # batch decoding
-        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
+        h, _, states = self.encoder.forward_one_step(
+            self.embed(ys), self._target_mask(ys), cache=batch_state
+        )
         h = self.decoder(h[:, -1])
         logp = h.log_softmax(dim=-1)
 
@@ -142,12 +178,12 @@
         batch_size = text.size(0)
         # For data parallel
         if max_length is None:
-            text = text[:, :text_lengths.max()]
-            punc = punc[:, :text_lengths.max()]
+            text = text[:, : text_lengths.max()]
+            punc = punc[:, : text_lengths.max()]
         else:
             text = text[:, :max_length]
             punc = punc[:, :max_length]
-    
+
         if self.with_vad():
             # Should be VadRealtimeTransformer
             assert vad_indexes is not None
@@ -155,21 +191,29 @@
         else:
             # Should be TargetDelayTransformer,
             y, _ = self.punc_forward(text, text_lengths)
-    
+
         # Calc negative log likelihood
         # nll: (BxL,)
         if self.training == False:
             _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
             from sklearn.metrics import f1_score
-            f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
-                                indices.squeeze(-1).detach().cpu().numpy(),
-                                average='micro')
+
+            f1_score = f1_score(
+                punc.view(-1).detach().cpu().numpy(),
+                indices.squeeze(-1).detach().cpu().numpy(),
+                average="micro",
+            )
             nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
             return nll, text_lengths
         else:
             self.punc_weight = self.punc_weight.to(punc.device)
-            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
-                                  ignore_index=self.ignore_id)
+            nll = F.cross_entropy(
+                y.view(-1, y.shape[-1]),
+                punc.view(-1),
+                self.punc_weight,
+                reduction="none",
+                ignore_index=self.ignore_id,
+            )
         # nll: (BxL,) -> (BxL,)
         if max_length is None:
             nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
@@ -182,7 +226,6 @@
         nll = nll.view(batch_size, -1)
         return nll, text_lengths
 
-
     def forward(
         self,
         text: torch.Tensor,
@@ -191,22 +234,187 @@
         punc_lengths: torch.Tensor,
         vad_indexes: Optional[torch.Tensor] = None,
         vad_indexes_lengths: Optional[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+    ):
         nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
         ntokens = y_lengths.sum()
         loss = nll.sum() / ntokens
         stats = dict(loss=loss.detach())
-    
+
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
         return loss, stats, weight
-    
-    def generate(self,
-                  text: torch.Tensor,
-                  text_lengths: torch.Tensor,
-                  vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
-        if self.with_vad():
-            assert vad_indexes is not None
-            return self.punc_forward(text, text_lengths, vad_indexes)
-        else:
-            return self.punc_forward(text, text_lengths)
\ No newline at end of file
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+        assert len(data_in) == 1
+        text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
+        vad_indexes = kwargs.get("vad_indexes", None)
+        # text = data_in[0]
+        # text_lengths = data_lengths[0] if data_lengths is not None else None
+        split_size = kwargs.get("split_size", 20)
+
+        tokens = split_words(text, jieba_usr_dict=self.jieba_usr_dict)
+        tokens_int = tokenizer.encode(tokens)
+
+        mini_sentences = split_to_mini_sentence(tokens, split_size)
+        mini_sentences_id = split_to_mini_sentence(tokens_int, split_size)
+        assert len(mini_sentences) == len(mini_sentences_id)
+        cache_sent = []
+        cache_sent_id = torch.from_numpy(np.array([], dtype="int32"))
+        new_mini_sentence = ""
+        new_mini_sentence_punc = []
+        cache_pop_trigger_limit = 200
+        results = []
+        meta_data = {}
+        punc_array = None
+        for mini_sentence_i in range(len(mini_sentences)):
+            mini_sentence = mini_sentences[mini_sentence_i]
+            mini_sentence_id = mini_sentences_id[mini_sentence_i]
+            mini_sentence = cache_sent + mini_sentence
+            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+            data = {
+                "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
+                "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype="int32")),
+            }
+            data = to_device(data, kwargs["device"])
+            # y, _ = self.wrapped_model(**data)
+            y, _ = self.punc_forward(**data)
+            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
+            punctuations = torch.squeeze(indices, dim=1)
+            assert punctuations.size()[0] == len(mini_sentence)
+
+            # Search for the last Period/QuestionMark as cache
+            if mini_sentence_i < len(mini_sentences) - 1:
+                sentenceEnd = -1
+                last_comma_index = -1
+                for i in range(len(punctuations) - 2, 1, -1):
+                    if (
+                        self.punc_list[punctuations[i]] == "銆�"
+                        or self.punc_list[punctuations[i]] == "锛�"
+                    ):
+                        sentenceEnd = i
+                        break
+                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
+                        last_comma_index = i
+
+                if (
+                    sentenceEnd < 0
+                    and len(mini_sentence) > cache_pop_trigger_limit
+                    and last_comma_index >= 0
+                ):
+                    # The sentence it too long, cut off at a comma.
+                    sentenceEnd = last_comma_index
+                    punctuations[sentenceEnd] = self.sentence_end_id
+                cache_sent = mini_sentence[sentenceEnd + 1 :]
+                cache_sent_id = mini_sentence_id[sentenceEnd + 1 :]
+                mini_sentence = mini_sentence[0 : sentenceEnd + 1]
+                punctuations = punctuations[0 : sentenceEnd + 1]
+
+            # if len(punctuations) == 0:
+            #    continue
+
+            punctuations_np = punctuations.cpu().numpy()
+            new_mini_sentence_punc += [int(x) for x in punctuations_np]
+            words_with_punc = []
+            for i in range(len(mini_sentence)):
+                if (
+                    i == 0
+                    or self.punc_list[punctuations[i - 1]] == "銆�"
+                    or self.punc_list[punctuations[i - 1]] == "锛�"
+                ) and len(mini_sentence[i][0].encode()) == 1:
+                    mini_sentence[i] = mini_sentence[i].capitalize()
+                if i == 0:
+                    if len(mini_sentence[i][0].encode()) == 1:
+                        mini_sentence[i] = " " + mini_sentence[i]
+                if i > 0:
+                    if (
+                        len(mini_sentence[i][0].encode()) == 1
+                        and len(mini_sentence[i - 1][0].encode()) == 1
+                    ):
+                        mini_sentence[i] = " " + mini_sentence[i]
+                words_with_punc.append(mini_sentence[i])
+                if self.punc_list[punctuations[i]] != "_":
+                    punc_res = self.punc_list[punctuations[i]]
+                    if len(mini_sentence[i][0].encode()) == 1:
+                        if punc_res == "锛�":
+                            punc_res = ","
+                        elif punc_res == "銆�":
+                            punc_res = "."
+                        elif punc_res == "锛�":
+                            punc_res = "?"
+                    words_with_punc.append(punc_res)
+            new_mini_sentence += "".join(words_with_punc)
+            # Add Period for the end of the sentence
+            new_mini_sentence_out = new_mini_sentence
+            new_mini_sentence_punc_out = new_mini_sentence_punc
+            if mini_sentence_i == len(mini_sentences) - 1:
+                if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
+                    new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
+                        self.sentence_end_id
+                    ]
+                elif new_mini_sentence[-1] == ",":
+                    new_mini_sentence_out = new_mini_sentence[:-1] + "."
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
+                        self.sentence_end_id
+                    ]
+                elif (
+                    new_mini_sentence[-1] != "銆�"
+                    and new_mini_sentence[-1] != "锛�"
+                    and len(new_mini_sentence[-1].encode()) != 1
+                ):
+                    new_mini_sentence_out = new_mini_sentence + "銆�"
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
+                        self.sentence_end_id
+                    ]
+                    if len(punctuations):
+                        punctuations[-1] = 2
+                elif (
+                    new_mini_sentence[-1] != "."
+                    and new_mini_sentence[-1] != "?"
+                    and len(new_mini_sentence[-1].encode()) == 1
+                ):
+                    new_mini_sentence_out = new_mini_sentence + "."
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
+                        self.sentence_end_id
+                    ]
+                    if len(punctuations):
+                        punctuations[-1] = 2
+            # keep a punctuations array for punc segment
+            if punc_array is None:
+                punc_array = punctuations
+            else:
+                punc_array = torch.cat([punc_array, punctuations], dim=0)
+
+        # post processing when using word level punc model
+        if self.jieba_usr_dict is not None:
+            punc_array = punc_array.reshape(-1)
+            len_tokens = len(tokens)
+            new_punc_array = copy.copy(punc_array).tolist()
+            # for i, (token, punc_id) in enumerate(zip(tokens[::-1], punc_array.tolist()[::-1])):
+            for i, token in enumerate(tokens[::-1]):
+                if "\u0e00" <= token[0] <= "\u9fa5":  # ignore en words
+                    if len(token) > 1:
+                        num_append = len(token) - 1
+                        ind_append = len_tokens - i - 1
+                        for _ in range(num_append):
+                            new_punc_array.insert(ind_append, 1)
+            punc_array = torch.tensor(new_punc_array)
+
+        result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
+        results.append(result_i)
+        return results, meta_data
+
+    def export(self, **kwargs):
+
+        from .export_meta import export_rebuild_model
+
+        models = export_rebuild_model(model=self, **kwargs)
+        return models

--
Gitblit v1.9.1