zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/frontends/default.py
@@ -6,6 +6,7 @@
import numpy as np
import torch
import torch.nn as nn
try:
    from torch_complex.tensor import ComplexTensor
except:
@@ -128,9 +129,7 @@
        return input_feats, feats_lens
    def _compute_stft(
            self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> torch.Tensor:
    def _compute_stft(self, input: torch.Tensor, input_lengths: torch.Tensor) -> torch.Tensor:
        input_stft, feats_lens = self.stft(input, input_lengths)
        assert input_stft.dim() >= 4, input_stft.shape
@@ -170,7 +169,7 @@
            lfr_m: int = 1,
            lfr_n: int = 1,
            cmvn_file: str = None,
            mc: bool = True
        mc: bool = True,
    ):
        super().__init__()
        # Deepcopy (In general, dict shouldn't be used as default arg)
@@ -183,8 +182,7 @@
            self.hop_length = self.hop_length
        else:
            logging.error(
                "Only one of (win_length, hop_length) and (frame_length, frame_shift)"
                "can be set."
                "Only one of (win_length, hop_length) and (frame_length, frame_shift)" "can be set."
            )
            exit(1)
@@ -277,7 +275,9 @@
            if input_feats.dim() ==4:
                bt = input_feats.size(0)
                channel_size = input_feats.size(2)
                input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
                input_feats = (
                    input_feats.transpose(1, 2).reshape(bt * channel_size, -1, 80).contiguous()
                )
                feats_lens = feats_lens.repeat(1,channel_size).squeeze()
            else:
                channel_size = 1
@@ -304,9 +304,7 @@
            return input_feats, feats_lens
    def _compute_stft(
            self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> torch.Tensor:
    def _compute_stft(self, input: torch.Tensor, input_lengths: torch.Tensor) -> torch.Tensor:
        input_stft, feats_lens = self.stft(input, input_lengths)
        assert input_stft.dim() >= 4, input_stft.shape
@@ -319,21 +317,21 @@
        return input_stft, feats_lens
    def _load_cmvn(self, cmvn_file):
        with open(cmvn_file, 'r', encoding='utf-8') as f:
        with open(cmvn_file, "r", encoding="utf-8") as f:
            lines = f.readlines()
        means_list = []
        vars_list = []
        for i in range(len(lines)):
            line_item = lines[i].split()
            if line_item[0] == '<AddShift>':
            if line_item[0] == "<AddShift>":
                line_item = lines[i + 1].split()
                if line_item[0] == '<LearnRateCoef>':
                if line_item[0] == "<LearnRateCoef>":
                    add_shift_line = line_item[3:(len(line_item) - 1)]
                    means_list = list(add_shift_line)
                    continue
            elif line_item[0] == '<Rescale>':
            elif line_item[0] == "<Rescale>":
                line_item = lines[i + 1].split()
                if line_item[0] == '<LearnRateCoef>':
                if line_item[0] == "<LearnRateCoef>":
                    rescale_line = line_item[3:(len(line_item) - 1)]
                    vars_list = list(rescale_line)
                    continue