From 0e622e694e6cb4459955f1e5942a7c53349ce640 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 21:58:14 +0800
Subject: [PATCH] funasr2

---
 funasr/frontends/wav_frontend.py |   18 ++++++++++++------
 1 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/frontends/wav_frontend.py
similarity index 97%
rename from funasr/models/frontend/wav_frontend.py
rename to funasr/frontends/wav_frontend.py
index ac16065..4866fa1 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/frontends/wav_frontend.py
@@ -4,11 +4,13 @@
 
 import numpy as np
 import torch
+import torch.nn as nn
 import torchaudio.compliance.kaldi as kaldi
 from torch.nn.utils.rnn import pad_sequence
 
-import funasr.models.frontend.eend_ola_feature as eend_ola_feature
-from funasr.models.frontend.abs_frontend import AbsFrontend
+import funasr.frontends.eend_ola_feature as eend_ola_feature
+from funasr.utils.register import register_class
+
 
 
 def load_cmvn(cmvn_file):
@@ -73,8 +75,8 @@
     LFR_outputs = torch.vstack(LFR_inputs)
     return LFR_outputs.type(torch.float32)
 
-
-class WavFrontend(AbsFrontend):
+@register_class("frontend_classes", "WavFrontend")
+class WavFrontend(nn.Module):
     """Conventional frontend structure for ASR.
     """
 
@@ -93,6 +95,7 @@
             dither: float = 1.0,
             snip_edges: bool = True,
             upsacle_samples: bool = True,
+            **kwargs,
     ):
         super().__init__()
         self.fs = fs
@@ -208,7 +211,8 @@
         return feats_pad, feats_lens
 
 
-class WavFrontendOnline(AbsFrontend):
+@register_class("frontend_classes", "WavFrontendOnline")
+class WavFrontendOnline(nn.Module):
     """Conventional frontend structure for streaming ASR/VAD.
     """
 
@@ -227,6 +231,7 @@
             dither: float = 1.0,
             snip_edges: bool = True,
             upsacle_samples: bool = True,
+            **kwargs,
     ):
         super().__init__()
         self.fs = fs
@@ -454,7 +459,7 @@
         self.lfr_splice_cache = []
 
 
-class WavFrontendMel23(AbsFrontend):
+class WavFrontendMel23(nn.Module):
     """Conventional frontend structure for ASR.
     """
 
@@ -465,6 +470,7 @@
             frame_shift: int = 10,
             lfr_m: int = 1,
             lfr_n: int = 1,
+            **kwargs,
     ):
         super().__init__()
         self.fs = fs

--
Gitblit v1.9.1