From bbbf17e4d97ff155049c424af4e96bfded9089b1 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 一月 2024 15:51:57 +0800
Subject: [PATCH] fix win bug
---
funasr/models/ct_transformer/model.py | 49 ++++++++++++++++++++++++++++++-------------------
1 files changed, 30 insertions(+), 19 deletions(-)
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index d843686..8c3f043 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -1,22 +1,34 @@
-from typing import Any
-from typing import List
-from typing import Tuple
-from typing import Optional
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import torch
import numpy as np
import torch.nn.functional as F
-
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.train_utils.device_funcs import to_device
-import torch
-import torch.nn as nn
-from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
-from funasr.utils.load_utils import load_audio_text_image_video
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Any, List, Tuple, Optional
from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
@tables.register("model_classes", "CTTransformer")
-class CTTransformer(nn.Module):
+class CTTransformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
@@ -45,11 +57,11 @@
punc_weight = [1] * punc_size
- self.embed = nn.Embedding(vocab_size, embed_unit)
- encoder_class = tables.encoder_classes.get(encoder.lower())
+ self.embed = torch.nn.Embedding(vocab_size, embed_unit)
+ encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(**encoder_conf)
- self.decoder = nn.Linear(att_unit, punc_size)
+ self.decoder = torch.nn.Linear(att_unit, punc_size)
self.encoder = encoder
self.punc_list = punc_list
self.punc_weight = punc_weight
@@ -60,7 +72,7 @@
- def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
+ def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs):
"""Compute loss value from buffer sequences.
Args:
@@ -211,7 +223,7 @@
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list = None,
@@ -332,7 +344,6 @@
punc_array = punctuations
else:
punc_array = torch.cat([punc_array, punctuations], dim=0)
-
result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
results.append(result_i)
--
Gitblit v1.9.1