From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/data2vec/data2vec_encoder.py | 204 ++++++++++++++++++++++++---------------------------
1 files changed, 96 insertions(+), 108 deletions(-)
diff --git a/funasr/models/data2vec/data2vec_encoder.py b/funasr/models/data2vec/data2vec_encoder.py
index 1bcb639..f591dd6 100644
--- a/funasr/models/data2vec/data2vec_encoder.py
+++ b/funasr/models/data2vec/data2vec_encoder.py
@@ -11,7 +11,6 @@
import torch.nn as nn
import torch.nn.functional as F
-from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.data2vec.data_utils import compute_mask_indices
from funasr.models.data2vec.ema_module import EMAModule
from funasr.models.data2vec.grad_multiply import GradMultiply
@@ -28,73 +27,73 @@
return end - r * pct_remaining
-class Data2VecEncoder(AbsEncoder):
+class Data2VecEncoder(nn.Module):
def __init__(
- self,
- # for ConvFeatureExtractionModel
- input_size: int = None,
- extractor_mode: str = None,
- conv_feature_layers: str = "[(512,2,2)] + [(512,2,2)]",
- # for Transformer Encoder
- ## model architecture
- layer_type: str = "transformer",
- layer_norm_first: bool = False,
- encoder_layers: int = 12,
- encoder_embed_dim: int = 768,
- encoder_ffn_embed_dim: int = 3072,
- encoder_attention_heads: int = 12,
- activation_fn: str = "gelu",
- ## dropouts
- dropout: float = 0.1,
- attention_dropout: float = 0.1,
- activation_dropout: float = 0.0,
- encoder_layerdrop: float = 0.0,
- dropout_input: float = 0.0,
- dropout_features: float = 0.0,
- ## grad settings
- feature_grad_mult: float = 1.0,
- ## masking
- mask_prob: float = 0.65,
- mask_length: int = 10,
- mask_selection: str = "static",
- mask_other: int = 0,
- no_mask_overlap: bool = False,
- mask_min_space: int = 1,
- require_same_masks: bool = True, # if set as True, collate_fn should be clipping
- mask_dropout: float = 0.0,
- ## channel masking
- mask_channel_length: int = 10,
- mask_channel_prob: float = 0.0,
- mask_channel_before: bool = False,
- mask_channel_selection: str = "static",
- mask_channel_other: int = 0,
- no_mask_channel_overlap: bool = False,
- mask_channel_min_space: int = 1,
- ## positional embeddings
- conv_pos: int = 128,
- conv_pos_groups: int = 16,
- pos_conv_depth: int = 1,
- max_positions: int = 100000,
- # EMA module
- average_top_k_layers: int = 8,
- layer_norm_target_layer: bool = False,
- instance_norm_target_layer: bool = False,
- instance_norm_targets: bool = False,
- layer_norm_targets: bool = False,
- batch_norm_target_layer: bool = False,
- group_norm_target_layer: bool = False,
- ema_decay: float = 0.999,
- ema_end_decay: float = 0.9999,
- ema_anneal_end_step: int = 100000,
- ema_transformer_only: bool = True,
- ema_layers_only: bool = True,
- min_target_var: float = 0.1,
- min_pred_var: float = 0.01,
- # Loss
- loss_beta: float = 0.0,
- loss_scale: float = None,
- # FP16 optimization
- required_seq_len_multiple: int = 2,
+ self,
+ # for ConvFeatureExtractionModel
+ input_size: int = None,
+ extractor_mode: str = None,
+ conv_feature_layers: str = "[(512,2,2)] + [(512,2,2)]",
+ # for Transformer Encoder
+ ## model architecture
+ layer_type: str = "transformer",
+ layer_norm_first: bool = False,
+ encoder_layers: int = 12,
+ encoder_embed_dim: int = 768,
+ encoder_ffn_embed_dim: int = 3072,
+ encoder_attention_heads: int = 12,
+ activation_fn: str = "gelu",
+ ## dropouts
+ dropout: float = 0.1,
+ attention_dropout: float = 0.1,
+ activation_dropout: float = 0.0,
+ encoder_layerdrop: float = 0.0,
+ dropout_input: float = 0.0,
+ dropout_features: float = 0.0,
+ ## grad settings
+ feature_grad_mult: float = 1.0,
+ ## masking
+ mask_prob: float = 0.65,
+ mask_length: int = 10,
+ mask_selection: str = "static",
+ mask_other: int = 0,
+ no_mask_overlap: bool = False,
+ mask_min_space: int = 1,
+ require_same_masks: bool = True, # if set as True, collate_fn should be clipping
+ mask_dropout: float = 0.0,
+ ## channel masking
+ mask_channel_length: int = 10,
+ mask_channel_prob: float = 0.0,
+ mask_channel_before: bool = False,
+ mask_channel_selection: str = "static",
+ mask_channel_other: int = 0,
+ no_mask_channel_overlap: bool = False,
+ mask_channel_min_space: int = 1,
+ ## positional embeddings
+ conv_pos: int = 128,
+ conv_pos_groups: int = 16,
+ pos_conv_depth: int = 1,
+ max_positions: int = 100000,
+ # EMA module
+ average_top_k_layers: int = 8,
+ layer_norm_target_layer: bool = False,
+ instance_norm_target_layer: bool = False,
+ instance_norm_targets: bool = False,
+ layer_norm_targets: bool = False,
+ batch_norm_target_layer: bool = False,
+ group_norm_target_layer: bool = False,
+ ema_decay: float = 0.999,
+ ema_end_decay: float = 0.9999,
+ ema_anneal_end_step: int = 100000,
+ ema_transformer_only: bool = True,
+ ema_layers_only: bool = True,
+ min_target_var: float = 0.1,
+ min_pred_var: float = 0.01,
+ # Loss
+ loss_beta: float = 0.0,
+ loss_scale: float = None,
+ # FP16 optimization
+ required_seq_len_multiple: int = 2,
):
super().__init__()
@@ -134,7 +133,9 @@
self.mask_other = mask_other
self.no_mask_overlap = no_mask_overlap
self.mask_min_space = mask_min_space
- self.require_same_masks = require_same_masks # if set as True, collate_fn should be clipping
+ self.require_same_masks = (
+ require_same_masks # if set as True, collate_fn should be clipping
+ )
self.mask_dropout = mask_dropout
## channel masking
self.mask_channel_length = mask_channel_length
@@ -240,11 +241,11 @@
self.num_updates = num_updates
def apply_mask(
- self,
- x,
- padding_mask,
- mask_indices=None,
- mask_channel_indices=None,
+ self,
+ x,
+ padding_mask,
+ mask_indices=None,
+ mask_channel_indices=None,
):
B, T, C = x.shape
@@ -260,10 +261,7 @@
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
- torch.from_numpy(mask_channel_indices)
- .to(x.device)
- .unsqueeze(1)
- .expand(-1, T, -1)
+ torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1)
)
x[mask_channel_indices] = 0
@@ -301,9 +299,9 @@
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
- .to(x.device)
- .unsqueeze(1)
- .expand(-1, T, -1)
+ .to(x.device)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
)
x[mask_channel_indices] = 0
@@ -327,15 +325,15 @@
return input_lengths.to(torch.long)
def forward(
- self,
- xs_pad,
- ilens=None,
- mask=False,
- features_only=True,
- layer=None,
- mask_indices=None,
- mask_channel_indices=None,
- padding_count=None,
+ self,
+ xs_pad,
+ ilens=None,
+ mask=False,
+ features_only=True,
+ layer=None,
+ mask_indices=None,
+ mask_channel_indices=None,
+ padding_count=None,
):
# create padding_mask by ilens
if ilens is not None:
@@ -447,16 +445,12 @@
if self.batch_norm_target_layer:
target_layer_results = [
- F.batch_norm(
- tl.float(), running_mean=None, running_var=None, training=True
- )
+ F.batch_norm(tl.float(), running_mean=None, running_var=None, training=True)
for tl in target_layer_results
]
if self.instance_norm_target_layer:
- target_layer_results = [
- F.instance_norm(tl.float()) for tl in target_layer_results
- ]
+ target_layer_results = [F.instance_norm(tl.float()) for tl in target_layer_results]
if permuted:
target_layer_results = [
@@ -465,14 +459,12 @@
if self.group_norm_target_layer:
target_layer_results = [
- F.layer_norm(tl.float(), tl.shape[-2:])
- for tl in target_layer_results
+ F.layer_norm(tl.float(), tl.shape[-2:]) for tl in target_layer_results
]
if self.layer_norm_target_layer:
target_layer_results = [
- F.layer_norm(tl.float(), tl.shape[-1:])
- for tl in target_layer_results
+ F.layer_norm(tl.float(), tl.shape[-1:]) for tl in target_layer_results
]
y = sum(target_layer_results) / len(target_layer_results)
@@ -522,9 +514,7 @@
f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
)
if self.num_updates > 5000 and result["pred_var"] < self.min_pred_var:
- logging.error(
- f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
- )
+ logging.error(f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting")
raise Exception(
f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
)
@@ -540,20 +530,18 @@
if dist.is_initialized():
zc = torch.tensor(y.size(0)).cuda()
zs = y.sum(dim=0)
- zss = (y ** 2).sum(dim=0)
+ zss = (y**2).sum(dim=0)
dist.all_reduce(zc)
dist.all_reduce(zs)
dist.all_reduce(zss)
- var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
+ var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1))
return torch.sqrt(var + 1e-6).mean()
else:
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
- def extract_features(
- self, xs_pad, ilens, mask=False, layer=None
- ):
+ def extract_features(self, xs_pad, ilens, mask=False, layer=None):
res = self.forward(
xs_pad,
ilens,
@@ -572,4 +560,4 @@
)
def output_size(self) -> int:
- return self.encoder_embed_dim
\ No newline at end of file
+ return self.encoder_embed_dim
--
Gitblit v1.9.1