kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/transducer/rnnt_decoder.py
@@ -42,7 +42,6 @@
        """Construct a RNNDecoder object."""
        super().__init__()
        if rnn_type not in ("lstm", "gru"):
            raise ValueError(f"Not supported: rnn_type={rnn_type}")
@@ -51,9 +50,7 @@
        rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
        self.rnn = torch.nn.ModuleList(
            [rnn_class(embed_size, hidden_size, 1, batch_first=True)]
        )
        self.rnn = torch.nn.ModuleList([rnn_class(embed_size, hidden_size, 1, batch_first=True)])
        for _ in range(1, num_layers):
            self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
@@ -77,9 +74,9 @@
                time_mask_width_range=3,
                num_time_mask=4,
                apply_freq_mask=False,
                apply_time_warp=False
                apply_time_warp=False,
            )
    def forward(
        self,
        labels: torch.Tensor,
@@ -128,13 +125,11 @@
        for layer in range(self.dlayers):
            if self.dtype == "lstm":
                x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
                    layer
                ](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
            else:
                x, h_next[layer : layer + 1] = self.rnn[layer](
                    x, hx=h_prev[layer : layer + 1]
                x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[layer](
                    x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1])
                )
            else:
                x, h_next[layer : layer + 1] = self.rnn[layer](x, hx=h_prev[layer : layer + 1])
            x = self.dropout_rnn[layer](x)
@@ -203,9 +198,7 @@
        """
        self.device = device
    def init_state(
        self, batch_size: int
    ) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
    def init_state(self, batch_size: int) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
        """Initialize decoder states.
        Args:
@@ -267,7 +260,5 @@
        """
        return (
            torch.cat([s[0] for s in new_states], dim=1),
            torch.cat([s[1] for s in new_states], dim=1)
            if self.dtype == "lstm"
            else None,
            torch.cat([s[1] for s in new_states], dim=1) if self.dtype == "lstm" else None,
        )