游雁
2024-02-19 94de39dde2e616a01683c518023d0fab72b4e103
funasr/models/ct_transformer_streaming/encoder.py
@@ -1,39 +1,29 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.scama.chunk_utilis import overlap_chunk
import numpy as np
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.sanm.attention import MultiHeadedAttention
from funasr.models.ct_transformer.attention import MultiHeadedAttentionSANMwithMask
from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
from funasr.models.transformer.positionwise_feed_forward import (
    PositionwiseFeedForward,  # noqa: H301
)
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
from funasr.models.transformer.utils.subsampling import TooShortUttError
from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from funasr.models.ctc.ctc import CTC
import torch
from typing import List, Optional, Tuple
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.sanm.attention import MultiHeadedAttention
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.transformer.utils.subsampling import TooShortUttError
from funasr.models.transformer.embedding import SinusoidalPositionEncoder
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
from funasr.models.ct_transformer_streaming.attention import MultiHeadedAttentionSANMwithMask
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, Conv2dSubsampling8
class EncoderLayerSANM(nn.Module):
class EncoderLayerSANM(torch.nn.Module):
    def __init__(
        self,
        in_size,
@@ -51,13 +41,13 @@
        self.feed_forward = feed_forward
        self.norm1 = LayerNorm(in_size)
        self.norm2 = LayerNorm(size)
        self.dropout = nn.Dropout(dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.in_size = in_size
        self.size = size
        self.normalize_before = normalize_before
        self.concat_after = concat_after
        if self.concat_after:
            self.concat_linear = nn.Linear(size + size, size)
            self.concat_linear = torch.nn.Linear(size + size, size)
        self.stochastic_depth_rate = stochastic_depth_rate
        self.dropout_rate = dropout_rate
@@ -156,7 +146,7 @@
@tables.register("encoder_classes", "SANMVadEncoder")
class SANMVadEncoder(nn.Module):
class SANMVadEncoder(torch.nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -306,7 +296,7 @@
            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
        self.interctc_use_conditioning = interctc_use_conditioning
        self.conditioning_layer = None
        self.dropout = nn.Dropout(dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
    def output_size(self) -> int:
        return self._output_size