From e67ed1d45d5a9d7fb7bb22d15fd2bfef17e9076f Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 17 一月 2024 10:57:14 +0800
Subject: [PATCH] Update load_pretrained_model.py
---
funasr/models/uniasr/e2e_uni_asr.py | 35 +++++++++++++++++++----------------
1 files changed, 19 insertions(+), 16 deletions(-)
diff --git a/funasr/models/uniasr/e2e_uni_asr.py b/funasr/models/uniasr/e2e_uni_asr.py
index 0fb4039..390d274 100644
--- a/funasr/models/uniasr/e2e_uni_asr.py
+++ b/funasr/models/uniasr/e2e_uni_asr.py
@@ -10,15 +10,15 @@
import torch
from funasr.models.e2e_asr_common import ErrorCalculator
-from funasr.models.transformer.utils.nets_utils import th_accuracy
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
@@ -26,7 +26,7 @@
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
from funasr.models.scama.chunk_utilis import sequence_mask
-from funasr.models.predictor.cif import mae_loss
+from funasr.models.paraformer.cif_predictor import mae_loss
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -443,7 +443,10 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
+<<<<<<< HEAD:funasr/models/uniasr/e2e_uni_asr.py
+=======
+>>>>>>> main:funasr/models/e2e_uni_asr.py
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@@ -538,20 +541,20 @@
speech_lengths: (Batch, )
"""
# with autocast(False):
- # # 1. Extract feats
- # feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ # # 1. Extract feats
+ # feats, feats_lengths = self._extract_feats(speech, speech_lengths)
#
- # # 2. Data augmentation
- # if self.specaug is not None and self.training:
- # feats, feats_lengths = self.specaug(feats, feats_lengths)
+ # # 2. Data augmentation
+ # if self.specaug is not None and self.training:
+ # feats, feats_lengths = self.specaug(feats, feats_lengths)
#
- # # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- # if self.normalize is not None:
- # feats, feats_lengths = self.normalize(feats, feats_lengths)
+ # # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ # if self.normalize is not None:
+ # feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
# if self.preencoder is not None:
- # feats, feats_lengths = self.preencoder(feats, feats_lengths)
+ # feats, feats_lengths = self.preencoder(feats, feats_lengths)
encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
encoder_out,
encoder_out_lens,
@@ -581,9 +584,9 @@
# # Post-encoder, e.g. NLU
# if self.postencoder is not None:
- # encoder_out, encoder_out_lens = self.postencoder(
- # encoder_out, encoder_out_lens
- # )
+ # encoder_out, encoder_out_lens = self.postencoder(
+ # encoder_out, encoder_out_lens
+ # )
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
--
Gitblit v1.9.1