游雁
2024-03-27 9b4e9cc8a0311e5243d69b73ed073e7ea441982e
funasr/models/ct_transformer/model.py
@@ -1,21 +1,37 @@
from typing import Any
from typing import List
from typing import Tuple
from typing import Optional
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import copy
import torch
import numpy as np
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.train_utils.device_funcs import force_gatherable
from funasr.train_utils.device_funcs import to_device
import torch
import torch.nn as nn
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Any, List, Tuple, Optional
from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
try:
    import jieba
except:
    pass
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
@tables.register("model_classes", "CTTransformer")
class CTTransformer(nn.Module):
class CTTransformer(torch.nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
@@ -44,11 +60,11 @@
            punc_weight = [1] * punc_size
        
        
        self.embed = nn.Embedding(vocab_size, embed_unit)
        encoder_class = tables.encoder_classes.get(encoder.lower())
        self.embed = torch.nn.Embedding(vocab_size, embed_unit)
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(**encoder_conf)
        self.decoder = nn.Linear(att_unit, punc_size)
        self.decoder = torch.nn.Linear(att_unit, punc_size)
        self.encoder = encoder
        self.punc_list = punc_list
        self.punc_weight = punc_weight
@@ -56,10 +72,14 @@
        self.sos = sos
        self.eos = eos
        self.sentence_end_id = sentence_end_id
        self.jieba_usr_dict = None
        if kwargs.get("jieba_usr_dict", None) is not None:
            jieba.load_userdict(kwargs["jieba_usr_dict"])
            self.jieba_usr_dict = jieba
        
        
    def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
    def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs):
        """Compute loss value from buffer sequences.
        Args:
@@ -210,7 +230,7 @@
        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
        return loss, stats, weight
    
    def generate(self,
    def inference(self,
                 data_in,
                 data_lengths=None,
                 key: list = None,
@@ -219,13 +239,13 @@
                 **kwargs,
                 ):
        assert len(data_in) == 1
        text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
        vad_indexes = kwargs.get("vad_indexes", None)
        text = data_in[0]
        text_lengths = data_lengths[0] if data_lengths is not None else None
        # text = data_in[0]
        # text_lengths = data_lengths[0] if data_lengths is not None else None
        split_size = kwargs.get("split_size", 20)
        
        tokens = split_words(text)
        tokens = split_words(text, jieba_usr_dict=self.jieba_usr_dict)
        tokens_int = tokenizer.encode(tokens)
        mini_sentences = split_to_mini_sentence(tokens, split_size)
@@ -238,6 +258,7 @@
        cache_pop_trigger_limit = 200
        results = []
        meta_data = {}
        punc_array = None
        for mini_sentence_i in range(len(mini_sentences)):
            mini_sentence = mini_sentences[mini_sentence_i]
            mini_sentence_id = mini_sentences_id[mini_sentence_i]
@@ -313,15 +334,41 @@
                elif new_mini_sentence[-1] == ",":
                    new_mini_sentence_out = new_mini_sentence[:-1] + "."
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
                elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==0:
                elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())!=1:
                    new_mini_sentence_out = new_mini_sentence + "。"
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
                    if len(punctuations): punctuations[-1] = 2
                elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
                    new_mini_sentence_out = new_mini_sentence + "."
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
        result_i = {"key": key[0], "text": new_mini_sentence_out}
                    if len(punctuations): punctuations[-1] = 2
            # keep a punctuations array for punc segment
            if punc_array is None:
                punc_array = punctuations
            else:
                punc_array = torch.cat([punc_array, punctuations], dim=0)
        # post processing when using word level punc model
        if self.jieba_usr_dict is not None:
            len_tokens = len(tokens)
            new_punc_array = copy.copy(punc_array).tolist()
            # for i, (token, punc_id) in enumerate(zip(tokens[::-1], punc_array.tolist()[::-1])):
            for i, token in enumerate(tokens[::-1]):
                if '\u0e00' <= token[0] <= '\u9fa5': # ignore en words
                    if len(token) > 1:
                        num_append = len(token) - 1
                        ind_append = len_tokens - i - 1
                        for _ in range(num_append):
                            new_punc_array.insert(ind_append, 1)
            punc_array = torch.tensor(new_punc_array)
        result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
        results.append(result_i)
        return results, meta_data
    def export(self, **kwargs):
        from .export_meta import export_rebuild_model
        models = export_rebuild_model(model=self, **kwargs)
        return models