#!/usr/bin/env python3
|
# -*- encoding: utf-8 -*-
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
# MIT License (https://opensource.org/licenses/MIT)
|
|
import time
|
import torch
|
import logging
|
from typing import Dict, Tuple
|
from contextlib import contextmanager
|
from distutils.version import LooseVersion
|
|
from funasr.register import tables
|
from funasr.models.ctc.ctc import CTC
|
from funasr.utils import postprocess_utils
|
from funasr.metrics.compute_acc import th_accuracy
|
from funasr.utils.datadir_writer import DatadirWriter
|
from funasr.models.sanm_kws.model import SanmKWS
|
from funasr.models.paraformer.search import Hypothesis
|
from funasr.models.paraformer.cif_predictor import mae_loss
|
from funasr.train_utils.device_funcs import force_gatherable
|
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
from torch.cuda.amp import autocast
|
else:
|
# Nothing to do if torch<1.6.0
|
@contextmanager
|
def autocast(enabled=True):
|
yield
|
|
|
@tables.register("model_classes", "SanmKWSStreaming")
|
class SanmKWSStreaming(SanmKWS):
|
"""
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
https://arxiv.org/abs/2206.08317
|
"""
|
|
def __init__(
|
self,
|
*args,
|
**kwargs,
|
):
|
super().__init__(*args, **kwargs)
|
|
def forward(
|
self,
|
speech: torch.Tensor,
|
speech_lengths: torch.Tensor,
|
text: torch.Tensor,
|
text_lengths: torch.Tensor,
|
**kwargs,
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
"""Encoder + Decoder + Calc loss
|
Args:
|
speech: (Batch, Length, ...)
|
speech_lengths: (Batch, )
|
text: (Batch, Length)
|
text_lengths: (Batch,)
|
"""
|
decoding_ind = kwargs.get("decoding_ind")
|
if len(text_lengths.size()) > 1:
|
text_lengths = text_lengths[:, 0]
|
if len(speech_lengths.size()) > 1:
|
speech_lengths = speech_lengths[:, 0]
|
|
batch_size = speech.shape[0]
|
|
# Encoder
|
if hasattr(self.encoder, "overlap_chunk_cls"):
|
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
|
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
|
else:
|
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
|
# decoder: CTC branch
|
if hasattr(self.encoder, "overlap_chunk_cls"):
|
encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(
|
encoder_out, encoder_out_lens, chunk_outs=None
|
)
|
else:
|
encoder_out_ctc, encoder_out_lens_ctc = encoder_out, encoder_out_lens
|
|
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
|
)
|
|
# Collect CTC branch stats
|
stats = dict()
|
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
stats["cer_ctc"] = cer_ctc
|
|
loss = loss_ctc
|
|
stats["cer"] = cer_ctc
|
stats["loss"] = torch.clone(loss.detach())
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
return loss, stats, weight
|
|
def encode_chunk(
|
self,
|
speech: torch.Tensor,
|
speech_lengths: torch.Tensor,
|
cache: dict = None,
|
**kwargs,
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
Args:
|
speech: (Batch, Length, ...)
|
speech_lengths: (Batch, )
|
ind: int
|
"""
|
with autocast(False):
|
# Data augmentation
|
if self.specaug is not None and self.training:
|
speech, speech_lengths = self.specaug(speech, speech_lengths)
|
|
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
if self.normalize is not None:
|
speech, speech_lengths = self.normalize(speech, speech_lengths)
|
|
# Forward encoder
|
encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
|
speech, speech_lengths, cache=cache["encoder"]
|
)
|
|
if isinstance(encoder_out, tuple):
|
encoder_out = encoder_out[0]
|
|
return encoder_out, torch.tensor([encoder_out.size(1)])
|
|
def init_cache(self, cache: dict = {}, **kwargs):
|
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
|
encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
|
decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
|
batch_size = 1
|
|
enc_output_size = kwargs["encoder_conf"]["output_size"]
|
feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
|
cache_encoder = {
|
"start_idx": 0,
|
"cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
|
"cif_alphas": torch.zeros((batch_size, 1)),
|
"encoder_out": None,
|
"encoder_out_lens": None,
|
"chunk_size": chunk_size,
|
"encoder_chunk_look_back": encoder_chunk_look_back,
|
"last_chunk": False,
|
"opt": None,
|
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
|
"tail_chunk": False,
|
}
|
cache["encoder"] = cache_encoder
|
|
cache_decoder = {
|
"decode_fsmn": None,
|
"decoder_chunk_look_back": decoder_chunk_look_back,
|
"opt": None,
|
"chunk_size": chunk_size,
|
}
|
cache["decoder"] = cache_decoder
|
cache["frontend"] = {}
|
cache["prev_samples"] = torch.empty(0)
|
|
return cache
|
|
def generate_chunk(
|
self,
|
speech,
|
speech_lengths=None,
|
key: list = None,
|
tokenizer=None,
|
frontend=None,
|
**kwargs,
|
):
|
cache = kwargs.get("cache", {})
|
speech = speech.to(device=kwargs["device"])
|
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
|
# Encoder
|
is_final = kwargs.get("is_final", False)
|
encoder_out, encoder_out_lens = self.encode_chunk(
|
speech, speech_lengths, cache=cache, is_final=is_final
|
)
|
if isinstance(encoder_out, tuple):
|
encoder_out = encoder_out[0]
|
|
chunk_size = cache["encoder"]["chunk_size"]
|
real_start_pos = chunk_size[0]
|
|
if encoder_out_lens[0] > chunk_size[0] + chunk_size[1] + chunk_size[2]:
|
assert False, print("impossible case 1 !")
|
if encoder_out_lens[0] == chunk_size[0] + chunk_size[1] + chunk_size[2]:
|
real_end_pos = chunk_size[0] + chunk_size[1]
|
elif encoder_out_lens[0] > chunk_size[0] + chunk_size[1]:
|
real_end_pos = chunk_size[0] + chunk_size[1]
|
elif encoder_out_lens[0] > chunk_size[0]:
|
real_end_pos = encoder_out_lens[0]
|
else:
|
assert False, print("impossible case 2 !")
|
|
encoder_out_accum = cache["encoder"]["encoder_out"]
|
if encoder_out_accum is not None:
|
encoder_out_accum = torch.cat((encoder_out_accum, encoder_out[:, real_start_pos:real_end_pos, :]), dim=1)
|
else:
|
encoder_out_accum = encoder_out[:, real_start_pos:real_end_pos, :]
|
cache["encoder"]["encoder_out"] = encoder_out_accum
|
|
if cache["encoder"]["encoder_out_lens"] is not None:
|
cache["encoder"]["encoder_out_lens"][0] += real_end_pos - real_start_pos
|
else:
|
cache["encoder"]["encoder_out_lens"] = encoder_out_lens
|
cache["encoder"]["encoder_out_lens"][0] = real_end_pos - real_start_pos
|
|
if is_final:
|
if kwargs.get("output_dir") is not None:
|
if not hasattr(self, "writer"):
|
self.writer = DatadirWriter(kwargs.get("output_dir"))
|
|
results = []
|
for i in range(encoder_out_accum.size(0)):
|
x = encoder_out_accum[i, : cache["encoder"]["encoder_out_lens"][i], :]
|
detect_result = self.kws_decoder.decode(x)
|
is_deted, det_keyword, det_score = detect_result[0], detect_result[1], detect_result[2]
|
|
if is_deted:
|
self.writer["detect"][key[i]] = "detected " + det_keyword + " " + str(det_score)
|
det_info = "detected " + det_keyword + " " + str(det_score)
|
else:
|
self.writer["detect"][key[i]] = "rejected"
|
det_info = "rejected"
|
|
result_i = {"key": key[i], "text": det_info}
|
results.append(result_i)
|
|
return results
|
else:
|
return None
|
|
def inference(
|
self,
|
data_in,
|
data_lengths=None,
|
key: list = None,
|
tokenizer=None,
|
frontend=None,
|
cache: dict = {},
|
**kwargs,
|
):
|
keywords = kwargs.get("keywords")
|
from funasr.utils.kws_utils import KwsCtcPrefixDecoder
|
self.kws_decoder = KwsCtcPrefixDecoder(
|
ctc=self.ctc,
|
keywords=keywords,
|
token_list=tokenizer.token_list,
|
seg_dict=tokenizer.seg_dict,
|
)
|
|
meta_data = {}
|
chunk_size = kwargs["chunk_size"]
|
chunk_stride_samples = int(chunk_size[1] * 960) # 600ms
|
first_chunk_padding_samples = int(chunk_size[2] * 960) # 600ms
|
|
if len(cache) == 0:
|
self.init_cache(cache, **kwargs)
|
|
time1 = time.perf_counter()
|
cfg = {"is_final": kwargs.get("is_final", False)}
|
audio_sample_list = load_audio_text_image_video(
|
data_in,
|
fs=frontend.fs,
|
audio_fs=kwargs.get("fs", 16000),
|
data_type=kwargs.get("data_type", "sound"),
|
tokenizer=tokenizer,
|
cache=cfg,
|
)
|
_is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
|
|
time2 = time.perf_counter()
|
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
assert len(audio_sample_list) == 1, "batch_size must be set 1"
|
|
audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
|
|
if len(audio_sample) < first_chunk_padding_samples:
|
print("key: {}, audio is too short for inference {}".format(key, len(audio_sample)))
|
|
audio_sample_pre = audio_sample[0 : first_chunk_padding_samples]
|
feat_pre, feat_pre_lengths = extract_fbank(
|
[audio_sample_pre],
|
data_type=kwargs.get("data_type", "sound"),
|
frontend=frontend,
|
cache=cache["frontend"],
|
is_final=False,
|
)
|
|
audio_sample = audio_sample[first_chunk_padding_samples:]
|
audio_chunks = int(len(audio_sample) // chunk_stride_samples)
|
|
for i in range(audio_chunks):
|
if i == 0:
|
cache["encoder"]["feats"][:, chunk_size[2] :, :] = feat_pre
|
|
kwargs["is_final"] = False
|
audio_sample_i = audio_sample[i * chunk_stride_samples : (i + 1) * chunk_stride_samples]
|
|
if kwargs["is_final"] and len(audio_sample_i) < 960:
|
cache["encoder"]["tail_chunk"] = True
|
speech = cache["encoder"]["feats"]
|
speech_lengths = torch.tensor([speech.shape[1]], dtype=torch.int64).to(
|
speech.device
|
)
|
else:
|
# extract fbank feats
|
speech, speech_lengths = extract_fbank(
|
[audio_sample_i],
|
data_type=kwargs.get("data_type", "sound"),
|
frontend=frontend,
|
cache=cache["frontend"],
|
is_final=kwargs["is_final"],
|
)
|
time3 = time.perf_counter()
|
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
meta_data["batch_data_time"] = (
|
speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
)
|
|
results_chunk_i = self.generate_chunk(
|
speech,
|
speech_lengths,
|
key=key,
|
tokenizer=tokenizer,
|
cache=cache,
|
frontend=frontend,
|
**kwargs,
|
)
|
|
# results_chunk_i must be None when is_final=False
|
assert results_chunk_i is None
|
|
# process tail samples
|
tail_audio_sample = audio_sample[ audio_chunks * chunk_stride_samples: ]
|
if len(tail_audio_sample) < 960:
|
kwargs["is_final"] = _is_final
|
cache["encoder"]["tail_chunk"] = True
|
speech = cache["encoder"]["feats"]
|
speech_lengths = torch.tensor([speech.shape[1]], dtype=torch.int64).to(
|
speech.device
|
)
|
results_chunk_tail = self.generate_chunk(
|
speech,
|
speech_lengths,
|
key=key,
|
tokenizer=tokenizer,
|
cache=cache,
|
frontend=frontend,
|
**kwargs,
|
)
|
elif len(tail_audio_sample) <= first_chunk_padding_samples:
|
kwargs["is_final"] = _is_final
|
# extract fbank feats
|
# cache["encoder"]["tail_chunk"] = True # cannot be true
|
speech, speech_lengths = extract_fbank(
|
[ tail_audio_sample ],
|
data_type=kwargs.get("data_type", "sound"),
|
frontend=frontend,
|
cache=cache["frontend"],
|
is_final=kwargs["is_final"],
|
)
|
results_chunk_tail = self.generate_chunk(
|
speech,
|
speech_lengths,
|
key=key,
|
tokenizer=tokenizer,
|
cache=cache,
|
frontend=frontend,
|
**kwargs,
|
)
|
elif len(tail_audio_sample) > first_chunk_padding_samples and \
|
len(tail_audio_sample) < chunk_stride_samples:
|
kwargs["is_final"] = False
|
# extract fbank feats
|
speech, speech_lengths = extract_fbank(
|
[ tail_audio_sample ],
|
data_type=kwargs.get("data_type", "sound"),
|
frontend=frontend,
|
cache=cache["frontend"],
|
is_final=kwargs["is_final"],
|
)
|
results_chunk = self.generate_chunk(
|
speech,
|
speech_lengths,
|
key=key,
|
tokenizer=tokenizer,
|
cache=cache,
|
frontend=frontend,
|
**kwargs,
|
)
|
# results_chunk must be None when is_final=False
|
assert results_chunk is None
|
|
# push tail chunk
|
kwargs["is_final"] = _is_final
|
cache["encoder"]["tail_chunk"] = True
|
speech = cache["encoder"]["feats"]
|
speech_lengths = torch.tensor([speech.shape[1]], dtype=torch.int64).to(
|
speech.device
|
)
|
results_chunk_tail = self.generate_chunk(
|
speech,
|
speech_lengths,
|
key=key,
|
tokenizer=tokenizer,
|
cache=cache,
|
frontend=frontend,
|
**kwargs,
|
)
|
|
result = results_chunk_tail
|
|
if _is_final:
|
self.init_cache(cache, **kwargs)
|
|
if kwargs.get("output_dir"):
|
if not hasattr(self, "writer"):
|
self.writer = DatadirWriter(kwargs.get("output_dir"))
|
|
return result, meta_data
|
|
def export(self, **kwargs):
|
from .export_meta import export_rebuild_model
|
|
models = export_rebuild_model(model=self, **kwargs)
|
return models
|