From 8827e26b8d487f123f8d7d5cbd8d00b81dcefcff Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 23 二月 2024 00:58:18 +0800
Subject: [PATCH] fp16
---
funasr/models/paraformer/cif_predictor.py | 34 ++++++++++++++++++----------------
1 files changed, 18 insertions(+), 16 deletions(-)
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index b06fa43..60ddc24 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -1,23 +1,25 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import torch
-from torch import nn
-from torch import Tensor
import logging
import numpy as np
-from funasr.train_utils.device_funcs import to_device
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.scama.utils import sequence_mask
-from typing import Optional, Tuple
from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+
@tables.register("predictor_classes", "CifPredictor")
-class CifPredictor(nn.Module):
+class CifPredictor(torch.nn.Module):
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
super().__init__()
- self.pad = nn.ConstantPad1d((l_order, r_order), 0)
- self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
- self.cif_output = nn.Linear(idim, 1)
+ self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
+ self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
+ self.cif_output = torch.nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
@@ -137,7 +139,7 @@
return predictor_alignments.detach(), predictor_alignments_length.detach()
@tables.register("predictor_classes", "CifPredictorV2")
-class CifPredictorV2(nn.Module):
+class CifPredictorV2(torch.nn.Module):
def __init__(self,
idim,
l_order,
@@ -153,9 +155,9 @@
):
super(CifPredictorV2, self).__init__()
- self.pad = nn.ConstantPad1d((l_order, r_order), 0)
- self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
- self.cif_output = nn.Linear(idim, 1)
+ self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
+ self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
+ self.cif_output = torch.nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
@@ -184,7 +186,7 @@
alphas = alphas.squeeze(-1)
mask = mask.squeeze(-1)
if target_label_length is not None:
- target_length = target_label_length
+ target_length = target_label_length.squeeze(-1)
elif target_label is not None:
target_length = (target_label != ignore_id).float().sum(-1)
else:
@@ -426,7 +428,7 @@
return var_dict_torch_update
-class mae_loss(nn.Module):
+class mae_loss(torch.nn.Module):
def __init__(self, normalize_length=False):
super(mae_loss, self).__init__()
--
Gitblit v1.9.1