| | |
| | | #!/usr/bin/env python3
|
| | | # -*- encoding: utf-8 -*-
|
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
| | | # MIT License (https://opensource.org/licenses/MIT)
|
| | |
|
| | | import torch
|
| | | from torch import nn
|
| | | from torch import Tensor
|
| | | import logging
|
| | | import numpy as np
|
| | | from funasr.train_utils.device_funcs import to_device
|
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
| | | from funasr.models.scama.utils import sequence_mask
|
| | | from typing import Optional, Tuple
|
| | |
|
| | | from funasr.register import tables
|
| | | from funasr.train_utils.device_funcs import to_device
|
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
| | |
|
| | |
|
| | | @tables.register("predictor_classes", "CifPredictor")
|
| | | class CifPredictor(nn.Module):
|
| | | class CifPredictor(torch.nn.Module):
|
| | | def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
|
| | | super().__init__()
|
| | |
|
| | | self.pad = nn.ConstantPad1d((l_order, r_order), 0)
|
| | | self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
|
| | | self.cif_output = nn.Linear(idim, 1)
|
| | | self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
|
| | | self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
|
| | | self.cif_output = torch.nn.Linear(idim, 1)
|
| | | self.dropout = torch.nn.Dropout(p=dropout)
|
| | | self.threshold = threshold
|
| | | self.smooth_factor = smooth_factor
|
| | |
| | | return predictor_alignments.detach(), predictor_alignments_length.detach()
|
| | |
|
| | | @tables.register("predictor_classes", "CifPredictorV2")
|
| | | class CifPredictorV2(nn.Module):
|
| | | class CifPredictorV2(torch.nn.Module):
|
| | | def __init__(self,
|
| | | idim,
|
| | | l_order,
|
| | |
| | | ):
|
| | | super(CifPredictorV2, self).__init__()
|
| | |
|
| | | self.pad = nn.ConstantPad1d((l_order, r_order), 0)
|
| | | self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
|
| | | self.cif_output = nn.Linear(idim, 1)
|
| | | self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
|
| | | self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
|
| | | self.cif_output = torch.nn.Linear(idim, 1)
|
| | | self.dropout = torch.nn.Dropout(p=dropout)
|
| | | self.threshold = threshold
|
| | | self.smooth_factor = smooth_factor
|
| | |
| | | alphas = alphas.squeeze(-1)
|
| | | mask = mask.squeeze(-1)
|
| | | if target_label_length is not None:
|
| | | target_length = target_label_length
|
| | | target_length = target_label_length.squeeze(-1)
|
| | | elif target_label is not None:
|
| | | target_length = (target_label != ignore_id).float().sum(-1)
|
| | | else:
|
| | |
| | | return var_dict_torch_update
|
| | |
|
| | |
|
| | | class mae_loss(nn.Module):
|
| | | class mae_loss(torch.nn.Module):
|
| | |
|
| | | def __init__(self, normalize_length=False):
|
| | | super(mae_loss, self).__init__()
|