From 0e622e694e6cb4459955f1e5942a7c53349ce640 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 21:58:14 +0800
Subject: [PATCH] funasr2
---
funasr/frontends/wav_frontend.py | 18 ++++++++++++------
1 files changed, 12 insertions(+), 6 deletions(-)
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/frontends/wav_frontend.py
similarity index 97%
rename from funasr/models/frontend/wav_frontend.py
rename to funasr/frontends/wav_frontend.py
index ac16065..4866fa1 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/frontends/wav_frontend.py
@@ -4,11 +4,13 @@
import numpy as np
import torch
+import torch.nn as nn
import torchaudio.compliance.kaldi as kaldi
from torch.nn.utils.rnn import pad_sequence
-import funasr.models.frontend.eend_ola_feature as eend_ola_feature
-from funasr.models.frontend.abs_frontend import AbsFrontend
+import funasr.frontends.eend_ola_feature as eend_ola_feature
+from funasr.utils.register import register_class
+
def load_cmvn(cmvn_file):
@@ -73,8 +75,8 @@
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
-
-class WavFrontend(AbsFrontend):
+@register_class("frontend_classes", "WavFrontend")
+class WavFrontend(nn.Module):
"""Conventional frontend structure for ASR.
"""
@@ -93,6 +95,7 @@
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
+ **kwargs,
):
super().__init__()
self.fs = fs
@@ -208,7 +211,8 @@
return feats_pad, feats_lens
-class WavFrontendOnline(AbsFrontend):
+@register_class("frontend_classes", "WavFrontendOnline")
+class WavFrontendOnline(nn.Module):
"""Conventional frontend structure for streaming ASR/VAD.
"""
@@ -227,6 +231,7 @@
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
+ **kwargs,
):
super().__init__()
self.fs = fs
@@ -454,7 +459,7 @@
self.lfr_splice_cache = []
-class WavFrontendMel23(AbsFrontend):
+class WavFrontendMel23(nn.Module):
"""Conventional frontend structure for ASR.
"""
@@ -465,6 +470,7 @@
frame_shift: int = 10,
lfr_m: int = 1,
lfr_n: int = 1,
+ **kwargs,
):
super().__init__()
self.fs = fs
--
Gitblit v1.9.1