From a7d7a0f3a2e7cd44a337ced34e3536b12ccb534e Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 19:24:44 +0800
Subject: [PATCH] Dev gzf (#1467)

---
 funasr/models/ct_transformer/model.py |  118 +++++++++++++++++++++++++++++++++++++++++++++++++---------
 1 files changed, 99 insertions(+), 19 deletions(-)

diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index 5fb3ed4..88ee867 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -1,22 +1,34 @@
-from typing import Any
-from typing import List
-from typing import Tuple
-from typing import Optional
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import copy
+import torch
 import numpy as np
 import torch.nn.functional as F
-
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.train_utils.device_funcs import to_device
-import torch
-import torch.nn as nn
-from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
-from funasr.utils.load_utils import load_audio_text_image_video
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Any, List, Tuple, Optional
 
 from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+    # Nothing to do if torch<1.6.0
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
 
 @tables.register("model_classes", "CTTransformer")
-class CTTransformer(nn.Module):
+class CTTransformer(torch.nn.Module):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
@@ -45,11 +57,11 @@
             punc_weight = [1] * punc_size
         
         
-        self.embed = nn.Embedding(vocab_size, embed_unit)
+        self.embed = torch.nn.Embedding(vocab_size, embed_unit)
         encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(**encoder_conf)
 
-        self.decoder = nn.Linear(att_unit, punc_size)
+        self.decoder = torch.nn.Linear(att_unit, punc_size)
         self.encoder = encoder
         self.punc_list = punc_list
         self.punc_weight = punc_weight
@@ -211,7 +223,7 @@
         loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
         return loss, stats, weight
     
-    def generate(self,
+    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
@@ -321,20 +333,88 @@
                 elif new_mini_sentence[-1] == ",":
                     new_mini_sentence_out = new_mini_sentence[:-1] + "."
                     new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
-                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())==0:
+                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())!=1:
                     new_mini_sentence_out = new_mini_sentence + "銆�"
                     new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
+                    if len(punctuations): punctuations[-1] = 2
                 elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
                     new_mini_sentence_out = new_mini_sentence + "."
                     new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
-            # keep a punctuations array for punc segment
+                    if len(punctuations): punctuations[-1] = 2
+            # keep a punctuations array for punc segment 
             if punc_array is None:
                 punc_array = punctuations
             else:
                 punc_array = torch.cat([punc_array, punctuations], dim=0)
+        # post processing when using word level punc model
+        if jieba_usr_dict:
+            len_tokens = len(tokens)
+            new_punc_array = copy.copy(punc_array).tolist()
+            # for i, (token, punc_id) in enumerate(zip(tokens[::-1], punc_array.tolist()[::-1])):
+            for i, token in enumerate(tokens[::-1]):
+                if '\u0e00' <= token[0] <= '\u9fa5': # ignore en words
+                    if len(token) > 1:
+                        num_append = len(token) - 1
+                        ind_append = len_tokens - i - 1
+                        for _ in range(num_append):
+                            new_punc_array.insert(ind_append, 1)
+            punc_array = torch.tensor(new_punc_array)
         
         result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
         results.append(result_i)
-    
         return results, meta_data
 
+    def export(
+        self,
+        **kwargs,
+    ):
+
+        is_onnx = kwargs.get("type", "onnx") == "onnx"
+        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
+        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
+
+        self.forward = self._export_forward
+        
+        return self
+
+    def export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
+        """Compute loss value from buffer sequences.
+
+        Args:
+            input (torch.Tensor): Input ids. (batch, len)
+            hidden (torch.Tensor): Target ids. (batch, len)
+
+        """
+        x = self.embed(inputs)
+        h, _ = self.encoder(x, text_lengths)
+        y = self.decoder(h)
+        return y
+
+    def export_dummy_inputs(self):
+        length = 120
+        text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
+        text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
+        return (text_indexes, text_lengths)
+
+    def export_input_names(self):
+        return ['inputs', 'text_lengths']
+
+    def export_output_names(self):
+        return ['logits']
+
+    def export_dynamic_axes(self):
+        return {
+            'inputs': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+            'text_lengths': {
+                0: 'batch_size',
+            },
+            'logits': {
+                0: 'batch_size',
+                1: 'logits_length'
+            },
+        }
+    def export_name(self):
+        return "model.onnx"
\ No newline at end of file

--
Gitblit v1.9.1