speech_asr
2023-03-14 141a4737f779fcf435a0ece5434b9c73eda7d2a9
update
2个文件已修改
18 ■■■■ 已修改文件
funasr/models/frontend/wav_frontend.py 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/diar.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/wav_frontend.py
@@ -6,6 +6,7 @@
import torch
import torchaudio.compliance.kaldi as kaldi
from funasr.models.frontend.abs_frontend import AbsFrontend
import funasr.models.frontend.eend_ola_feature as eend_ola_feature
from torch.nn.utils.rnn import pad_sequence
from typeguard import check_argument_types
from typing import Tuple
@@ -213,33 +214,18 @@
    def __init__(
            self,
            fs: int = 16000,
            window: str = 'hamming',
            n_mels: int = 80,
            frame_length: int = 25,
            frame_shift: int = 10,
            filter_length_min: int = -1,
            filter_length_max: int = -1,
            lfr_m: int = 1,
            lfr_n: int = 1,
            dither: float = 1.0,
            snip_edges: bool = True,
            upsacle_samples: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
        self.fs = fs
        self.window = window
        self.n_mels = n_mels
        self.frame_length = frame_length
        self.frame_shift = frame_shift
        self.filter_length_min = filter_length_min
        self.filter_length_max = filter_length_max
        self.lfr_m = lfr_m
        self.lfr_n = lfr_n
        self.cmvn_file = cmvn_file
        self.dither = dither
        self.snip_edges = snip_edges
        self.upsacle_samples = upsacle_samples
    def output_size(self) -> int:
        return self.n_mels * self.lfr_m
funasr/tasks/diar.py
@@ -23,6 +23,7 @@
from funasr.layers.label_aggregation import LabelAggregate
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.e2e_diar_sond import DiarSondModel
from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
@@ -103,6 +104,7 @@
    "model",
    classes=dict(
        sond=DiarSondModel,
        eend_ola=DiarEENDOLAModel,
    ),
    type_check=AbsESPnetModel,
    default="sond",