aky15
2023-03-21 fc9595625855be5b63f86a38ac785e49c142c0ae
funasr/models_transducer/encoder/encoder.py
@@ -134,14 +134,11 @@
            )
        mask = make_source_mask(x_len)
        if self.unified_model_training:
            x, mask = self.embed(x, mask, self.default_chunk_size)
        else:
            x, mask = self.embed(x, mask)
        pos_enc = self.pos_enc(x)
        if self.unified_model_training:
            chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
            x, mask = self.embed(x, mask, chunk_size)
            pos_enc = self.pos_enc(x)
            chunk_mask = make_chunk_mask(
                x.size(1),
                chunk_size,
@@ -178,6 +175,9 @@
            else:
                chunk_size = (chunk_size % self.short_chunk_size) + 1
            x, mask = self.embed(x, mask, chunk_size)
            pos_enc = self.pos_enc(x)
            chunk_mask = make_chunk_mask(
                x.size(1),
                chunk_size,
@@ -185,6 +185,8 @@
                device=x.device,
            )
        else:
            x, mask = self.embed(x, mask, None)
            pos_enc = self.pos_enc(x)
            chunk_mask = None
        x = self.encoders(
            x,