From 141a4737f779fcf435a0ece5434b9c73eda7d2a9 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 14 三月 2023 15:54:28 +0800
Subject: [PATCH] update

---
 funasr/tasks/diar.py                   |    2 ++
 funasr/models/frontend/wav_frontend.py |   16 +---------------
 2 files changed, 3 insertions(+), 15 deletions(-)

diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 4e52b90..6af7074 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -6,6 +6,7 @@
 import torch
 import torchaudio.compliance.kaldi as kaldi
 from funasr.models.frontend.abs_frontend import AbsFrontend
+import funasr.models.frontend.eend_ola_feature as eend_ola_feature
 from torch.nn.utils.rnn import pad_sequence
 from typeguard import check_argument_types
 from typing import Tuple
@@ -213,33 +214,18 @@
     def __init__(
             self,
             fs: int = 16000,
-            window: str = 'hamming',
-            n_mels: int = 80,
             frame_length: int = 25,
             frame_shift: int = 10,
-            filter_length_min: int = -1,
-            filter_length_max: int = -1,
             lfr_m: int = 1,
             lfr_n: int = 1,
-            dither: float = 1.0,
-            snip_edges: bool = True,
-            upsacle_samples: bool = True,
     ):
         assert check_argument_types()
         super().__init__()
         self.fs = fs
-        self.window = window
-        self.n_mels = n_mels
         self.frame_length = frame_length
         self.frame_shift = frame_shift
-        self.filter_length_min = filter_length_min
-        self.filter_length_max = filter_length_max
         self.lfr_m = lfr_m
         self.lfr_n = lfr_n
-        self.cmvn_file = cmvn_file
-        self.dither = dither
-        self.snip_edges = snip_edges
-        self.upsacle_samples = upsacle_samples
 
     def output_size(self) -> int:
         return self.n_mels * self.lfr_m
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 953ab82..ae7ee9b 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -23,6 +23,7 @@
 from funasr.layers.label_aggregation import LabelAggregate
 from funasr.layers.utterance_mvn import UtteranceMVN
 from funasr.models.e2e_diar_sond import DiarSondModel
+from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
 from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.encoder.conformer_encoder import ConformerEncoder
 from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
@@ -103,6 +104,7 @@
     "model",
     classes=dict(
         sond=DiarSondModel,
+        eend_ola=DiarEENDOLAModel,
     ),
     type_check=AbsESPnetModel,
     default="sond",

--
Gitblit v1.9.1