From 94de39dde2e616a01683c518023d0fab72b4e103 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 22:21:50 +0800
Subject: [PATCH] aishell example

---
 funasr/models/ct_transformer/model.py |  113 +++++++++++++++++++++++++++++++++++---------------------
 1 files changed, 71 insertions(+), 42 deletions(-)

diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index a1aff47..1e53aa3 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -1,21 +1,34 @@
-from typing import Any
-from typing import List
-from typing import Tuple
-from typing import Optional
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import torch
 import numpy as np
 import torch.nn.functional as F
-
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.train_utils.device_funcs import to_device
-import torch
-import torch.nn as nn
-from funasr.models.ct_transformer.utils import split_to_mini_sentence
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Any, List, Tuple, Optional
 
 from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+    # Nothing to do if torch<1.6.0
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
 
 @tables.register("model_classes", "CTTransformer")
-class CTTransformer(nn.Module):
+class CTTransformer(torch.nn.Module):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
@@ -34,6 +47,7 @@
         ignore_id: int = -1,
         sos: int = 1,
         eos: int = 2,
+        sentence_end_id: int = 3,
         **kwargs,
     ):
         super().__init__()
@@ -43,21 +57,22 @@
             punc_weight = [1] * punc_size
         
         
-        self.embed = nn.Embedding(vocab_size, embed_unit)
-        encoder_class = tables.encoder_classes.get(encoder.lower())
+        self.embed = torch.nn.Embedding(vocab_size, embed_unit)
+        encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(**encoder_conf)
 
-        self.decoder = nn.Linear(att_unit, punc_size)
+        self.decoder = torch.nn.Linear(att_unit, punc_size)
         self.encoder = encoder
         self.punc_list = punc_list
         self.punc_weight = punc_weight
         self.ignore_id = ignore_id
         self.sos = sos
         self.eos = eos
+        self.sentence_end_id = sentence_end_id
         
         
 
-    def punc_forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
+    def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs):
         """Compute loss value from buffer sequences.
 
         Args:
@@ -65,7 +80,7 @@
             hidden (torch.Tensor): Target ids. (batch, len)
 
         """
-        x = self.embed(input)
+        x = self.embed(text)
         # mask = self._target_mask(input)
         h, _, _ = self.encoder(x, text_lengths)
         y = self.decoder(h)
@@ -208,7 +223,7 @@
         loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
         return loss, stats, weight
     
-    def generate(self,
+    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
@@ -216,22 +231,33 @@
                  frontend=None,
                  **kwargs,
                  ):
+        assert len(data_in) == 1
+        text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
         vad_indexes = kwargs.get("vad_indexes", None)
-        text = data_in
-        text_lengths = data_lengths
+        # text = data_in[0]
+        # text_lengths = data_lengths[0] if data_lengths is not None else None
         split_size = kwargs.get("split_size", 20)
-        
-        data = {"text": text}
-        result = self.preprocessor(data=data, uid="12938712838719")
-        split_text = self.preprocessor.pop_split_text_data(result)
-        mini_sentences = split_to_mini_sentence(split_text, split_size)
-        mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
+
+        jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
+        if jieba_usr_dict and isinstance(jieba_usr_dict, str):
+            import jieba
+            jieba.load_userdict(jieba_usr_dict)
+            jieba_usr_dict = jieba
+            kwargs["jieba_usr_dict"] = "jieba_usr_dict"
+        tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
+        tokens_int = tokenizer.encode(tokens)
+
+        mini_sentences = split_to_mini_sentence(tokens, split_size)
+        mini_sentences_id = split_to_mini_sentence(tokens_int, split_size)
         assert len(mini_sentences) == len(mini_sentences_id)
         cache_sent = []
         cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
         new_mini_sentence = ""
         new_mini_sentence_punc = []
         cache_pop_trigger_limit = 200
+        results = []
+        meta_data = {}
+        punc_array = None
         for mini_sentence_i in range(len(mini_sentences)):
             mini_sentence = mini_sentences[mini_sentence_i]
             mini_sentence_id = mini_sentences_id[mini_sentence_i]
@@ -241,9 +267,9 @@
                 "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
                 "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
             }
-            data = to_device(data, self.device)
+            data = to_device(data, kwargs["device"])
             # y, _ = self.wrapped_model(**data)
-            y, _ = self.punc_forward(text, text_lengths)
+            y, _ = self.punc_forward(**data)
             _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
             punctuations = indices
             if indices.size()[0] != 1:
@@ -264,7 +290,7 @@
                 if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
                     # The sentence it too long, cut off at a comma.
                     sentenceEnd = last_comma_index
-                    punctuations[sentenceEnd] = self.period
+                    punctuations[sentenceEnd] = self.sentence_end_id
                 cache_sent = mini_sentence[sentenceEnd + 1:]
                 cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
                 mini_sentence = mini_sentence[0:sentenceEnd + 1]
@@ -303,21 +329,24 @@
             if mini_sentence_i == len(mini_sentences) - 1:
                 if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
                     new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
                 elif new_mini_sentence[-1] == ",":
                     new_mini_sentence_out = new_mini_sentence[:-1] + "."
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
-                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())==0:
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
+                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())!=1:
                     new_mini_sentence_out = new_mini_sentence + "銆�"
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
+                    if len(punctuations): punctuations[-1] = 2
                 elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
                     new_mini_sentence_out = new_mini_sentence + "."
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
-        
-        return new_mini_sentence_out, new_mini_sentence_punc_out
-        
-        # if self.with_vad():
-        #     assert vad_indexes is not None
-        #     return self.punc_forward(text, text_lengths, vad_indexes)
-        # else:
-        #     return self.punc_forward(text, text_lengths)
\ No newline at end of file
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
+                    if len(punctuations): punctuations[-1] = 2
+            # keep a punctuations array for punc segment
+            if punc_array is None:
+                punc_array = punctuations
+            else:
+                punc_array = torch.cat([punc_array, punctuations], dim=0)
+        result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
+        results.append(result_i)
+        return results, meta_data
+

--
Gitblit v1.9.1