zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/data2vec/data2vec_encoder.py
@@ -133,7 +133,9 @@
        self.mask_other = mask_other
        self.no_mask_overlap = no_mask_overlap
        self.mask_min_space = mask_min_space
        self.require_same_masks = require_same_masks  # if set as True, collate_fn should be clipping
        self.require_same_masks = (
            require_same_masks  # if set as True, collate_fn should be clipping
        )
        self.mask_dropout = mask_dropout
        ## channel masking
        self.mask_channel_length = mask_channel_length
@@ -259,10 +261,7 @@
                min_space=self.mask_channel_min_space,
            )
            mask_channel_indices = (
                torch.from_numpy(mask_channel_indices)
                    .to(x.device)
                    .unsqueeze(1)
                    .expand(-1, T, -1)
                torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1)
            )
            x[mask_channel_indices] = 0
@@ -446,16 +445,12 @@
            if self.batch_norm_target_layer:
                target_layer_results = [
                    F.batch_norm(
                        tl.float(), running_mean=None, running_var=None, training=True
                    )
                    F.batch_norm(tl.float(), running_mean=None, running_var=None, training=True)
                    for tl in target_layer_results
                ]
            if self.instance_norm_target_layer:
                target_layer_results = [
                    F.instance_norm(tl.float()) for tl in target_layer_results
                ]
                target_layer_results = [F.instance_norm(tl.float()) for tl in target_layer_results]
            if permuted:
                target_layer_results = [
@@ -464,14 +459,12 @@
            if self.group_norm_target_layer:
                target_layer_results = [
                    F.layer_norm(tl.float(), tl.shape[-2:])
                    for tl in target_layer_results
                    F.layer_norm(tl.float(), tl.shape[-2:]) for tl in target_layer_results
                ]
            if self.layer_norm_target_layer:
                target_layer_results = [
                    F.layer_norm(tl.float(), tl.shape[-1:])
                    for tl in target_layer_results
                    F.layer_norm(tl.float(), tl.shape[-1:]) for tl in target_layer_results
                ]
            y = sum(target_layer_results) / len(target_layer_results)
@@ -521,9 +514,7 @@
                f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
            )
        if self.num_updates > 5000 and result["pred_var"] < self.min_pred_var:
            logging.error(
                f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
            )
            logging.error(f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting")
            raise Exception(
                f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
            )
@@ -550,9 +541,7 @@
        else:
            return torch.sqrt(y.var(dim=0) + 1e-6).mean()
    def extract_features(
            self, xs_pad, ilens, mask=False, layer=None
    ):
    def extract_features(self, xs_pad, ilens, mask=False, layer=None):
        res = self.forward(
            xs_pad,
            ilens,