From 55c09aeaa25b4bb88a50e09ba68fa6ff00a6d676 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 15 一月 2024 20:10:39 +0800
Subject: [PATCH] update readme, fix seaco bug

---
 funasr/models/bicif_paraformer/cif_predictor.py |   38 ++++++++++++++++++--------------------
 1 files changed, 18 insertions(+), 20 deletions(-)

diff --git a/funasr/models/bicif_paraformer/cif_predictor.py b/funasr/models/bicif_paraformer/cif_predictor.py
index 5a1488e..e7b3ba9 100644
--- a/funasr/models/bicif_paraformer/cif_predictor.py
+++ b/funasr/models/bicif_paraformer/cif_predictor.py
@@ -1,17 +1,15 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import torch
-from torch import nn
-from torch import Tensor
-import logging
-import numpy as np
-from funasr.train_utils.device_funcs import to_device
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.scama.utils import sequence_mask
-from typing import Optional, Tuple
 
 from funasr.register import tables
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
 
 
-class mae_loss(nn.Module):
+class mae_loss(torch.nn.Module):
 
     def __init__(self, normalize_length=False):
         super(mae_loss, self).__init__()
@@ -95,7 +93,7 @@
     return fires
 
 @tables.register("predictor_classes", "CifPredictorV3")
-class CifPredictorV3(nn.Module):
+class CifPredictorV3(torch.nn.Module):
     def __init__(self,
                  idim,
                  l_order,
@@ -116,9 +114,9 @@
                  ):
         super(CifPredictorV3, self).__init__()
 
-        self.pad = nn.ConstantPad1d((l_order, r_order), 0)
-        self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
-        self.cif_output = nn.Linear(idim, 1)
+        self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
+        self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
+        self.cif_output = torch.nn.Linear(idim, 1)
         self.dropout = torch.nn.Dropout(p=dropout)
         self.threshold = threshold
         self.smooth_factor = smooth_factor
@@ -131,14 +129,14 @@
         self.upsample_type = upsample_type
         self.use_cif1_cnn = use_cif1_cnn
         if self.upsample_type == 'cnn':
-            self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
-            self.cif_output2 = nn.Linear(idim, 1)
+            self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+            self.cif_output2 = torch.nn.Linear(idim, 1)
         elif self.upsample_type == 'cnn_blstm':
-            self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
-            self.blstm = nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
-            self.cif_output2 = nn.Linear(idim*2, 1)
+            self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+            self.blstm = torch.nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
+            self.cif_output2 = torch.nn.Linear(idim*2, 1)
         elif self.upsample_type == 'cnn_attn':
-            self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+            self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
             from funasr.models.transformer.encoder import EncoderLayer as TransformerEncoderLayer
             from funasr.models.transformer.attention import MultiHeadedAttention
             from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
@@ -157,7 +155,7 @@
                 True, #normalize_before,
                 False, #concat_after,
             )
-            self.cif_output2 = nn.Linear(idim, 1)
+            self.cif_output2 = torch.nn.Linear(idim, 1)
         self.smooth_factor2 = smooth_factor2
         self.noise_threshold2 = noise_threshold2
 

--
Gitblit v1.9.1