From 141a4737f779fcf435a0ece5434b9c73eda7d2a9 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 14 三月 2023 15:54:28 +0800
Subject: [PATCH] update
---
funasr/tasks/diar.py | 2 ++
funasr/models/frontend/wav_frontend.py | 16 +---------------
2 files changed, 3 insertions(+), 15 deletions(-)
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 4e52b90..6af7074 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -6,6 +6,7 @@
import torch
import torchaudio.compliance.kaldi as kaldi
from funasr.models.frontend.abs_frontend import AbsFrontend
+import funasr.models.frontend.eend_ola_feature as eend_ola_feature
from torch.nn.utils.rnn import pad_sequence
from typeguard import check_argument_types
from typing import Tuple
@@ -213,33 +214,18 @@
def __init__(
self,
fs: int = 16000,
- window: str = 'hamming',
- n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
- filter_length_min: int = -1,
- filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
- dither: float = 1.0,
- snip_edges: bool = True,
- upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
self.fs = fs
- self.window = window
- self.n_mels = n_mels
self.frame_length = frame_length
self.frame_shift = frame_shift
- self.filter_length_min = filter_length_min
- self.filter_length_max = filter_length_max
self.lfr_m = lfr_m
self.lfr_n = lfr_n
- self.cmvn_file = cmvn_file
- self.dither = dither
- self.snip_edges = snip_edges
- self.upsacle_samples = upsacle_samples
def output_size(self) -> int:
return self.n_mels * self.lfr_m
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 953ab82..ae7ee9b 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -23,6 +23,7 @@
from funasr.layers.label_aggregation import LabelAggregate
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.e2e_diar_sond import DiarSondModel
+from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
@@ -103,6 +104,7 @@
"model",
classes=dict(
sond=DiarSondModel,
+ eend_ola=DiarEENDOLAModel,
),
type_check=AbsESPnetModel,
default="sond",
--
Gitblit v1.9.1