From d79287c37e4e7ae2694a992cbbfb03a5ca4f7670 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 20 二月 2024 14:05:58 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR merge
---
funasr/models/transducer/model.py | 155 ++++++++++++++++-----------------------------------
1 files changed, 48 insertions(+), 107 deletions(-)
diff --git a/funasr/models/transducer/model.py b/funasr/models/transducer/model.py
index 906aa60..fd8ad71 100644
--- a/funasr/models/transducer/model.py
+++ b/funasr/models/transducer/model.py
@@ -1,42 +1,26 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import time
+import torch
import logging
from contextlib import contextmanager
+from typing import Dict, Optional, Tuple
from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
-import torch
-import torch.nn as nn
-import random
-import numpy as np
-import time
-from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
-)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.frontends.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.paraformer.cif_predictor import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.metrics.compute_acc import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.paraformer.cif_predictor import CifPredictorV3
-from funasr.models.paraformer.search import Hypothesis
-from funasr.models.model_class_factory import *
+from funasr.register import tables
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.scorers.length_bonus import LengthBonus
+from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer
+
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -45,16 +29,10 @@
@contextmanager
def autocast(enabled=True):
yield
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
-from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
-class Transducer(nn.Module):
- """ESPnet2ASRTransducerModel module definition."""
-
-
+@tables.register("model_classes", "Transducer")
+class Transducer(torch.nn.Module):
def __init__(
self,
frontend: Optional[str] = None,
@@ -96,35 +74,30 @@
super().__init__()
- if frontend is not None:
- frontend_class = frontend_classes.get_class(frontend)
- frontend = frontend_class(**frontend_conf)
if specaug is not None:
- specaug_class = specaug_classes.get_class(specaug)
+ specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
- normalize_class = normalize_classes.get_class(normalize)
+ normalize_class = tables.normalize_classes.get(normalize)
normalize = normalize_class(**normalize_conf)
- encoder_class = encoder_classes.get_class(encoder)
+ encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
- decoder_class = decoder_classes.get_class(decoder)
+ decoder_class = tables.decoder_classes.get(decoder)
decoder = decoder_class(
vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
**decoder_conf,
)
decoder_output_size = decoder.output_size
- joint_network_class = joint_network_classes.get_class(decoder)
+ joint_network_class = tables.joint_network_classes.get(joint_network)
joint_network = joint_network_class(
vocab_size,
encoder_output_size,
decoder_output_size,
**joint_network_conf,
)
-
self.criterion_transducer = None
self.error_calculator = None
@@ -157,23 +130,17 @@
self.decoder = decoder
self.joint_network = joint_network
-
-
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
- #
- # if report_cer or report_wer:
- # self.error_calculator = ErrorCalculator(
- # token_list, sym_space, sym_blank, report_cer, report_wer
- # )
- #
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
+ self.ctc = None
+ self.ctc_weight = 0.0
def forward(
self,
@@ -190,8 +157,6 @@
text: (Batch, Length)
text_lengths: (Batch,)
"""
- # import pdb;
- # pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
@@ -283,12 +248,7 @@
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
- if self.encoder.interctc_use_conditioning:
- encoder_out, encoder_out_lens, _ = self.encoder(
- speech, speech_lengths, ctc=self.ctc
- )
- else:
- encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
@@ -449,9 +409,6 @@
def init_beam_search(self,
**kwargs,
):
- from funasr.models.transformer.search import BeamSearch
- from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
- from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
@@ -466,28 +423,16 @@
length_bonus=LengthBonus(len(token_list)),
)
-
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
- weights = dict(
- decoder=1.0 - kwargs.get("decoding_ctc_weight"),
- ctc=kwargs.get("decoding_ctc_weight", 0.0),
- lm=kwargs.get("lm_weight", 0.0),
- ngram=kwargs.get("ngram_weight", 0.0),
- length_bonus=kwargs.get("penalty", 0.0),
- )
- beam_search = BeamSearch(
- beam_size=kwargs.get("beam_size", 2),
- weights=weights,
- scorers=scorers,
- sos=self.sos,
- eos=self.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+ beam_search = BeamSearchTransducer(
+ self.decoder,
+ self.joint_network,
+ kwargs.get("beam_size", 2),
+ nbest=1,
)
# beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
# for scorer in scorers.values():
@@ -495,13 +440,13 @@
# scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
self.beam_search = beam_search
- def generate(self,
- data_in: list,
- data_lengths: list=None,
- key: list=None,
- tokenizer=None,
- **kwargs,
- ):
+ def inference(self,
+ data_in: list,
+ data_lengths: list=None,
+ key: list=None,
+ tokenizer=None,
+ **kwargs,
+ ):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
@@ -509,10 +454,10 @@
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
- if self.beam_search is None and (is_use_lm or is_use_ctc):
- logging.info("enable beam_search")
- self.init_beam_search(**kwargs)
- self.nbest = kwargs.get("nbest", 1)
+ # if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
meta_data = {}
# extract fbank feats
@@ -534,12 +479,8 @@
encoder_out = encoder_out[0]
# c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
- )
-
+ nbest_hyps = self.beam_search(encoder_out[0], is_final=True)
nbest_hyps = nbest_hyps[: self.nbest]
-
results = []
b, n, d = encoder_out.size()
@@ -553,9 +494,9 @@
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
+ token_int = hyp.yseq#[1:last_pos]
else:
- token_int = hyp.yseq[1:last_pos].tolist()
+ token_int = hyp.yseq#[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
--
Gitblit v1.9.1