From 94de39dde2e616a01683c518023d0fab72b4e103 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 22:21:50 +0800
Subject: [PATCH] aishell example
---
funasr/models/ct_transformer/model.py | 186 ++++++++++++++++++++++++++++++++++++++++-----
1 files changed, 163 insertions(+), 23 deletions(-)
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index 31b2af2..1e53aa3 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -1,14 +1,34 @@
-from typing import Any
-from typing import List
-from typing import Tuple
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import torch
-import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Any, List, Tuple, Optional
-from funasr.utils.register import register_class, registry_tables
+from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
-@register_class("model_classes", "CTTransformer")
-class CTTransformer(nn.Module):
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+@tables.register("model_classes", "CTTransformer")
+class CTTransformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
@@ -17,7 +37,7 @@
def __init__(
self,
encoder: str = None,
- encoder_conf: str = None,
+ encoder_conf: dict = None,
vocab_size: int = -1,
punc_list: list = None,
punc_weight: list = None,
@@ -27,6 +47,7 @@
ignore_id: int = -1,
sos: int = 1,
eos: int = 2,
+ sentence_end_id: int = 3,
**kwargs,
):
super().__init__()
@@ -36,21 +57,22 @@
punc_weight = [1] * punc_size
- self.embed = nn.Embedding(vocab_size, embed_unit)
- encoder_class = registry_tables.encoder_classes.get(encoder.lower())
+ self.embed = torch.nn.Embedding(vocab_size, embed_unit)
+ encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(**encoder_conf)
- self.decoder = nn.Linear(att_unit, punc_size)
+ self.decoder = torch.nn.Linear(att_unit, punc_size)
self.encoder = encoder
self.punc_list = punc_list
self.punc_weight = punc_weight
self.ignore_id = ignore_id
self.sos = sos
self.eos = eos
+ self.sentence_end_id = sentence_end_id
- def punc_forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
+ def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs):
"""Compute loss value from buffer sequences.
Args:
@@ -58,7 +80,7 @@
hidden (torch.Tensor): Target ids. (batch, len)
"""
- x = self.embed(input)
+ x = self.embed(text)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
@@ -191,7 +213,7 @@
punc_lengths: torch.Tensor,
vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ ):
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
@@ -201,12 +223,130 @@
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
- def generate(self,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
- if self.with_vad():
- assert vad_indexes is not None
- return self.punc_forward(text, text_lengths, vad_indexes)
- else:
- return self.punc_forward(text, text_lengths)
\ No newline at end of file
+ def inference(self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+ assert len(data_in) == 1
+ text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
+ vad_indexes = kwargs.get("vad_indexes", None)
+ # text = data_in[0]
+ # text_lengths = data_lengths[0] if data_lengths is not None else None
+ split_size = kwargs.get("split_size", 20)
+
+ jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
+ if jieba_usr_dict and isinstance(jieba_usr_dict, str):
+ import jieba
+ jieba.load_userdict(jieba_usr_dict)
+ jieba_usr_dict = jieba
+ kwargs["jieba_usr_dict"] = "jieba_usr_dict"
+ tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
+ tokens_int = tokenizer.encode(tokens)
+
+ mini_sentences = split_to_mini_sentence(tokens, split_size)
+ mini_sentences_id = split_to_mini_sentence(tokens_int, split_size)
+ assert len(mini_sentences) == len(mini_sentences_id)
+ cache_sent = []
+ cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
+ new_mini_sentence = ""
+ new_mini_sentence_punc = []
+ cache_pop_trigger_limit = 200
+ results = []
+ meta_data = {}
+ punc_array = None
+ for mini_sentence_i in range(len(mini_sentences)):
+ mini_sentence = mini_sentences[mini_sentence_i]
+ mini_sentence_id = mini_sentences_id[mini_sentence_i]
+ mini_sentence = cache_sent + mini_sentence
+ mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+ data = {
+ "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
+ "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
+ }
+ data = to_device(data, kwargs["device"])
+ # y, _ = self.wrapped_model(**data)
+ y, _ = self.punc_forward(**data)
+ _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
+ punctuations = indices
+ if indices.size()[0] != 1:
+ punctuations = torch.squeeze(indices)
+ assert punctuations.size()[0] == len(mini_sentence)
+
+ # Search for the last Period/QuestionMark as cache
+ if mini_sentence_i < len(mini_sentences) - 1:
+ sentenceEnd = -1
+ last_comma_index = -1
+ for i in range(len(punctuations) - 2, 1, -1):
+ if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+ sentenceEnd = i
+ break
+ if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
+ last_comma_index = i
+
+ if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+ # The sentence it too long, cut off at a comma.
+ sentenceEnd = last_comma_index
+ punctuations[sentenceEnd] = self.sentence_end_id
+ cache_sent = mini_sentence[sentenceEnd + 1:]
+ cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
+ mini_sentence = mini_sentence[0:sentenceEnd + 1]
+ punctuations = punctuations[0:sentenceEnd + 1]
+
+ # if len(punctuations) == 0:
+ # continue
+
+ punctuations_np = punctuations.cpu().numpy()
+ new_mini_sentence_punc += [int(x) for x in punctuations_np]
+ words_with_punc = []
+ for i in range(len(mini_sentence)):
+ if (i==0 or self.punc_list[punctuations[i-1]] == "銆�" or self.punc_list[punctuations[i-1]] == "锛�") and len(mini_sentence[i][0].encode()) == 1:
+ mini_sentence[i] = mini_sentence[i].capitalize()
+ if i == 0:
+ if len(mini_sentence[i][0].encode()) == 1:
+ mini_sentence[i] = " " + mini_sentence[i]
+ if i > 0:
+ if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
+ mini_sentence[i] = " " + mini_sentence[i]
+ words_with_punc.append(mini_sentence[i])
+ if self.punc_list[punctuations[i]] != "_":
+ punc_res = self.punc_list[punctuations[i]]
+ if len(mini_sentence[i][0].encode()) == 1:
+ if punc_res == "锛�":
+ punc_res = ","
+ elif punc_res == "銆�":
+ punc_res = "."
+ elif punc_res == "锛�":
+ punc_res = "?"
+ words_with_punc.append(punc_res)
+ new_mini_sentence += "".join(words_with_punc)
+ # Add Period for the end of the sentence
+ new_mini_sentence_out = new_mini_sentence
+ new_mini_sentence_punc_out = new_mini_sentence_punc
+ if mini_sentence_i == len(mini_sentences) - 1:
+ if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
+ new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
+ elif new_mini_sentence[-1] == ",":
+ new_mini_sentence_out = new_mini_sentence[:-1] + "."
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
+ elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())!=1:
+ new_mini_sentence_out = new_mini_sentence + "銆�"
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
+ if len(punctuations): punctuations[-1] = 2
+ elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
+ new_mini_sentence_out = new_mini_sentence + "."
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
+ if len(punctuations): punctuations[-1] = 2
+ # keep a punctuations array for punc segment
+ if punc_array is None:
+ punc_array = punctuations
+ else:
+ punc_array = torch.cat([punc_array, punctuations], dim=0)
+ result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
+ results.append(result_i)
+ return results, meta_data
+
--
Gitblit v1.9.1