From 811ebea5b0d4b112a494b3ee9a63c4e35098dbf5 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 28 二月 2024 19:12:02 +0800
Subject: [PATCH] init param

---
 funasr/models/ct_transformer_streaming/model.py |   38 ++++++++++++++++++++++----------------
 1 files changed, 22 insertions(+), 16 deletions(-)

diff --git a/funasr/models/ct_transformer_streaming/model.py b/funasr/models/ct_transformer_streaming/model.py
index 5254d15..217767a 100644
--- a/funasr/models/ct_transformer_streaming/model.py
+++ b/funasr/models/ct_transformer_streaming/model.py
@@ -1,20 +1,28 @@
-from typing import Any
-from typing import List
-from typing import Tuple
-from typing import Optional
-import numpy as np
-import torch.nn.functional as F
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.train_utils.device_funcs import to_device
 import torch
-import torch.nn as nn
-from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
-from funasr.utils.load_utils import load_audio_text_image_video
-from funasr.models.ct_transformer.model import CTTransformer
+import numpy as np
+from contextlib import contextmanager
+from distutils.version import LooseVersion
 
 from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.ct_transformer.model import CTTransformer
+from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+    # Nothing to do if torch<1.6.0
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
 
 @tables.register("model_classes", "CTTransformerStreaming")
 class CTTransformerStreaming(CTTransformer):
@@ -47,10 +55,8 @@
 
     def with_vad(self):
         return True
-
-
     
-    def generate(self,
+    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,

--
Gitblit v1.9.1