彭震东
2024-05-30 a64b7d8d8aeb2bb543ca703045a45f42470e9a63
funasr/frontends/fused.py
@@ -7,14 +7,10 @@
class FusedFrontends(nn.Module):
    def __init__(
        self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
    ):
    def __init__(self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000):
        super().__init__()
        self.align_method = (
            align_method  # fusing method : linear_projection only for now
        )
        self.align_method = align_method  # fusing method : linear_projection only for now
        self.proj_dim = proj_dim  # dim of the projection done on each frontend
        self.frontends = []  # list of the frontends to combine
@@ -109,9 +105,7 @@
                input_feats, feats_lens = frontend.forward(input, input_lengths)
            self.feats.append([input_feats, feats_lens])
        if (
            self.align_method == "linear_projection"
        ):  # TODO(Dan): to add other align methods
        if self.align_method == "linear_projection":  # TODO(Dan): to add other align methods
            # first step : projections
            self.feats_proj = []
@@ -141,4 +135,4 @@
        else:
            raise NotImplementedError
        return input_feats, feats_lens
        return input_feats, feats_lens