游雁
2024-02-19 94de39dde2e616a01683c518023d0fab72b4e103
funasr/models/ct_transformer/model.py
@@ -1,22 +1,34 @@
from typing import Any
from typing import List
from typing import Tuple
from typing import Optional
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
import numpy as np
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.train_utils.device_funcs import force_gatherable
from funasr.train_utils.device_funcs import to_device
import torch
import torch.nn as nn
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
from funasr.utils.load_utils import load_audio_text_image_video
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Any, List, Tuple, Optional
from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
@tables.register("model_classes", "CTTransformer")
class CTTransformer(nn.Module):
class CTTransformer(torch.nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
@@ -45,11 +57,11 @@
            punc_weight = [1] * punc_size
        
        
        self.embed = nn.Embedding(vocab_size, embed_unit)
        self.embed = torch.nn.Embedding(vocab_size, embed_unit)
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(**encoder_conf)
        self.decoder = nn.Linear(att_unit, punc_size)
        self.decoder = torch.nn.Linear(att_unit, punc_size)
        self.encoder = encoder
        self.punc_list = punc_list
        self.punc_weight = punc_weight
@@ -211,7 +223,7 @@
        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
        return loss, stats, weight
    
    def generate(self,
    def inference(self,
                 data_in,
                 data_lengths=None,
                 key: list = None,
@@ -321,20 +333,20 @@
                elif new_mini_sentence[-1] == ",":
                    new_mini_sentence_out = new_mini_sentence[:-1] + "."
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
                elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==0:
                elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())!=1:
                    new_mini_sentence_out = new_mini_sentence + "。"
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
                    if len(punctuations): punctuations[-1] = 2
                elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
                    new_mini_sentence_out = new_mini_sentence + "."
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
                    if len(punctuations): punctuations[-1] = 2
            # keep a punctuations array for punc segment
            if punc_array is None:
                punc_array = punctuations
            else:
                punc_array = torch.cat([punc_array, punctuations], dim=0)
        result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
        results.append(result_i)
        return results, meta_data