游雁
2023-12-27 f6b611de44c3a535befa96da552d07b0ed1b073c
funasr/models/ct_transformer/model.py
@@ -10,7 +10,7 @@
from funasr.train_utils.device_funcs import to_device
import torch
import torch.nn as nn
from funasr.models.ct_transformer.utils import split_to_mini_sentence
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
from funasr.register import tables
@@ -34,6 +34,7 @@
        ignore_id: int = -1,
        sos: int = 1,
        eos: int = 2,
        sentence_end_id: int = 3,
        **kwargs,
    ):
        super().__init__()
@@ -54,10 +55,11 @@
        self.ignore_id = ignore_id
        self.sos = sos
        self.eos = eos
        self.sentence_end_id = sentence_end_id
        
        
    def punc_forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
    def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
        """Compute loss value from buffer sequences.
        Args:
@@ -65,7 +67,7 @@
            hidden (torch.Tensor): Target ids. (batch, len)
        """
        x = self.embed(input)
        x = self.embed(text)
        # mask = self._target_mask(input)
        h, _, _ = self.encoder(x, text_lengths)
        y = self.decoder(h)
@@ -216,22 +218,26 @@
                 frontend=None,
                 **kwargs,
                 ):
        assert len(data_in) == 1
        vad_indexes = kwargs.get("vad_indexes", None)
        text = data_in
        text_lengths = data_lengths
        text = data_in[0]
        text_lengths = data_lengths[0] if data_lengths is not None else None
        split_size = kwargs.get("split_size", 20)
        
        data = {"text": text}
        result = self.preprocessor(data=data, uid="12938712838719")
        split_text = self.preprocessor.pop_split_text_data(result)
        mini_sentences = split_to_mini_sentence(split_text, split_size)
        mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
        tokens = split_words(text)
        tokens_int = tokenizer.encode(tokens)
        mini_sentences = split_to_mini_sentence(tokens, split_size)
        mini_sentences_id = split_to_mini_sentence(tokens_int, split_size)
        assert len(mini_sentences) == len(mini_sentences_id)
        cache_sent = []
        cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
        new_mini_sentence = ""
        new_mini_sentence_punc = []
        cache_pop_trigger_limit = 200
        results = []
        meta_data = {}
        for mini_sentence_i in range(len(mini_sentences)):
            mini_sentence = mini_sentences[mini_sentence_i]
            mini_sentence_id = mini_sentences_id[mini_sentence_i]
@@ -241,9 +247,9 @@
                "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
                "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
            }
            data = to_device(data, self.device)
            data = to_device(data, kwargs["device"])
            # y, _ = self.wrapped_model(**data)
            y, _ = self.punc_forward(text, text_lengths)
            y, _ = self.punc_forward(**data)
            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
            punctuations = indices
            if indices.size()[0] != 1:
@@ -264,7 +270,7 @@
                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
                    # The sentence it too long, cut off at a comma.
                    sentenceEnd = last_comma_index
                    punctuations[sentenceEnd] = self.period
                    punctuations[sentenceEnd] = self.sentence_end_id
                cache_sent = mini_sentence[sentenceEnd + 1:]
                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
                mini_sentence = mini_sentence[0:sentenceEnd + 1]
@@ -303,21 +309,19 @@
            if mini_sentence_i == len(mini_sentences) - 1:
                if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、":
                    new_mini_sentence_out = new_mini_sentence[:-1] + "。"
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
                elif new_mini_sentence[-1] == ",":
                    new_mini_sentence_out = new_mini_sentence[:-1] + "."
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
                elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==0:
                    new_mini_sentence_out = new_mini_sentence + "。"
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
                elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
                    new_mini_sentence_out = new_mini_sentence + "."
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
        return new_mini_sentence_out, new_mini_sentence_punc_out
        # if self.with_vad():
        #     assert vad_indexes is not None
        #     return self.punc_forward(text, text_lengths, vad_indexes)
        # else:
        #     return self.punc_forward(text, text_lengths)
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
        result_i = {"key": key[0], "text": new_mini_sentence_out}
        results.append(result_i)
        return results, meta_data