From a1b0cd33d50cee3e4612d1e787399e508b453a4a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 十二月 2023 14:20:21 +0800
Subject: [PATCH] rename register tables
---
funasr/models/transformer/decoder.py | 12 ++++++------
1 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/funasr/models/transformer/decoder.py b/funasr/models/transformer/decoder.py
index 3e8d224..820de4a 100644
--- a/funasr/models/transformer/decoder.py
+++ b/funasr/models/transformer/decoder.py
@@ -26,7 +26,7 @@
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
-from funasr.utils.register import register_class, registry_tables
+from funasr.register import tables
class DecoderLayer(nn.Module):
"""Single decoder layer module.
@@ -352,7 +352,7 @@
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
-@register_class("decoder_classes", "TransformerDecoder")
+@tables.register("decoder_classes", "TransformerDecoder")
class TransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
@@ -401,7 +401,7 @@
)
-@register_class("decoder_classes", "LightweightConvolutionTransformerDecoder")
+@tables.register("decoder_classes", "LightweightConvolutionTransformerDecoder")
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
@@ -462,7 +462,7 @@
),
)
-@register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder")
+@tables.register("decoder_classes", "LightweightConvolution2DTransformerDecoder")
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
@@ -524,7 +524,7 @@
)
-@register_class("decoder_classes", "DynamicConvolutionTransformerDecoder")
+@tables.register("decoder_classes", "DynamicConvolutionTransformerDecoder")
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
@@ -585,7 +585,7 @@
),
)
-@register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder")
+@tables.register("decoder_classes", "DynamicConvolution2DTransformerDecoder")
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
--
Gitblit v1.9.1