AldarisX
2025-04-07 d43d0853dcf3a1db04302c7b527e92ace3ccfb55
funasr/frontends/fused.py
@@ -7,14 +7,10 @@
class FusedFrontends(nn.Module):
    def __init__(
        self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
    ):
    def __init__(self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000):
        super().__init__()
        self.align_method = (
            align_method  # fusing method : linear_projection only for now
        )
        self.align_method = align_method  # fusing method : linear_projection only for now
        self.proj_dim = proj_dim  # dim of the projection done on each frontend
        self.frontends = []  # list of the frontends to combine
@@ -82,6 +78,8 @@
        self.factors = [frontend.hop_length // self.gcd for frontend in self.frontends]
        if torch.cuda.is_available():
            dev = "cuda"
        elif torch.xpu.is_available():
            dev = "xpu"
        else:
            dev = "cpu"
        if self.align_method == "linear_projection":
@@ -109,9 +107,7 @@
                input_feats, feats_lens = frontend.forward(input, input_lengths)
            self.feats.append([input_feats, feats_lens])
        if (
            self.align_method == "linear_projection"
        ):  # TODO(Dan): to add other align methods
        if self.align_method == "linear_projection":  # TODO(Dan): to add other align methods
            # first step : projections
            self.feats_proj = []
@@ -141,4 +137,4 @@
        else:
            raise NotImplementedError
        return input_feats, feats_lens
        return input_feats, feats_lens