kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/transformer/utils/nets_utils.py
@@ -25,9 +25,7 @@
    elif isinstance(m, torch.Tensor):
        device = m.device
    else:
        raise TypeError(
            "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
        )
        raise TypeError("Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}")
    return x.to(device)
@@ -215,9 +213,7 @@
        if length_dim < 0:
            length_dim = xs.dim() + length_dim
        # ind = (:, None, ..., None, :, , None, ..., None)
        ind = tuple(
            slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
        )
        ind = tuple(slice(None) if i in (0, length_dim) else None for i in range(xs.dim()))
        mask = mask[ind].expand_as(xs).to(xs.device)
    return mask
@@ -342,29 +338,6 @@
    return ret
def th_accuracy(pad_outputs, pad_targets, ignore_label):
    """Calculate accuracy.
    Args:
        pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
        pad_targets (LongTensor): Target label tensors (B, Lmax).
        ignore_label (int): Ignore label id.
    Returns:
        float: Accuracy value (0.0 - 1.0).
    """
    pad_pred = pad_outputs.view(
        pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
    ).argmax(2)
    mask = pad_targets != ignore_label
    numerator = torch.sum(
        pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
    )
    denominator = torch.sum(mask)
    return float(numerator) / float(denominator)
def to_torch_tensor(x):
    """Change to torch.Tensor or ComplexTensor from numpy.ndarray.
@@ -455,9 +428,9 @@
        return subsample
    elif (
            (mode == "asr" and arch in ("rnn", "rnn-t"))
            or (mode == "mt" and arch == "rnn")
            or (mode == "st" and arch == "rnn")
        (mode == "asr" and arch in ("rnn", "rnn-t"))
        or (mode == "mt" and arch == "rnn")
        or (mode == "st" and arch == "rnn")
    ):
        subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
        if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
@@ -473,14 +446,10 @@
        return subsample
    elif mode == "asr" and arch == "rnn_mix":
        subsample = np.ones(
            train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32
        )
        subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32)
        if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
            ss = train_args.subsample.split("_")
            for j in range(
                    min(train_args.elayers_sd + train_args.elayers + 1, len(ss))
            ):
            for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
                subsample[j] = int(ss[j])
        else:
            logging.warning(
@@ -494,9 +463,7 @@
        subsample_list = []
        for idx in range(train_args.num_encs):
            subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int32)
            if train_args.etype[idx].endswith("p") and not train_args.etype[
                idx
            ].startswith("vgg"):
            if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
                ss = train_args.subsample[idx].split("_")
                for j in range(min(train_args.elayers[idx] + 1, len(ss))):
                    subsample[j] = int(ss[j])
@@ -514,9 +481,7 @@
        raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
def rename_state_dict(
        old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
):
def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
    """Replace keys of old prefix with new prefix in state dict."""
    # need this list not to break the dict iterator
    old_keys = [k for k in state_dict if k.startswith(old_prefix)]
@@ -526,6 +491,7 @@
        v = state_dict.pop(k)
        new_k = k.replace(old_prefix, new_prefix)
        state_dict[new_k] = v
class Swish(torch.nn.Module):
    """Swish activation definition.
@@ -561,6 +527,7 @@
        """Forward computation."""
        return self.swish(x)
def get_activation(act):
    """Return activation function."""
@@ -573,6 +540,7 @@
    }
    return activation_funcs[act]()
class TooShortUttError(Exception):
    """Raised when the utt is too short for subsampling.
@@ -634,9 +602,7 @@
    elif sub_factor == 6:
        return 5, 3, (((input_size - 1) // 2 - 2) // 3)
    else:
        raise ValueError(
            "subsampling_factor parameter should be set to either 2, 4 or 6."
        )
        raise ValueError("subsampling_factor parameter should be set to either 2, 4 or 6.")
def make_chunk_mask(
@@ -671,6 +637,7 @@
        mask[i, start:end] = True
    return ~mask
def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
    """Create source mask for given lengths.
@@ -756,6 +723,7 @@
    return decoder_in, target, t_len, u_len
def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
    """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
    if t.size(dim) == pad_len:
@@ -763,6 +731,4 @@
    else:
        pad_size = list(t.shape)
        pad_size[dim] = pad_len - t.size(dim)
        return torch.cat(
            [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
        )
        return torch.cat([t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim)