From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/transducer/rnnt_decoder.py | 40 ++++++++++++++++++----------------------
1 files changed, 18 insertions(+), 22 deletions(-)
diff --git a/funasr/models/transducer/rnnt_decoder.py b/funasr/models/transducer/rnnt_decoder.py
index 6d35b71..bc1787a 100644
--- a/funasr/models/transducer/rnnt_decoder.py
+++ b/funasr/models/transducer/rnnt_decoder.py
@@ -1,12 +1,17 @@
-"""RNN decoder definition for Transducer models."""
-
-from typing import List, Optional, Tuple
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import torch
+from typing import List, Optional, Tuple
-from funasr.models.transducer.beam_search_transducer import Hypothesis
+from funasr.register import tables
from funasr.models.specaug.specaug import SpecAug
+from funasr.models.transducer.beam_search_transducer import Hypothesis
+
+@tables.register("decoder_classes", "rnnt_decoder")
class RNNTDecoder(torch.nn.Module):
"""RNN decoder module.
@@ -37,7 +42,6 @@
"""Construct a RNNDecoder object."""
super().__init__()
-
if rnn_type not in ("lstm", "gru"):
raise ValueError(f"Not supported: rnn_type={rnn_type}")
@@ -46,9 +50,7 @@
rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
- self.rnn = torch.nn.ModuleList(
- [rnn_class(embed_size, hidden_size, 1, batch_first=True)]
- )
+ self.rnn = torch.nn.ModuleList([rnn_class(embed_size, hidden_size, 1, batch_first=True)])
for _ in range(1, num_layers):
self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
@@ -72,9 +74,9 @@
time_mask_width_range=3,
num_time_mask=4,
apply_freq_mask=False,
- apply_time_warp=False
+ apply_time_warp=False,
)
-
+
def forward(
self,
labels: torch.Tensor,
@@ -123,13 +125,11 @@
for layer in range(self.dlayers):
if self.dtype == "lstm":
- x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
- layer
- ](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
- else:
- x, h_next[layer : layer + 1] = self.rnn[layer](
- x, hx=h_prev[layer : layer + 1]
+ x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[layer](
+ x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1])
)
+ else:
+ x, h_next[layer : layer + 1] = self.rnn[layer](x, hx=h_prev[layer : layer + 1])
x = self.dropout_rnn[layer](x)
@@ -198,9 +198,7 @@
"""
self.device = device
- def init_state(
- self, batch_size: int
- ) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
+ def init_state(self, batch_size: int) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
"""Initialize decoder states.
Args:
@@ -262,7 +260,5 @@
"""
return (
torch.cat([s[0] for s in new_states], dim=1),
- torch.cat([s[1] for s in new_states], dim=1)
- if self.dtype == "lstm"
- else None,
+ torch.cat([s[1] for s in new_states], dim=1) if self.dtype == "lstm" else None,
)
--
Gitblit v1.9.1