From 811ebea5b0d4b112a494b3ee9a63c4e35098dbf5 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 28 二月 2024 19:12:02 +0800
Subject: [PATCH] init param
---
funasr/models/ct_transformer_streaming/model.py | 38 ++++++++++++++++++++++----------------
1 files changed, 22 insertions(+), 16 deletions(-)
diff --git a/funasr/models/ct_transformer_streaming/model.py b/funasr/models/ct_transformer_streaming/model.py
index 5254d15..217767a 100644
--- a/funasr/models/ct_transformer_streaming/model.py
+++ b/funasr/models/ct_transformer_streaming/model.py
@@ -1,20 +1,28 @@
-from typing import Any
-from typing import List
-from typing import Tuple
-from typing import Optional
-import numpy as np
-import torch.nn.functional as F
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.train_utils.device_funcs import to_device
import torch
-import torch.nn as nn
-from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
-from funasr.utils.load_utils import load_audio_text_image_video
-from funasr.models.ct_transformer.model import CTTransformer
+import numpy as np
+from contextlib import contextmanager
+from distutils.version import LooseVersion
from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.ct_transformer.model import CTTransformer
+from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
@tables.register("model_classes", "CTTransformerStreaming")
class CTTransformerStreaming(CTTransformer):
@@ -47,10 +55,8 @@
def with_vad(self):
return True
-
-
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list = None,
--
Gitblit v1.9.1