Merge pull request #1247 from alibaba-damo-academy/funasr1.0
Funasr1.0
| | |
| | | ### Speech Recognition (Non-streaming) |
| | | ```python |
| | | from funasr import AutoModel |
| | | |
| | | model = AutoModel(model="paraformer-zh") |
| | | # for the long duration wav, you could add vad model |
| | | # model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc") |
| | | |
| | | res = model(input="asr_example_zh.wav", batch_size=64) |
| | | # paraformer-zh is a multi-functional asr model |
| | | # use vad, punc, spk or not as you need |
| | | model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \ |
| | | vad_model="fsmn-vad", vad_model_revision="v2.0.2", \ |
| | | punc_model="ct-punc-c", punc_model_revision="v2.0.2", \ |
| | | spk_model="cam++", spk_model_revision="v2.0.2") |
| | | res = model(input=f"{model.model_path}/example/asr_example.wav", |
| | | batch_size=16, |
| | | hotword='魔搭') |
| | | print(res) |
| | | ``` |
| | | Note: `model_hub`: represents the model repository, `ms` stands for selecting ModelScope download, `hf` stands for selecting Huggingface download. |
| | |
| | | encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention |
| | | decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention |
| | | |
| | | model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.0") |
| | | model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.2") |
| | | |
| | | import soundfile |
| | | import os |
| | |
| | | ```python |
| | | from funasr import AutoModel |
| | | |
| | | model = AutoModel(model="ct-punc", model_revision="v2.0.1") |
| | | model = AutoModel(model="ct-punc", model_revision="v2.0.2") |
| | | |
| | | res = model(input="那今天的会就到这里吧 happy new year 明年见") |
| | | print(res) |
| | |
| | | ```python |
| | | from funasr import AutoModel |
| | | |
| | | model = AutoModel(model="fa-zh", model_revision="v2.0.0") |
| | | model = AutoModel(model="fa-zh", model_revision="v2.0.2") |
| | | |
| | | wav_file = f"{model.model_path}/example/asr_example.wav" |
| | | text_file = f"{model.model_path}/example/asr_example.wav" |
| | |
| | | vad_model_revision="v2.0.2", |
| | | punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", |
| | | punc_model_revision="v2.0.2", |
| | | spk_model="damo/speech_campplus_sv_zh-cn_16k-common", |
| | | spk_model_revision="v2.0.2", |
| | | ) |
| | | |
| | | res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", |
| | | hotword='达摩院 磨搭') |
| | | res = model(input=f"{model.model_path}/example/asr_example.wav", |
| | | hotword='达摩院 魔搭') |
| | | print(res) |
| | |
| | | def generate_with_vad(self, input, input_len=None, **cfg): |
| | | |
| | | # step.1: compute the vad model |
| | | model = self.vad_model |
| | | kwargs = self.vad_kwargs |
| | | kwargs.update(cfg) |
| | | self.vad_kwargs.update(cfg) |
| | | beg_vad = time.time() |
| | | res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg) |
| | | vad_res = res |
| | | res = self.generate(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg) |
| | | end_vad = time.time() |
| | | print(f"time cost vad: {end_vad - beg_vad:0.3f}") |
| | | |
| | |
| | | logging.info("decoding, utt: {}, empty speech".format(key)) |
| | | continue |
| | | |
| | | |
| | | # if kwargs["device"] == "cpu": |
| | | # batch_size = 0 |
| | | if len(sorted_data) > 0 and len(sorted_data[0]) > 0: |
| | | batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]) |
| | | |
| | |
| | | import argparse |
| | | import logging |
| | | import os |
| | | import sys |
| | | from io import BytesIO |
| | | from collections.abc import Sequence |
| | | import torch |
| | | import hydra |
| | | import logging |
| | | import argparse |
| | | from io import BytesIO |
| | | import torch.distributed as dist |
| | | from collections.abc import Sequence |
| | | from omegaconf import DictConfig, OmegaConf |
| | | from funasr.train_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.models.lora.utils import mark_only_lora_as_trainable |
| | | from torch.nn.parallel import DistributedDataParallel as DDP |
| | | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| | | |
| | | from funasr.register import tables |
| | | from funasr.optimizers import optim_classes |
| | | from funasr.train_utils.trainer import Trainer |
| | | from funasr.schedulers import scheduler_classes |
| | | from funasr.train_utils.load_pretrained_model import load_pretrained_model |
| | | from funasr.train_utils.initialize import initialize |
| | | from funasr.download.download_from_hub import download_model |
| | | from funasr.models.lora.utils import mark_only_lora_as_trainable |
| | | from funasr.train_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.train_utils.load_pretrained_model import load_pretrained_model |
| | | # from funasr.tokenizer.build_tokenizer import build_tokenizer |
| | | # from funasr.tokenizer.token_id_converter import TokenIDConverter |
| | | # from funasr.tokenizer.funtoken import build_tokenizer |
| | | from funasr.train_utils.trainer import Trainer |
| | | import torch.distributed as dist |
| | | from torch.nn.parallel import DistributedDataParallel as DDP |
| | | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| | | from funasr.download.download_from_hub import download_model |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @hydra.main(config_name=None, version_base=None) |
| | | def main_hydra(kwargs: DictConfig): |
| | |
| | | import torch |
| | | import json |
| | | import torch.distributed as dist |
| | | import numpy as np |
| | | import kaldiio |
| | | import librosa |
| | | import torchaudio |
| | | import time |
| | | import logging |
| | | |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.register import tables |
| | | from funasr.utils.load_utils import extract_fbank |
| | | |
| | | |
| | | @tables.register("dataset_classes", "AudioDataset") |
| | | class AudioDataset(torch.utils.data.Dataset): |
| | |
| | | |
| | | |
| | | def collator(self, samples: list=None): |
| | | |
| | | |
| | | outputs = {} |
| | | for sample in samples: |
| | | for key in sample.keys(): |
| | |
| | | import torch |
| | | import json |
| | | import torch.distributed as dist |
| | | import time |
| | | import torch |
| | | import logging |
| | | import torch.distributed as dist |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("index_ds_classes", "IndexDSJsonl") |
| | | class IndexDSJsonl(torch.utils.data.Dataset): |
| | | |
| | |
| | | import torch |
| | | |
| | | import numpy as np |
| | | |
| | | from funasr.register import tables |
| | |
| | | import json |
| | | import os |
| | | import json |
| | | from omegaconf import OmegaConf |
| | | import torch |
| | | |
| | | from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf |
| | | |
| | | |
| | | def download_model(**kwargs): |
| | | model_hub = kwargs.get("model_hub", "ms") |
| | | if model_hub == "ms": |
| | |
| | | |
| | | |
| | | name_maps_ms = { |
| | | "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", |
| | | "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn", |
| | | "paraformer-zh": "damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", |
| | | "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020", |
| | | "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020", |
| | | "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", |
| | | "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", |
| | | "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large", |
| | | "ct-punc-c": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", |
| | | "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline", |
| | | "cam++": "damo/speech_campplus_sv_zh-cn_16k-common", |
| | | } |
| | | |
| | | name_maps_hf = { |
| | |
| | | from pathlib import Path |
| | | import os |
| | | import argparse |
| | | from pathlib import Path |
| | | |
| | | from funasr.utils.types import str2bool |
| | | |
| | | |
| | | def main(): |
| | | parser = argparse.ArgumentParser() |
| | | parser.add_argument('--model-name', type=str, required=True) |
| | |
| | | punc_array = punctuations |
| | | else: |
| | | punc_array = torch.cat([punc_array, punctuations], dim=0) |
| | | |
| | | result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array} |
| | | results.append(result_i) |
| | | |
| | |
| | | |
| | | import math |
| | | import torch |
| | | import numpy as np |
| | | import math |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | import logging |
| | | import torch.nn.functional as F |
| | | from funasr.models.scama.utils import sequence_mask |
| | | |
| | | from funasr.models.scama.utils import sequence_mask |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | |
| | | |
| | | class overlap_chunk(): |
| | |
| | | import os |
| | | import torch |
| | | from torch.nn import functional as F |
| | | import yaml |
| | | import torch |
| | | import numpy as np |
| | | from torch.nn import functional as F |
| | | |
| | | |
| | | def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): |
| | | if maxlen is None: |
| | |
| | | |
| | | dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation |
| | | dha_pred = torch.log_softmax(dha_output, dim=-1) |
| | | # import pdb; pdb.set_trace() |
| | | def _merge_res(dec_output, dha_output): |
| | | lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0]) |
| | | dha_ids = dha_output.max(-1)[-1][0] |
| | | dha_ids = dha_output.max(-1)[-1]# [0] |
| | | dha_mask = (dha_ids == 8377).int().unsqueeze(-1) |
| | | a = (1 - lmbd) / lmbd |
| | | b = 1 / lmbd |
| | |
| | | logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask) |
| | | return logits |
| | | merged_pred = _merge_res(decoder_pred, dha_pred) |
| | | # import pdb; pdb.set_trace() |
| | | return merged_pred |
| | | else: |
| | | return decoder_pred |
| | |
| | | from abc import ABC |
| | | from abc import abstractmethod |
| | | from typing import Iterable |
| | | from typing import List |
| | | from pathlib import Path |
| | | from typing import Dict |
| | | from typing import Iterable |
| | | from typing import List |
| | | from typing import Union |
| | | import json |
| | | |
| | | import numpy as np |
| | | from abc import ABC |
| | | from pathlib import Path |
| | | from abc import abstractmethod |
| | | from typing import Union, Iterable, List, Dict |
| | | |
| | | |
| | | class AbsTokenizer(ABC): |
| | |
| | | import torch |
| | | import os |
| | | from funasr.train_utils.device_funcs import to_device |
| | | import logging |
| | | import time |
| | | import torch |
| | | import logging |
| | | from tqdm import tqdm |
| | | from contextlib import nullcontext |
| | | import torch.distributed as dist |
| | | from contextlib import nullcontext |
| | | |
| | | from funasr.train_utils.device_funcs import to_device |
| | | from funasr.train_utils.recursive_op import recursive_average |
| | | |
| | | |
| | | class Trainer: |
| | | """ |
| | | A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch, |
| | |
| | | import torch |
| | | from torch.nn.utils.rnn import pad_sequence |
| | | |
| | | |
| | | def slice_padding_fbank(speech, speech_lengths, vad_segments): |
| | | speech_list = [] |
| | | speech_lengths_list = [] |
| | |
| | | feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0) |
| | | speech_lengths_pad = torch.Tensor(speech_lengths_list).int() |
| | | return feats_pad, speech_lengths_pad |
| | | |
| | | |
| | | def slice_padding_audio_samples(speech, speech_lengths, vad_segments): |
| | | speech_list = [] |