funasr/models/llm_asr_nar/adaptor.py
@@ -3,6 +3,7 @@ from funasr.register import tables @tables.register("adaptor_classes", "Linear") class Linear(nn.Module): def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs): @@ -20,7 +21,7 @@ if num_frames_to_discard > 0: x = x[:, :-num_frames_to_discard, :] seq_len = x.size(1) x = x.contiguous() x = x.view(batch_size, seq_len // self.k, dim * self.k) x = self.linear1(x)