From 7d1efe158eda74dc847c397db906f6cb77ac0f84 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 16:49:56 +0800
Subject: [PATCH] rnnt reorg

---
 funasr/modules/nets_utils.py |  195 ++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 194 insertions(+), 1 deletions(-)

diff --git a/funasr/modules/nets_utils.py b/funasr/modules/nets_utils.py
index 6d77d69..5d4fe1c 100644
--- a/funasr/modules/nets_utils.py
+++ b/funasr/modules/nets_utils.py
@@ -3,7 +3,7 @@
 """Network related utility tools."""
 
 import logging
-from typing import Dict
+from typing import Dict, List, Tuple
 
 import numpy as np
 import torch
@@ -506,3 +506,196 @@
     }
 
     return activation_funcs[act]()
+
+class TooShortUttError(Exception):
+    """Raised when the utt is too short for subsampling.
+
+    Args:
+        message: Error message to display.
+        actual_size: The size that cannot pass the subsampling.
+        limit: The size limit for subsampling.
+
+    """
+
+    def __init__(self, message: str, actual_size: int, limit: int) -> None:
+        """Construct a TooShortUttError module."""
+        super().__init__(message)
+
+        self.actual_size = actual_size
+        self.limit = limit
+
+
+def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
+    """Check if the input is too short for subsampling.
+
+    Args:
+        sub_factor: Subsampling factor for Conv2DSubsampling.
+        size: Input size.
+
+    Returns:
+        : Whether an error should be sent.
+        : Size limit for specified subsampling factor.
+
+    """
+    if sub_factor == 2 and size < 3:
+        return True, 7
+    elif sub_factor == 4 and size < 7:
+        return True, 7
+    elif sub_factor == 6 and size < 11:
+        return True, 11
+
+    return False, -1
+
+
+def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
+    """Get conv2D second layer parameters for given subsampling factor.
+
+    Args:
+        sub_factor: Subsampling factor (1/X).
+        input_size: Input size.
+
+    Returns:
+        : Kernel size for second convolution.
+        : Stride for second convolution.
+        : Conv2DSubsampling output size.
+
+    """
+    if sub_factor == 2:
+        return 3, 1, (((input_size - 1) // 2 - 2))
+    elif sub_factor == 4:
+        return 3, 2, (((input_size - 1) // 2 - 1) // 2)
+    elif sub_factor == 6:
+        return 5, 3, (((input_size - 1) // 2 - 2) // 3)
+    else:
+        raise ValueError(
+            "subsampling_factor parameter should be set to either 2, 4 or 6."
+        )
+
+
+def make_chunk_mask(
+    size: int,
+    chunk_size: int,
+    left_chunk_size: int = 0,
+    device: torch.device = None,
+) -> torch.Tensor:
+    """Create chunk mask for the subsequent steps (size, size).
+
+    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+    Args:
+        size: Size of the source mask.
+        chunk_size: Number of frames in chunk.
+        left_chunk_size: Size of the left context in chunks (0 means full context).
+        device: Device for the mask tensor.
+
+    Returns:
+        mask: Chunk mask. (size, size)
+
+    """
+    mask = torch.zeros(size, size, device=device, dtype=torch.bool)
+
+    for i in range(size):
+        if left_chunk_size <= 0:
+            start = 0
+        else:
+            start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
+
+        end = min((i // chunk_size + 1) * chunk_size, size)
+        mask[i, start:end] = True
+
+    return ~mask
+
+def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
+    """Create source mask for given lengths.
+
+    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+    Args:
+        lengths: Sequence lengths. (B,)
+
+    Returns:
+        : Mask for the sequence lengths. (B, max_len)
+
+    """
+    max_len = lengths.max()
+    batch_size = lengths.size(0)
+
+    expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
+
+    return expanded_lengths >= lengths.unsqueeze(1)
+
+
+def get_transducer_task_io(
+    labels: torch.Tensor,
+    encoder_out_lens: torch.Tensor,
+    ignore_id: int = -1,
+    blank_id: int = 0,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    """Get Transducer loss I/O.
+
+    Args:
+        labels: Label ID sequences. (B, L)
+        encoder_out_lens: Encoder output lengths. (B,)
+        ignore_id: Padding symbol ID.
+        blank_id: Blank symbol ID.
+
+    Returns:
+        decoder_in: Decoder inputs. (B, U)
+        target: Target label ID sequences. (B, U)
+        t_len: Time lengths. (B,)
+        u_len: Label lengths. (B,)
+
+    """
+
+    def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
+        """Create padded batch of labels from a list of labels sequences.
+
+        Args:
+            labels: Labels sequences. [B x (?)]
+            padding_value: Padding value.
+
+        Returns:
+            labels: Batch of padded labels sequences. (B,)
+
+        """
+        batch_size = len(labels)
+
+        padded = (
+            labels[0]
+            .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
+            .fill_(padding_value)
+        )
+
+        for i in range(batch_size):
+            padded[i, : labels[i].size(0)] = labels[i]
+
+        return padded
+
+    device = labels.device
+
+    labels_unpad = [y[y != ignore_id] for y in labels]
+    blank = labels[0].new([blank_id])
+
+    decoder_in = pad_list(
+        [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
+    ).to(device)
+
+    target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
+
+    encoder_out_lens = list(map(int, encoder_out_lens))
+    t_len = torch.IntTensor(encoder_out_lens).to(device)
+
+    u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
+
+    return decoder_in, target, t_len, u_len
+
+def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
+    """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
+    if t.size(dim) == pad_len:
+        return t
+    else:
+        pad_size = list(t.shape)
+        pad_size[dim] = pad_len - t.size(dim)
+        return torch.cat(
+            [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
+        )

--
Gitblit v1.9.1