kongdeqiang
5 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/transformer/embedding.py
@@ -11,6 +11,7 @@
import torch.nn.functional as F
from torch import einsum
def _pre_hook(
    state_dict,
    prefix,
@@ -64,9 +65,7 @@
                return
        pe = torch.zeros(x.size(1), self.d_model)
        if self.reverse:
            position = torch.arange(
                x.size(1) - 1, -1, -1.0, dtype=torch.float32
            ).unsqueeze(1)
            position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
        else:
            position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
@@ -170,9 +169,7 @@
        if self.gamma is None:
            self.gamma = self.d_model // 2
        assert (
            d_model % 2 == 0
        ), "d_model should be divisible by two in order to use this layer."
        assert d_model % 2 == 0, "d_model should be divisible by two in order to use this layer."
        self.w_r = torch.nn.Parameter(torch.empty(1, d_model // 2))
        self._reset()  # init the weights
@@ -185,9 +182,7 @@
            )
    def _reset(self):
        self.w_r.data = torch.normal(
            0, (1 / math.sqrt(self.gamma)), (1, self.d_model // 2)
        )
        self.w_r.data = torch.normal(0, (1 / math.sqrt(self.gamma)), (1, self.d_model // 2))
    def extend_pe(self, x):
        """Reset the positional encodings."""
@@ -384,45 +379,57 @@
        x = x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
        return self.dropout(x)
class SinusoidalPositionEncoder(torch.nn.Module):
    '''
    '''
class SinusoidalPositionEncoder(torch.nn.Module):
    """ """
    def __int__(self, d_model=80, dropout_rate=0.1):
        pass
    def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32):
    def encode(
        self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32
    ):
        batch_size = positions.size(0)
        positions = positions.type(dtype)
        device = positions.device
        log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (depth / 2 - 1)
        inv_timescales = torch.exp(torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment))
        log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (
            depth / 2 - 1
        )
        inv_timescales = torch.exp(
            torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment)
        )
        inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
        scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1])
        scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
            inv_timescales, [1, 1, -1]
        )
        encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
        return encoding.type(dtype)
    def forward(self, x):
        batch_size, timesteps, input_dim = x.size()
        positions = torch.arange(1, timesteps+1, device=x.device)[None, :]
        positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
        position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
        return x + position_encoding
class StreamSinusoidalPositionEncoder(torch.nn.Module):
    '''
    '''
class StreamSinusoidalPositionEncoder(torch.nn.Module):
    """ """
    def __int__(self, d_model=80, dropout_rate=0.1):
        pass
    def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32):
    def encode(
        self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32
    ):
        batch_size = positions.size(0)
        positions = positions.type(dtype)
        log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype)) / (depth / 2 - 1)
        inv_timescales = torch.exp(torch.arange(depth / 2).type(dtype) * (-log_timescale_increment))
        inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
        scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1])
        scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
            inv_timescales, [1, 1, -1]
        )
        encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
        return encoding.type(dtype)
@@ -432,9 +439,10 @@
        if cache is not None:
            start_idx = cache["start_idx"]
            cache["start_idx"] += timesteps
        positions = torch.arange(1, timesteps+start_idx+1)[None, :]
        positions = torch.arange(1, timesteps + start_idx + 1)[None, :]
        position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
        return x + position_encoding[:, start_idx: start_idx + timesteps]
        return x + position_encoding[:, start_idx : start_idx + timesteps]
class StreamingRelPositionalEncoding(torch.nn.Module):
    """Relative positional encoding.
@@ -444,9 +452,7 @@
        dropout_rate: Dropout rate.
    """
    def __init__(
        self, size: int, dropout_rate: float = 0.0, max_len: int = 5000
    ) -> None:
    def __init__(self, size: int, dropout_rate: float = 0.0, max_len: int = 5000) -> None:
        """Construct a RelativePositionalEncoding object."""
        super().__init__()
@@ -477,8 +483,7 @@
        position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.size, 2, dtype=torch.float32)
            * -(math.log(10000.0) / self.size)
            torch.arange(0, self.size, 2, dtype=torch.float32) * -(math.log(10000.0) / self.size)
        )
        pe_positive[:, 0::2] = torch.sin(position * div_term)
@@ -489,9 +494,7 @@
        pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
        pe_negative = pe_negative[1:].unsqueeze(0)
        self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(
            dtype=x.dtype, device=x.device
        )
        self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(dtype=x.dtype, device=x.device)
    def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
        """Compute positional encoding.
@@ -505,9 +508,7 @@
        time1 = x.size(1) + left_context
        pos_enc = self.pe[
            :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)
        ]
        pos_enc = self.pe[:, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)]
        pos_enc = self.dropout(pos_enc)
        return pos_enc
@@ -516,14 +517,17 @@
class ScaledSinuEmbedding(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = torch.nn.Parameter(torch.ones(1,))
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.scale = torch.nn.Parameter(
            torch.ones(
                1,
            )
        )
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
    def forward(self, x):
        n, device = x.shape[1], x.device
        t = torch.arange(n, device = device).type_as(self.inv_freq)
        sinu = einsum('i , j -> i j', t, self.inv_freq)
        emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
        t = torch.arange(n, device=device).type_as(self.inv_freq)
        sinu = einsum("i , j -> i j", t, self.inv_freq)
        emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1)
        return emb * self.scale