From 55c09aeaa25b4bb88a50e09ba68fa6ff00a6d676 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 15 一月 2024 20:10:39 +0800
Subject: [PATCH] update readme, fix seaco bug
---
funasr/models/bicif_paraformer/cif_predictor.py | 38 ++++++++++++++++++--------------------
1 files changed, 18 insertions(+), 20 deletions(-)
diff --git a/funasr/models/bicif_paraformer/cif_predictor.py b/funasr/models/bicif_paraformer/cif_predictor.py
index 5a1488e..e7b3ba9 100644
--- a/funasr/models/bicif_paraformer/cif_predictor.py
+++ b/funasr/models/bicif_paraformer/cif_predictor.py
@@ -1,17 +1,15 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import torch
-from torch import nn
-from torch import Tensor
-import logging
-import numpy as np
-from funasr.train_utils.device_funcs import to_device
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.scama.utils import sequence_mask
-from typing import Optional, Tuple
from funasr.register import tables
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
-class mae_loss(nn.Module):
+class mae_loss(torch.nn.Module):
def __init__(self, normalize_length=False):
super(mae_loss, self).__init__()
@@ -95,7 +93,7 @@
return fires
@tables.register("predictor_classes", "CifPredictorV3")
-class CifPredictorV3(nn.Module):
+class CifPredictorV3(torch.nn.Module):
def __init__(self,
idim,
l_order,
@@ -116,9 +114,9 @@
):
super(CifPredictorV3, self).__init__()
- self.pad = nn.ConstantPad1d((l_order, r_order), 0)
- self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
- self.cif_output = nn.Linear(idim, 1)
+ self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
+ self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
+ self.cif_output = torch.nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
@@ -131,14 +129,14 @@
self.upsample_type = upsample_type
self.use_cif1_cnn = use_cif1_cnn
if self.upsample_type == 'cnn':
- self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
- self.cif_output2 = nn.Linear(idim, 1)
+ self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+ self.cif_output2 = torch.nn.Linear(idim, 1)
elif self.upsample_type == 'cnn_blstm':
- self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
- self.blstm = nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
- self.cif_output2 = nn.Linear(idim*2, 1)
+ self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+ self.blstm = torch.nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
+ self.cif_output2 = torch.nn.Linear(idim*2, 1)
elif self.upsample_type == 'cnn_attn':
- self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+ self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
from funasr.models.transformer.encoder import EncoderLayer as TransformerEncoderLayer
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
@@ -157,7 +155,7 @@
True, #normalize_before,
False, #concat_after,
)
- self.cif_output2 = nn.Linear(idim, 1)
+ self.cif_output2 = torch.nn.Linear(idim, 1)
self.smooth_factor2 = smooth_factor2
self.noise_threshold2 = noise_threshold2
--
Gitblit v1.9.1