From 97d648c255316ec1fff5d82e46749076faabdd2d Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 15 一月 2024 15:41:25 +0800
Subject: [PATCH] code optimize, model update, scripts
---
funasr/models/ct_transformer/model.py | 44 ++++++++++++++++++++++++++++----------------
1 files changed, 28 insertions(+), 16 deletions(-)
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index 5fb3ed4..285f5cc 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -1,22 +1,34 @@
-from typing import Any
-from typing import List
-from typing import Tuple
-from typing import Optional
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import torch
import numpy as np
import torch.nn.functional as F
-
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.train_utils.device_funcs import to_device
-import torch
-import torch.nn as nn
-from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
-from funasr.utils.load_utils import load_audio_text_image_video
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Any, List, Tuple, Optional
from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
@tables.register("model_classes", "CTTransformer")
-class CTTransformer(nn.Module):
+class CTTransformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
@@ -45,11 +57,11 @@
punc_weight = [1] * punc_size
- self.embed = nn.Embedding(vocab_size, embed_unit)
+ self.embed = torch.nn.Embedding(vocab_size, embed_unit)
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(**encoder_conf)
- self.decoder = nn.Linear(att_unit, punc_size)
+ self.decoder = torch.nn.Linear(att_unit, punc_size)
self.encoder = encoder
self.punc_list = punc_list
self.punc_weight = punc_weight
@@ -211,7 +223,7 @@
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list = None,
--
Gitblit v1.9.1