游雁
2024-06-12 2ac79cd3f312e485f3fc4f0e63313cc8a3e0bfc6
funasr/models/emotion2vec/audio.py
@@ -3,29 +3,24 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Tuple
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.nn as nn
from functools import partial
import torch.nn.functional as F
from typing import Callable, Dict
from typing import Callable, Dict, Optional
from funasr.models.emotion2vec.fairseq_modules import (
    LayerNorm,
    SamePad,
    TransposeLast,
    ConvFeatureExtractionModel,
)
from funasr.models.emotion2vec.base import ModalitySpecificEncoder, get_alibi_bias
from funasr.models.emotion2vec.modules import Modality, BlockEncoder, Decoder1d
from funasr.models.emotion2vec.base import ModalitySpecificEncoder, get_alibi_bias
class AudioEncoder(ModalitySpecificEncoder):
    def __init__(
        self,
@@ -95,9 +90,7 @@
        )
        decoder = (
            Decoder1d(modality_cfg.decoder, embed_dim)
            if modality_cfg.decoder is not None
            else None
            Decoder1d(modality_cfg.decoder, embed_dim) if modality_cfg.decoder is not None else None
        )
        alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
@@ -148,13 +141,9 @@
                        output_lengths - 1,
                    )
                ] = 1
                padding_mask = (
                    1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
                ).bool()
                padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
            else:
                padding_mask = torch.zeros(
                    x.shape[:2], dtype=torch.bool, device=x.device
                )
                padding_mask = torch.zeros(x.shape[:2], dtype=torch.bool, device=x.device)
        return padding_mask