From a4bd736b038a64fb14c3849e4a2bd26deb02517b Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 18 四月 2023 14:44:59 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/modules/repeat.py |   91 +++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 91 insertions(+), 0 deletions(-)

diff --git a/funasr/modules/repeat.py b/funasr/modules/repeat.py
index a3d2676..2b2dac8 100644
--- a/funasr/modules/repeat.py
+++ b/funasr/modules/repeat.py
@@ -6,6 +6,8 @@
 
 """Repeat the same layer definition."""
 
+from typing import Dict, List, Optional
+
 import torch
 
 
@@ -31,3 +33,92 @@
 
     """
     return MultiSequential(*[fn(n) for n in range(N)])
+
+
+class MultiBlocks(torch.nn.Module):
+    """MultiBlocks definition.
+    Args:
+        block_list: Individual blocks of the encoder architecture.
+        output_size: Architecture output size.
+        norm_class: Normalization module class.
+        norm_args: Normalization module arguments.
+    """
+
+    def __init__(
+        self,
+        block_list: List[torch.nn.Module],
+        output_size: int,
+        norm_class: torch.nn.Module = torch.nn.LayerNorm,
+    ) -> None:
+        """Construct a MultiBlocks object."""
+        super().__init__()
+
+        self.blocks = torch.nn.ModuleList(block_list)
+        self.norm_blocks = norm_class(output_size)
+
+        self.num_blocks = len(block_list)
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset encoder streaming cache.
+        Args:
+            left_context: Number of left frames during chunk-by-chunk inference.
+            device: Device to use for cache tensor.
+        """
+        for idx in range(self.num_blocks):
+            self.blocks[idx].reset_streaming_cache(left_context, device)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Forward each block of the encoder architecture.
+        Args:
+            x: MultiBlocks input sequences. (B, T, D_block_1)
+            pos_enc: Positional embedding sequences.
+            mask: Source mask. (B, T)
+            chunk_mask: Chunk mask. (T_2, T_2)
+        Returns:
+            x: Output sequences. (B, T, D_block_N)
+        """
+        for block_index, block in enumerate(self.blocks):
+            x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
+
+        x = self.norm_blocks(x)
+
+        return x
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_size: int = 0,
+        left_context: int = 0,
+        right_context: int = 0,
+    ) -> torch.Tensor:
+        """Forward each block of the encoder architecture.
+        Args:
+            x: MultiBlocks input sequences. (B, T, D_block_1)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
+            mask: Source mask. (B, T_2)
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+        Returns:
+            x: MultiBlocks output sequences. (B, T, D_block_N)
+        """
+        for block_idx, block in enumerate(self.blocks):
+            x, pos_enc = block.chunk_forward(
+                x,
+                pos_enc,
+                mask,
+                chunk_size=chunk_size,
+                left_context=left_context,
+                right_context=right_context,
+            )
+
+        x = self.norm_blocks(x)
+
+        return x

--
Gitblit v1.9.1