shixian.shi
2024-01-15 97d648c255316ec1fff5d82e46749076faabdd2d
funasr/models/bicif_paraformer/cif_predictor.py
@@ -1,17 +1,15 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
from torch import nn
from torch import Tensor
import logging
import numpy as np
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.scama.utils import sequence_mask
from typing import Optional, Tuple
from funasr.register import tables
from funasr.models.transformer.utils.nets_utils import make_pad_mask
class mae_loss(nn.Module):
class mae_loss(torch.nn.Module):
    def __init__(self, normalize_length=False):
        super(mae_loss, self).__init__()
@@ -95,7 +93,7 @@
    return fires
@tables.register("predictor_classes", "CifPredictorV3")
class CifPredictorV3(nn.Module):
class CifPredictorV3(torch.nn.Module):
    def __init__(self,
                 idim,
                 l_order,
@@ -116,9 +114,9 @@
                 ):
        super(CifPredictorV3, self).__init__()
        self.pad = nn.ConstantPad1d((l_order, r_order), 0)
        self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
        self.cif_output = nn.Linear(idim, 1)
        self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
        self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
        self.cif_output = torch.nn.Linear(idim, 1)
        self.dropout = torch.nn.Dropout(p=dropout)
        self.threshold = threshold
        self.smooth_factor = smooth_factor
@@ -131,14 +129,14 @@
        self.upsample_type = upsample_type
        self.use_cif1_cnn = use_cif1_cnn
        if self.upsample_type == 'cnn':
            self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
            self.cif_output2 = nn.Linear(idim, 1)
            self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
            self.cif_output2 = torch.nn.Linear(idim, 1)
        elif self.upsample_type == 'cnn_blstm':
            self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
            self.blstm = nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
            self.cif_output2 = nn.Linear(idim*2, 1)
            self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
            self.blstm = torch.nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
            self.cif_output2 = torch.nn.Linear(idim*2, 1)
        elif self.upsample_type == 'cnn_attn':
            self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
            self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
            from funasr.models.transformer.encoder import EncoderLayer as TransformerEncoderLayer
            from funasr.models.transformer.attention import MultiHeadedAttention
            from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
@@ -157,7 +155,7 @@
                True, #normalize_before,
                False, #concat_after,
            )
            self.cif_output2 = nn.Linear(idim, 1)
            self.cif_output2 = torch.nn.Linear(idim, 1)
        self.smooth_factor2 = smooth_factor2
        self.noise_threshold2 = noise_threshold2