shixian.shi
2024-01-15 97d648c255316ec1fff5d82e46749076faabdd2d
funasr/models/ct_transformer/model.py
@@ -1,22 +1,34 @@
from typing import Any
from typing import List
from typing import Tuple
from typing import Optional
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
import numpy as np
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.train_utils.device_funcs import force_gatherable
from funasr.train_utils.device_funcs import to_device
import torch
import torch.nn as nn
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
from funasr.utils.load_utils import load_audio_text_image_video
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Any, List, Tuple, Optional
from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
@tables.register("model_classes", "CTTransformer")
class CTTransformer(nn.Module):
class CTTransformer(torch.nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
@@ -45,11 +57,11 @@
            punc_weight = [1] * punc_size
        
        
        self.embed = nn.Embedding(vocab_size, embed_unit)
        self.embed = torch.nn.Embedding(vocab_size, embed_unit)
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(**encoder_conf)
        self.decoder = nn.Linear(att_unit, punc_size)
        self.decoder = torch.nn.Linear(att_unit, punc_size)
        self.encoder = encoder
        self.punc_list = punc_list
        self.punc_weight = punc_weight
@@ -211,7 +223,7 @@
        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
        return loss, stats, weight
    
    def generate(self,
    def inference(self,
                 data_in,
                 data_lengths=None,
                 key: list = None,