游雁
2024-02-19 94de39dde2e616a01683c518023d0fab72b4e103
funasr/models/ct_transformer_streaming/model.py
@@ -1,20 +1,28 @@
from typing import Any
from typing import List
from typing import Tuple
from typing import Optional
import numpy as np
import torch.nn.functional as F
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.train_utils.device_funcs import force_gatherable
from funasr.train_utils.device_funcs import to_device
import torch
import torch.nn as nn
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.models.ct_transformer.model import CTTransformer
import numpy as np
from contextlib import contextmanager
from distutils.version import LooseVersion
from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
from funasr.models.ct_transformer.model import CTTransformer
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
@tables.register("model_classes", "CTTransformerStreaming")
class CTTransformerStreaming(CTTransformer):
@@ -47,10 +55,8 @@
    def with_vad(self):
        return True
    
    def generate(self,
    def inference(self,
                 data_in,
                 data_lengths=None,
                 key: list = None,