嘉渊
2023-05-17 334dec5d184b34358e5703da6bda87ed3af1fea6
funasr/models/encoder/sanm_encoder.py
@@ -355,18 +355,9 @@
    def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
        if len(cache) == 0:
            return feats
        # process last chunk
        cache["feats"] = to_device(cache["feats"], device=feats.device)
        overlap_feats = torch.cat((cache["feats"], feats), dim=1)
        if cache["is_final"]:
            cache["feats"] = overlap_feats[:, -cache["chunk_size"][0]:, :]
            if not cache["last_chunk"]:
               padding_length = sum(cache["chunk_size"]) - overlap_feats.shape[1]
               overlap_feats = overlap_feats.transpose(1, 2)
               overlap_feats = F.pad(overlap_feats, (0, padding_length))
               overlap_feats = overlap_feats.transpose(1, 2)
        else:
            cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
        return overlap_feats
    def forward_chunk(self,