haoneng.lhn
2023-09-13 5f088a67cd1b18a8260746971f32a6569e0cf2c6
funasr/models/encoder/sanm_encoder.py
@@ -114,8 +114,44 @@
        if not self.normalize_before:
            x = self.norm2(x)
        return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
        """Compute encoded features.
        Args:
            x_input (torch.Tensor): Input tensor (#batch, time, size).
            mask (torch.Tensor): Mask tensor for the input (#batch, time).
            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
        Returns:
            torch.Tensor: Output tensor (#batch, time, size).
            torch.Tensor: Mask tensor (#batch, time).
        """
        residual = x
        if self.normalize_before:
            x = self.norm1(x)
        if self.in_size == self.size:
            attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
            x = residual + attn
        else:
            x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
        if not self.normalize_before:
            x = self.norm1(x)
        residual = x
        if self.normalize_before:
            x = self.norm2(x)
        x = residual + self.feed_forward(x)
        if not self.normalize_before:
            x = self.norm2(x)
        return x, cache
class SANMEncoder(AbsEncoder):
    """
@@ -837,11 +873,56 @@
        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
        return overlap_feats
    #def forward_chunk(self,
    #                  xs_pad: torch.Tensor,
    #                  ilens: torch.Tensor,
    #                  cache: dict = None,
    #                  ctc: CTC = None,
    #                  ):
    #    xs_pad *= self.output_size() ** 0.5
    #    if self.embed is None:
    #        xs_pad = xs_pad
    #    else:
    #        xs_pad = self.embed(xs_pad, cache)
    #    if cache["tail_chunk"]:
    #        xs_pad = to_device(cache["feats"], device=xs_pad.device)
    #    else:
    #        xs_pad = self._add_overlap_chunk(xs_pad, cache)
    #    encoder_outs = self.encoders0(xs_pad, None, None, None, None)
    #    xs_pad, masks = encoder_outs[0], encoder_outs[1]
    #    intermediate_outs = []
    #    if len(self.interctc_layer_idx) == 0:
    #        encoder_outs = self.encoders(xs_pad, None, None, None, None)
    #        xs_pad, masks = encoder_outs[0], encoder_outs[1]
    #    else:
    #        for layer_idx, encoder_layer in enumerate(self.encoders):
    #            encoder_outs = encoder_layer(xs_pad, None, None, None, None)
    #            xs_pad, masks = encoder_outs[0], encoder_outs[1]
    #            if layer_idx + 1 in self.interctc_layer_idx:
    #                encoder_out = xs_pad
    #                # intermediate outputs are also normalized
    #                if self.normalize_before:
    #                    encoder_out = self.after_norm(encoder_out)
    #                intermediate_outs.append((layer_idx + 1, encoder_out))
    #                if self.interctc_use_conditioning:
    #                    ctc_out = ctc.softmax(encoder_out)
    #                    xs_pad = xs_pad + self.conditioning_layer(ctc_out)
    #    if self.normalize_before:
    #        xs_pad = self.after_norm(xs_pad)
    #    if len(intermediate_outs) > 0:
    #        return (xs_pad, intermediate_outs), None, None
    #    return xs_pad, ilens, None
    def forward_chunk(self,
                      xs_pad: torch.Tensor,
                      ilens: torch.Tensor,
                      cache: dict = None,
                      ctc: CTC = None,
                      ):
        xs_pad *= self.output_size() ** 0.5
        if self.embed is None:
@@ -852,34 +933,25 @@
            xs_pad = to_device(cache["feats"], device=xs_pad.device)
        else:
            xs_pad = self._add_overlap_chunk(xs_pad, cache)
        encoder_outs = self.encoders0(xs_pad, None, None, None, None)
        xs_pad, masks = encoder_outs[0], encoder_outs[1]
        intermediate_outs = []
        if len(self.interctc_layer_idx) == 0:
            encoder_outs = self.encoders(xs_pad, None, None, None, None)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]
        if cache["opt"] is None:
            cache_layer_num = len(self.encoders0) + len(self.encoders)
            new_cache = [None] * cache_layer_num
        else:
            for layer_idx, encoder_layer in enumerate(self.encoders):
                encoder_outs = encoder_layer(xs_pad, None, None, None, None)
                xs_pad, masks = encoder_outs[0], encoder_outs[1]
                if layer_idx + 1 in self.interctc_layer_idx:
                    encoder_out = xs_pad
            new_cache = cache["opt"]
                    # intermediate outputs are also normalized
                    if self.normalize_before:
                        encoder_out = self.after_norm(encoder_out)
        for layer_idx, encoder_layer in enumerate(self.encoders0):
            encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
            xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
                    intermediate_outs.append((layer_idx + 1, encoder_out))
                    if self.interctc_use_conditioning:
                        ctc_out = ctc.softmax(encoder_out)
                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
        for layer_idx, encoder_layer in enumerate(self.encoders):
            encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
            xs_pad, new_cache[layer_idx+1] = encoder_outs[0], encoder_outs[1]
        if self.normalize_before:
            xs_pad = self.after_norm(xs_pad)
        if cache["encoder_chunk_look_back"] > 0:
            cache["opt"] = new_cache
        if len(intermediate_outs) > 0:
            return (xs_pad, intermediate_outs), None, None
        return xs_pad, ilens, None
    def gen_tf2torch_map_dict(self):