From 98c94ab3ab0266482117343a064beeb6bd6bcedc Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 28 二月 2024 20:45:07 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR merge
---
funasr/models/llm_asr_nar/model.py | 338 ++++++++++++++
.gitignore | 1
funasr/models/llm_asr_nar/__init__.py | 0
funasr/train_utils/trainer.py | 46 +
runtime/html5/static/main.js | 10
setup.py | 2
runtime/docs/docker_online_cpu_zh_lists | 2
funasr/models/seaco_paraformer/model.py | 2
funasr/bin/train.py | 10
funasr/datasets/llm_datasets/samplers.py | 277 +++++++++++
funasr/models/llm_asr_nar/adaptor.py | 29 +
funasr/tokenizer/hf_tokenizer.py | 15
runtime/html5/static/index.html | 6
funasr/auto/auto_frontend.py | 10
funasr/datasets/llm_datasets/__init__.py | 0
runtime/python/websocket/funasr_wss_server.py | 4
runtime/html5/static/wsconnecter.js | 2
runtime/docs/docker_offline_cpu_zh_lists | 2
funasr/auto/auto_model.py | 96 ++-
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py | 8
funasr/datasets/llm_datasets/preprocessor.py | 37 +
funasr/models/paraformer/cif_predictor.py | 200 +++-----
README_zh.md | 24
funasr/train_utils/load_pretrained_model.py | 123 +---
README.md | 25
funasr/datasets/llm_datasets/datasets.py | 131 +++++
runtime/docs/SDK_advanced_guide_online_zh.md | 2
funasr/metrics/compute_acc.py | 17
runtime/docs/docker_offline_cpu_en_lists | 1
29 files changed, 1,123 insertions(+), 297 deletions(-)
diff --git a/.gitignore b/.gitignore
index adf2937..b0d4692 100644
--- a/.gitignore
+++ b/.gitignore
@@ -25,3 +25,4 @@
emotion2vec*
GPT-SoVITS*
modelscope_models
+examples/aishell/llm_asr_nar/*
diff --git a/README.md b/README.md
index 454adc9..04a3e68 100644
--- a/README.md
+++ b/README.md
@@ -105,10 +105,8 @@
from funasr import AutoModel
# paraformer-zh is a multi-functional asr model
# use vad, punc, spk or not as you need
-model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
- vad_model="fsmn-vad", vad_model_revision="v2.0.4",
- punc_model="ct-punc-c", punc_model_revision="v2.0.4",
- # spk_model="cam++", spk_model_revision="v2.0.2",
+model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc-c",
+ # spk_model="cam++",
)
res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
batch_size_s=300,
@@ -125,7 +123,7 @@
encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
-model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.4")
+model = AutoModel(model="paraformer-zh-streaming")
import soundfile
import os
@@ -148,17 +146,19 @@
```python
from funasr import AutoModel
-model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
+model = AutoModel(model="fsmn-vad")
wav_file = f"{model.model_path}/example/asr_example.wav"
res = model.generate(input=wav_file)
print(res)
```
+Note: The output format of the VAD model is: `[[beg1, end1], [beg2, end2], ..., [begN, endN]]`, where `begN/endN` indicates the starting/ending point of the `N-th` valid audio segment, measured in milliseconds.
+
### Voice Activity Detection (Streaming)
```python
from funasr import AutoModel
chunk_size = 200 # ms
-model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
+model = AutoModel(model="fsmn-vad")
import soundfile
@@ -175,11 +175,18 @@
if len(res[0]["value"]):
print(res)
```
+Note: The output format for the streaming VAD model can be one of four scenarios:
+- `[[beg1, end1], [beg2, end2], .., [begN, endN]]`锛歍he same as the offline VAD output result mentioned above.
+- `[[beg, -1]]`锛欼ndicates that only a starting point has been detected.
+- `[[-1, end]]`锛欼ndicates that only an ending point has been detected.
+- `[]`锛欼ndicates that neither a starting point nor an ending point has been detected.
+
+The output is measured in milliseconds and represents the absolute time from the starting point.
### Punctuation Restoration
```python
from funasr import AutoModel
-model = AutoModel(model="ct-punc", model_revision="v2.0.4")
+model = AutoModel(model="ct-punc")
res = model.generate(input="閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�")
print(res)
```
@@ -187,7 +194,7 @@
```python
from funasr import AutoModel
-model = AutoModel(model="fa-zh", model_revision="v2.0.4")
+model = AutoModel(model="fa-zh")
wav_file = f"{model.model_path}/example/asr_example.wav"
text_file = f"{model.model_path}/example/text.txt"
res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
diff --git a/README_zh.md b/README_zh.md
index 07cdd1f..63ad2e2 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -101,10 +101,8 @@
from funasr import AutoModel
# paraformer-zh is a multi-functional asr model
# use vad, punc, spk or not as you need
-model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
- vad_model="fsmn-vad", vad_model_revision="v2.0.4",
- punc_model="ct-punc-c", punc_model_revision="v2.0.4",
- # spk_model="cam++", spk_model_revision="v2.0.2",
+model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc-c",
+ # spk_model="cam++"
)
res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
batch_size_s=300,
@@ -122,7 +120,7 @@
encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
-model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.4")
+model = AutoModel(model="paraformer-zh-streaming")
import soundfile
import os
@@ -146,19 +144,21 @@
```python
from funasr import AutoModel
-model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
+model = AutoModel(model="fsmn-vad")
wav_file = f"{model.model_path}/example/asr_example.wav"
res = model.generate(input=wav_file)
print(res)
```
+娉細VAD妯″瀷杈撳嚭鏍煎紡涓猴細`[[beg1, end1], [beg2, end2], .., [begN, endN]]`锛屽叾涓璥begN/endN`琛ㄧず绗琡N`涓湁鏁堥煶棰戠墖娈电殑璧峰鐐�/缁撴潫鐐癸紝
+鍗曚綅涓烘绉掋��
### 璇煶绔偣妫�娴嬶紙瀹炴椂锛�
```python
from funasr import AutoModel
chunk_size = 200 # ms
-model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
+model = AutoModel(model="fsmn-vad")
import soundfile
@@ -175,12 +175,18 @@
if len(res[0]["value"]):
print(res)
```
+娉細娴佸紡VAD妯″瀷杈撳嚭鏍煎紡涓�4绉嶆儏鍐碉細
+- `[[beg1, end1], [beg2, end2], .., [begN, endN]]`锛氬悓涓婄绾縑AD杈撳嚭缁撴灉銆�
+- `[[beg, -1]]`锛氳〃绀哄彧妫�娴嬪埌璧峰鐐广��
+- `[[-1, end]]`锛氳〃绀哄彧妫�娴嬪埌缁撴潫鐐广��
+- `[]`锛氳〃绀烘棦娌℃湁妫�娴嬪埌璧峰鐐癸紝涔熸病鏈夋娴嬪埌缁撴潫鐐�
+杈撳嚭缁撴灉鍗曚綅涓烘绉掞紝浠庤捣濮嬬偣寮�濮嬬殑缁濆鏃堕棿銆�
### 鏍囩偣鎭㈠
```python
from funasr import AutoModel
-model = AutoModel(model="ct-punc", model_revision="v2.0.4")
+model = AutoModel(model="ct-punc")
res = model.generate(input="閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�")
print(res)
@@ -190,7 +196,7 @@
```python
from funasr import AutoModel
-model = AutoModel(model="fa-zh", model_revision="v2.0.0")
+model = AutoModel(model="fa-zh")
wav_file = f"{model.model_path}/example/asr_example.wav"
text_file = f"{model.model_path}/example/text.txt"
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
index f043123..c28db7a 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -10,6 +10,8 @@
res = model.generate(input=wav_file)
print(res)
+# [[beg1, end1], [beg2, end2], .., [begN, endN]]
+# beg/end: ms
@@ -37,3 +39,9 @@
# print(res)
if len(res[0]["value"]):
print(res)
+
+
+# 1. [[beg1, end1], [beg2, end2], .., [begN, endN]]; [[beg, end]]; [[beg1, end1], [beg2, end2]]
+# 2. [[beg, -1]]
+# 3. [[-1, end]]
+# beg/end: ms
\ No newline at end of file
diff --git a/funasr/auto/auto_frontend.py b/funasr/auto/auto_frontend.py
index 8f2f069..35ea23f 100644
--- a/funasr/auto/auto_frontend.py
+++ b/funasr/auto/auto_frontend.py
@@ -1,3 +1,8 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import json
import time
import torch
@@ -12,15 +17,14 @@
from funasr.register import tables
from funasr.utils.load_utils import load_bytes
from funasr.download.file import download_from_url
+from funasr.auto.auto_model import prepare_data_iterator
+from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
-from funasr.auto.auto_model import prepare_data_iterator
-
class AutoFrontend:
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index e5faa2a..a6be691 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -1,3 +1,8 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import json
import time
import copy
@@ -12,12 +17,12 @@
from funasr.register import tables
from funasr.utils.load_utils import load_bytes
from funasr.download.file import download_from_url
+from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
+from funasr.utils.load_utils import load_audio_text_image_video
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
-from funasr.utils.load_utils import load_audio_text_image_video
-from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
try:
from funasr.models.campplus.cluster_backend import ClusterBackend
@@ -90,7 +95,7 @@
class AutoModel:
def __init__(self, **kwargs):
- if not kwargs.get("disable_log", False):
+ if not kwargs.get("disable_log", True):
tables.print()
model, kwargs = self.build_model(**kwargs)
@@ -157,8 +162,10 @@
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer
- kwargs["token_list"] = tokenizer.token_list
- vocab_size = len(tokenizer.token_list)
+
+ kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
+ kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
+ vocab_size = len(kwargs["token_list"])
else:
vocab_size = -1
@@ -179,15 +186,18 @@
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None:
- logging.info(f"Loading pretrained params from {init_param}")
- load_pretrained_model(
- model=model,
- path=init_param,
- ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
- oss_bucket=kwargs.get("oss_bucket", None),
- scope_map=kwargs.get("scope_map", None),
- excludes=kwargs.get("excludes", None),
- )
+ if os.path.exists(init_param):
+ logging.info(f"Loading pretrained params from {init_param}")
+ load_pretrained_model(
+ model=model,
+ path=init_param,
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
+ oss_bucket=kwargs.get("oss_bucket", None),
+ scope_map=kwargs.get("scope_map", []),
+ excludes=kwargs.get("excludes", None),
+ )
+ else:
+ print(f"error, init_param does not exist!: {init_param}")
return model, kwargs
@@ -219,7 +229,7 @@
speed_stats = {}
asr_result_list = []
num_samples = len(data_list)
- disable_pbar = kwargs.get("disable_pbar", False)
+ disable_pbar = self.kwargs.get("disable_pbar", False)
pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True) if not disable_pbar else None
time_speech_total = 0.0
time_escape_total = 0.0
@@ -231,12 +241,12 @@
if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank
batch["data_in"] = data_batch[0]
batch["data_lengths"] = input_len
-
+
time1 = time.perf_counter()
with torch.no_grad():
results, meta_data = model.inference(**batch, **kwargs)
time2 = time.perf_counter()
-
+
asr_result_list.extend(results)
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
@@ -261,31 +271,29 @@
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
torch.cuda.empty_cache()
return asr_result_list
-
+
def inference_with_vad(self, input, input_len=None, **cfg):
-
+ kwargs = self.kwargs
# step.1: compute the vad model
self.vad_kwargs.update(cfg)
beg_vad = time.time()
res = self.inference(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
end_vad = time.time()
- print(f"time cost vad: {end_vad - beg_vad:0.3f}")
# step.2 compute asr model
model = self.model
- kwargs = self.kwargs
kwargs.update(cfg)
batch_size = int(kwargs.get("batch_size_s", 300))*1000
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
kwargs["batch_size"] = batch_size
-
+
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None))
results_ret_list = []
time_speech_total_all_samples = 1e-6
beg_total = time.time()
- pbar_total = tqdm(colour="red", total=len(res), dynamic_ncols=True)
+ pbar_total = tqdm(colour="red", total=len(res), dynamic_ncols=True) if not kwargs.get("disable_pbar", False) else None
for i in range(len(res)):
key = res[i]["key"]
vadsegments = res[i]["value"]
@@ -296,14 +304,14 @@
data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = []
-
+
if not len(sorted_data):
logging.info("decoding, utt: {}, empty speech".format(key))
continue
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
-
+
batch_size_ms_cum = 0
beg_idx = 0
beg_asr_total = time.time()
@@ -322,8 +330,8 @@
continue
batch_size_ms_cum = 0
end_idx = j + 1
- speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
- results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, disable_pbar=True, **cfg)
+ speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
+ results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
if self.spk_model is not None:
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
for _b in range(len(speech_j)):
@@ -333,26 +341,26 @@
segments = sv_chunk(vad_segments)
all_segments.extend(segments)
speech_b = [i[2] for i in segments]
- spk_res = self.inference(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, disable_pbar=True, **cfg)
+ spk_res = self.inference(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
results[_b]['spk_embedding'] = spk_res[0]['spk_embedding']
beg_idx = end_idx
if len(results) < 1:
continue
results_sorted.extend(results)
-
+
# end_asr_total = time.time()
# time_escape_total_per_sample = end_asr_total - beg_asr_total
# pbar_sample.update(1)
# pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
# f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
# f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
-
+
restored_data = [0] * n
for j in range(n):
index = sorted_data[j][1]
restored_data[index] = results_sorted[j]
result = {}
-
+
# results combine for texts, timestamps, speaker embeddings and others
# TODO: rewrite for clean code
for j in range(n):
@@ -379,18 +387,21 @@
result[k] = restored_data[j][k]
else:
result[k] += restored_data[j][k]
-
- return_raw_text = kwargs.get('return_raw_text', False)
+
+ return_raw_text = kwargs.get('return_raw_text', False)
# step.3 compute punc model
if self.punc_model is not None:
- self.punc_kwargs.update(cfg)
- punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, disable_pbar=True, **cfg)
- raw_text = copy.copy(result["text"])
- if return_raw_text: result['raw_text'] = raw_text
- result["text"] = punc_res[0]["text"]
+ if not len(result["text"]):
+ result['raw_text'] = ''
+ else:
+ self.punc_kwargs.update(cfg)
+ punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
+ raw_text = copy.copy(result["text"])
+ if return_raw_text: result['raw_text'] = raw_text
+ result["text"] = punc_res[0]["text"]
else:
raw_text = None
-
+
# speaker embedding cluster after resorted
if self.spk_model is not None and kwargs.get('return_spk_res', True):
if raw_text is None:
@@ -429,13 +440,14 @@
return_raw_text=return_raw_text)
result['sentence_info'] = sentence_list
if "spk_embedding" in result: del result['spk_embedding']
-
+
result["key"] = key
results_ret_list.append(result)
end_asr_total = time.time()
time_escape_total_per_sample = end_asr_total - beg_asr_total
- pbar_total.update(1)
- pbar_total.set_description(f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+ if pbar_total:
+ pbar_total.update(1)
+ pbar_total.set_description(f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
f"time_speech: {time_speech_total_per_sample: 0.3f}, "
f"time_escape: {time_escape_total_per_sample:0.3f}")
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 4538224..569757a 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -85,7 +85,9 @@
# build model
model_class = tables.model_classes.get(kwargs["model"])
- model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+ vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None
+ vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size
+ model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
@@ -103,13 +105,15 @@
path=p,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
oss_bucket=kwargs.get("oss_bucket", None),
- scope_map=kwargs.get("scope_map", None),
+ scope_map=kwargs.get("scope_map", []),
excludes=kwargs.get("excludes", None),
)
else:
logging.info(f"Checkpoint does not exist, init randomly: {p}")
- else:
+ elif kwargs.get("init", None):
initialize(model, kwargs.get("init", "kaiming_normal"))
+ else:
+ print("No initialize method")
# freeze_param
diff --git a/funasr/datasets/llm_datasets/__init__.py b/funasr/datasets/llm_datasets/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/llm_datasets/__init__.py
diff --git a/funasr/datasets/llm_datasets/datasets.py b/funasr/datasets/llm_datasets/datasets.py
new file mode 100644
index 0000000..9673d76
--- /dev/null
+++ b/funasr/datasets/llm_datasets/datasets.py
@@ -0,0 +1,131 @@
+import torch
+import copy
+
+from funasr.register import tables
+from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
+
+
+@tables.register("dataset_classes", "AudioLLMDataset")
+class AudioLLMDataset(torch.utils.data.Dataset):
+ """
+ AudioLLMDataset
+ """
+ def __init__(self,
+ path,
+ index_ds: str = None,
+ frontend=None,
+ tokenizer=None,
+ int_pad_value: int = -1,
+ float_pad_value: float = 0.0,
+ **kwargs):
+ super().__init__()
+ index_ds_class = tables.index_ds_classes.get(index_ds)
+ self.index_ds = index_ds_class(path, **kwargs)
+ preprocessor_speech = kwargs.get("preprocessor_speech", None)
+ if preprocessor_speech:
+ preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
+ preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
+ self.preprocessor_speech = preprocessor_speech
+ preprocessor_text = kwargs.get("preprocessor_text", None)
+ if preprocessor_text:
+ preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
+ preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
+ self.preprocessor_text = preprocessor_text
+
+ self.frontend = frontend
+ self.fs = 16000 if frontend is None else frontend.fs
+ self.data_type = "sound"
+ self.tokenizer = tokenizer
+
+ self.float_pad_value = float_pad_value
+ self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
+ self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
+ self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
+ self.prompt_af = ""
+ self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
+ self.int_pad_value = self.IGNORE_INDEX
+
+ def get_source_len(self, index):
+ item = self.index_ds[index]
+ return self.index_ds.get_source_len(item)
+
+ def get_target_len(self, index):
+ item = self.index_ds[index]
+ return self.index_ds.get_target_len(item)
+
+ def __len__(self):
+ return len(self.index_ds)
+
+ def __getitem__(self, index):
+ item = self.index_ds[index]
+ # import pdb;
+ # pdb.set_trace()
+ source = item["source"]
+ data_src = load_audio_text_image_video(source, fs=self.fs)
+ if self.preprocessor_speech:
+ data_src = self.preprocessor_speech(data_src, fs=self.fs)
+ speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
+ speech = speech.squeeze(0)
+
+ target = item["target"]
+ if self.preprocessor_text:
+ target = self.preprocessor_text(target)
+
+
+ prompt_ids_pre = self.tokenizer.encode(self.prompt_pre) # [bos,prompt]
+ prompt_pre_length = len(prompt_ids_pre)
+
+ prompt_input = "{}{}".format(self.prompt_pre, target)
+ prompt_input_ids = self.tokenizer.encode(prompt_input)
+ audio_length = len(prompt_input_ids) - prompt_pre_length
+ input_ids = prompt_input_ids + [self.tokenizer.pad_token_id]
+ input_ids = torch.tensor(input_ids, dtype=torch.int64) #[bos, prompt, input, pad]
+ input_ids[prompt_pre_length:] = -1 # [bos, prompt,-1,-1]
+ attention_mask = input_ids.ge(-1) # [true, true, true, true], length mask
+
+ prompt_answer = "{}{}".format(self.prompt_pre, target)
+ prompt_answer_ids = self.tokenizer.encode(prompt_answer)
+ answer_length = len(prompt_answer_ids) - prompt_pre_length
+ labels_ids = copy.deepcopy(prompt_input_ids) + [self.tokenizer.eos_token_id]
+ labels_ids = torch.tensor(labels_ids, dtype=torch.int64) # [bos, prompt, input, eos]
+ labels_ids[:prompt_pre_length] = -1 # [-1, -1, input, eos]
+ label_mask = labels_ids.ge(0) # [False,False,True,True]
+ labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,input,eos]
+
+ audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0]
+ audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
+
+ ids = self.tokenizer.encode(target) # token ids is different from labels_ids
+ text = torch.tensor(ids, dtype=torch.int64)
+ text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
+
+ return {"speech": speech,
+ "speech_lengths": speech_lengths,
+ "text": text,
+ "text_lengths": text_lengths,
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels_ids": labels_ids,
+ "label_mask": label_mask,
+ "audio_mask": audio_mask,
+ }
+
+
+ def collator(self, samples: list=None):
+ outputs = {}
+ for sample in samples:
+ for key in sample.keys():
+ if key not in outputs:
+ outputs[key] = []
+ outputs[key].append(sample[key])
+
+ for key, data_list in outputs.items():
+ if isinstance(data_list[0], torch.Tensor):
+ if data_list[0].dtype == torch.int64:
+
+ pad_value = self.int_pad_value
+ else:
+ pad_value = self.float_pad_value
+
+ outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+ return outputs
diff --git a/funasr/datasets/llm_datasets/preprocessor.py b/funasr/datasets/llm_datasets/preprocessor.py
new file mode 100644
index 0000000..9f20672
--- /dev/null
+++ b/funasr/datasets/llm_datasets/preprocessor.py
@@ -0,0 +1,37 @@
+import os
+import json
+import torch
+import logging
+import concurrent.futures
+import librosa
+import torch.distributed as dist
+from typing import Collection
+import torch
+import torchaudio
+from torch import nn
+import random
+import re
+import string
+from funasr.tokenizer.cleaner import TextCleaner
+from funasr.register import tables
+
+
+
+@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
+class TextPreprocessSegDict(nn.Module):
+ def __init__(self,
+ **kwargs):
+ super().__init__()
+
+
+ def forward(self, text, **kwargs):
+ # 瀹氫箟鑻辨枃鏍囩偣绗﹀彿
+ en_punct = string.punctuation
+ # 瀹氫箟涓枃鏍囩偣绗﹀彿锛堥儴鍒嗗父鐢ㄧ殑锛�
+ cn_punct = '銆傦紵锛侊紝銆侊紱锛氣�溾�濃�樷�欙紙锛夈�娿�嬨�愩�戔�︹�旓綖路'
+ # 鍚堝苟鑻辨枃鍜屼腑鏂囨爣鐐圭鍙�
+ all_punct = en_punct + cn_punct
+ # 鍒涘缓姝e垯琛ㄨ揪寮忔ā寮忥紝鍖归厤浠讳綍鍦╝ll_punct涓殑瀛楃
+ punct_pattern = re.compile('[{}]'.format(re.escape(all_punct)))
+ # 浣跨敤姝e垯琛ㄨ揪寮忕殑sub鏂规硶鏇挎崲鎺夎繖浜涘瓧绗�
+ return punct_pattern.sub('', text)
diff --git a/funasr/datasets/llm_datasets/samplers.py b/funasr/datasets/llm_datasets/samplers.py
new file mode 100644
index 0000000..914e776
--- /dev/null
+++ b/funasr/datasets/llm_datasets/samplers.py
@@ -0,0 +1,277 @@
+import torch
+import numpy as np
+import logging
+import torch.distributed as dist
+
+from funasr.register import tables
+
+
+@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
+class BatchSampler(torch.utils.data.BatchSampler):
+
+ def __init__(self, dataset,
+ batch_type: str = "example",
+ batch_size: int = 100,
+ buffer_size: int = 30,
+ drop_last: bool = False,
+ shuffle: bool = True,
+ is_training: bool = True,
+ **kwargs):
+
+ self.drop_last = drop_last
+ self.pre_idx = -1
+ self.dataset = dataset
+ self.total_samples = len(dataset)
+ self.batch_type = batch_type
+ self.batch_size = int(batch_size)
+ self.buffer_size = buffer_size
+ self.max_token_length = kwargs.get("max_token_length", 5000)
+ self.shuffle_idx = np.arange(self.total_samples)
+ self.shuffle = shuffle and is_training
+ self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+
+
+ def __len__(self):
+ return (self.total_samples-1) // self.batch_size + 1
+
+ def set_epoch(self, epoch):
+ np.random.seed(epoch)
+
+ def __iter__(self):
+
+ if self.shuffle:
+ np.random.shuffle(self.shuffle_idx)
+
+ batch = []
+ max_token = 0
+ num_sample = 0
+
+ iter_num = (self.total_samples - 1) // self.buffer_size + 1
+ # print("iter_num: ", iter_num)
+ for iter in range(self.pre_idx + 1, iter_num):
+ datalen_with_index = []
+ for i in range(self.buffer_size):
+ idx = iter * self.buffer_size + i
+ if idx >= self.total_samples:
+ continue
+
+ idx_map = self.shuffle_idx[idx]
+ # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
+ target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
+ source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
+ sample_len_cur = source_len + target_len
+
+
+ datalen_with_index.append([idx, sample_len_cur])
+
+ datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
+ for item in datalen_with_index_sort:
+ idx, sample_len_cur_raw = item
+ if sample_len_cur_raw > self.max_token_length:
+ continue
+
+ max_token_cur = max(max_token, sample_len_cur_raw)
+ max_token_padding = 1 + num_sample
+ if self.batch_type != 'example':
+ max_token_padding *= max_token_cur
+ if max_token_padding <= self.batch_size:
+ batch.append(idx)
+ max_token = max_token_cur
+ num_sample += 1
+ else:
+ yield batch
+ batch = [idx]
+ max_token = sample_len_cur_raw
+ num_sample = 1
+
+
+@tables.register("batch_sampler_classes", "BatchSampler")
+@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
+class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
+
+ def __init__(self, dataset,
+ batch_type: str = "example",
+ batch_size: int = 100,
+ buffer_size: int = 30,
+ drop_last: bool = True,
+ shuffle: bool = True,
+ is_training: bool = True,
+ **kwargs):
+
+ self.drop_last = drop_last
+ self.pre_idx = -1
+ self.dataset = dataset
+ self.total_samples = len(dataset)
+ self.batch_type = batch_type
+ self.batch_size = int(batch_size)
+ self.buffer_size = buffer_size
+ self.max_token_length = kwargs.get("max_token_length", 1500)
+ self.shuffle_idx = np.arange(self.total_samples)
+ self.shuffle = shuffle and is_training
+ self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+
+ try:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ except:
+ rank = 0
+ world_size = 1
+ self.rank = rank
+ self.world_size = world_size
+
+ def __len__(self):
+ return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
+
+ def set_epoch(self, epoch):
+ np.random.seed(epoch)
+
+ def __iter__(self):
+
+ batch_size_total = self.batch_size * self.world_size
+
+ if self.shuffle:
+ np.random.shuffle(self.shuffle_idx)
+
+ batch = []
+ max_token = 0
+ num_sample = 0
+
+ iter_num = (self.total_samples - 1) // self.buffer_size + 1
+ # print("iter_num: ", iter_num)
+ for iter in range(self.pre_idx + 1, iter_num):
+ # if iter == iter_num -1 and self.drop_last:
+ # continue
+ datalen_with_index = []
+ for i in range(self.buffer_size):
+ idx = iter * self.buffer_size + i
+ if idx >= self.total_samples:
+ continue
+
+ idx_map = self.shuffle_idx[idx]
+ # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
+
+ source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
+ target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
+ sample_len_cur = source_len + target_len
+
+ datalen_with_index.append([idx, sample_len_cur])
+
+ datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
+ for item in datalen_with_index_sort:
+ idx, sample_len_cur_raw = item
+ if sample_len_cur_raw > self.max_token_length:
+ continue
+
+ max_token_cur = max(max_token, sample_len_cur_raw)
+ max_token_padding = 1 + num_sample
+ # if self.batch_type != 'example':
+ # max_token_padding *= max_token_cur
+ if max_token_padding <= batch_size_total:
+ batch.append(idx)
+ max_token = max_token_cur
+ num_sample += 1
+ else:
+ batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
+ yield batch_rank
+ batch = [idx]
+ max_token = sample_len_cur_raw
+ num_sample = 1
+
+
+@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
+class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
+
+ def __init__(self, dataset,
+ batch_type: str = "example",
+ batch_size: int = 100,
+ buffer_size: int = 30,
+ drop_last: bool = True,
+ shuffle: bool = True,
+ is_training: bool = True,
+ **kwargs):
+
+ self.drop_last = drop_last
+ self.pre_idx = -1
+ self.dataset = dataset
+ self.total_samples = len(dataset)
+ self.batch_type = batch_type
+ self.batch_size = int(batch_size)
+ self.buffer_size = buffer_size
+ self.max_token_length = kwargs.get("max_token_length", 1500)
+ self.shuffle_idx = np.arange(self.total_samples)
+ self.shuffle = shuffle and is_training
+ self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+
+ try:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ except:
+ rank = 0
+ world_size = 1
+ self.rank = rank
+ self.world_size = world_size
+
+ def __len__(self):
+ return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
+
+ def set_epoch(self, epoch):
+ np.random.seed(epoch)
+
+ def __iter__(self):
+
+ batch_size_total = self.batch_size * self.world_size
+ if self.shuffle:
+ np.random.shuffle(self.shuffle_idx)
+
+ batch_list_all_rank = []
+ batch_list_cur = []
+ max_token = 0
+ num_sample = 0
+
+ iter_num = (self.total_samples - 1) // self.buffer_size + 1
+ # print("iter_num: ", iter_num)
+ for iter in range(self.pre_idx + 1, iter_num):
+ # if iter == iter_num - 1 and self.drop_last:
+ # continue
+ datalen_with_index = []
+ for i in range(self.buffer_size):
+ idx = iter * self.buffer_size + i
+ if idx >= self.total_samples:
+ continue
+
+ idx_map = self.shuffle_idx[idx]
+ # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
+
+ source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
+ target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
+ sample_len_cur = source_len + target_len
+
+ datalen_with_index.append([idx, sample_len_cur])
+
+ datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
+ for ii, item in enumerate(datalen_with_index_sort):
+ is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort)
+ idx, sample_len_cur_raw = item
+ if sample_len_cur_raw > self.max_token_length:
+ continue
+
+ max_token_cur = max(max_token, sample_len_cur_raw)
+ max_token_padding = 1 + num_sample
+
+ if self.batch_type != 'example':
+ max_token_padding *= max_token_cur
+ if len(batch_list_all_rank) < self.world_size:
+
+ if max_token_padding <= self.batch_size:
+ batch_list_cur.append(idx)
+ max_token = max_token_cur
+ num_sample += 1
+ else:
+ batch_list_all_rank.append(batch_list_cur)
+ batch_list_cur = []
+ else:
+ batch_rank = batch_list_all_rank[self.rank]
+ yield batch_rank
+ batch_list_all_rank = [idx]
+ max_token = sample_len_cur_raw
+ num_sample = 1
diff --git a/funasr/metrics/compute_acc.py b/funasr/metrics/compute_acc.py
index 9d16e1f..ec8067f 100644
--- a/funasr/metrics/compute_acc.py
+++ b/funasr/metrics/compute_acc.py
@@ -21,3 +21,20 @@
)
denominator = torch.sum(mask)
return float(numerator) / float(denominator)
+
+def compute_accuracy(pad_outputs, pad_targets, ignore_label):
+ """Calculate accuracy.
+
+ Args:
+ pad_outputs (LongTensor): Prediction tensors (B, Lmax).
+ pad_targets (LongTensor): Target label tensors (B, Lmax).
+ ignore_label (int): Ignore label id.
+
+ Returns:
+ float: Accuracy value (0.0 - 1.0).
+
+ """
+ mask = pad_targets != ignore_label
+ numerator = torch.sum(pad_outputs.masked_select(mask) == pad_targets.masked_select(mask))
+ denominator = torch.sum(mask)
+ return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type
\ No newline at end of file
diff --git a/funasr/models/llm_asr_nar/__init__.py b/funasr/models/llm_asr_nar/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/llm_asr_nar/__init__.py
diff --git a/funasr/models/llm_asr_nar/adaptor.py b/funasr/models/llm_asr_nar/adaptor.py
new file mode 100644
index 0000000..0676e7d
--- /dev/null
+++ b/funasr/models/llm_asr_nar/adaptor.py
@@ -0,0 +1,29 @@
+import torch
+import torch.nn as nn
+
+from funasr.register import tables
+
+@tables.register("adaptor_classes", "Linear")
+class Linear(nn.Module):
+ def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
+ super().__init__()
+ self.k = downsample_rate
+ self.encoder_dim = encoder_dim
+ self.llm_dim = llm_dim
+ self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
+ self.relu = nn.ReLU()
+ self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
+
+ def forward(self, x):
+ batch_size, seq_len, dim = x.size()
+ num_frames_to_discard = seq_len % self.k
+ if num_frames_to_discard > 0:
+ x = x[:, :-num_frames_to_discard, :]
+ seq_len = x.size(1)
+
+ x = x.contiguous()
+ x = x.view(batch_size, seq_len // self.k, dim * self.k)
+ x = self.linear1(x)
+ x = self.relu(x)
+ x = self.linear2(x)
+ return x
diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py
new file mode 100644
index 0000000..6a4ecce
--- /dev/null
+++ b/funasr/models/llm_asr_nar/model.py
@@ -0,0 +1,338 @@
+import logging
+from typing import Union, Dict, List, Tuple, Optional
+
+import time
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.cuda.amp import autocast
+
+from funasr.models.scama.utils import sequence_mask
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.ctc.ctc import CTC
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy, compute_accuracy
+# from funasr.models.e2e_asr_common import ErrorCalculator
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.register import tables
+
+
+@tables.register("model_classes", "LLMASRNAR")
+class LLMASRNAR(nn.Module):
+ """ """
+
+ def __init__(
+ self,
+ specaug: str = None,
+ specaug_conf: dict = None,
+ normalize: str = None,
+ normalize_conf: dict = None,
+ encoder: str = None,
+ encoder_conf: dict = None,
+ decoder: str = None,
+ decoder_conf: dict = None,
+ ctc: str = None,
+ ctc_conf: dict = None,
+ ctc_weight: float = 0.5,
+ llm: str = None,
+ llm_conf: dict = None,
+ adaptor: str = None,
+ adaptor_conf: dict = None,
+ input_size: int = 80,
+ vocab_size: int = -1,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ # extract_feats_in_collect_stats: bool = True,
+ share_embedding: bool = False,
+ # preencoder: Optional[AbsPreEncoder] = None,
+ # postencoder: Optional[AbsPostEncoder] = None,
+ **kwargs,
+ ):
+
+ super().__init__()
+
+ if specaug is not None:
+ specaug_class = tables.specaug_classes.get(specaug)
+ specaug = specaug_class(**specaug_conf)
+ if normalize is not None:
+ normalize_class = tables.normalize_classes.get(normalize)
+ normalize = normalize_class(**normalize_conf)
+
+ # audio encoder
+ hub = encoder_conf.get("hub", None)
+ if hub == "funasr":
+ from funasr import AutoModel
+ init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+ model = AutoModel(model=init_param_path, model_revision="v2.0.4")
+ # frontend = model.kwargs.get("frontend")
+ model.model.decoder = None
+
+ self.audio_encoder = model.model
+ # self.frontend = frontend
+
+ elif hub == "hf":
+ pass
+ else:
+ encoder_class = tables.encoder_classes.get(encoder)
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
+ encoder_output_size = encoder.output_size()
+
+ # llm
+ hub = llm_conf.get("hub", "hf")
+ self.llm = None
+ if hub == "hf":
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
+
+ init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+ model = AutoModelForCausalLM.from_pretrained(
+ init_param_path,
+ load_in_8bit=None,
+ device_map=None,
+ use_cache=None,
+ )
+ freeze = llm_conf.get("freeze", True)
+ if freeze:
+ for name, param in model.named_parameters():
+ param.requires_grad = False
+ model.eval()
+ self.llm = model
+
+ # adaptor
+ adaptor_class = tables.adaptor_classes.get(adaptor)
+ adaptor = adaptor_class(**adaptor_conf)
+
+ self.adaptor = adaptor
+
+
+ self.blank_id = blank_id
+ self.sos = sos if sos is not None else vocab_size - 1
+ self.eos = eos if eos is not None else vocab_size - 1
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.specaug = specaug
+ self.normalize = normalize
+ self.encoder = encoder
+
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+ #
+ # if report_cer or report_wer:
+ # self.error_calculator = ErrorCalculator(
+ # token_list, sym_space, sym_blank, report_cer, report_wer
+ # )
+ #
+ self.error_calculator = None
+
+ self.length_normalized_loss = length_normalized_loss
+ self.beam_search = None
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ input_ids: torch.Tensor,
+ attention_mask:torch.Tensor,
+ labels_ids: torch.Tensor,
+ label_mask: torch.Tensor,
+ audio_mask: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ # import pdb;
+ # pdb.set_trace()
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ # audio encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)
+
+ # adaptor
+ encoder_out = self.adaptor(encoder_out)
+
+ if input_ids is not None:
+ input_ids[input_ids == -1] = 0
+ input_ids[input_ids == -100] = 0
+ if hasattr(self.llm.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.embed_tokens(input_ids)
+ elif hasattr(self.llm.model.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
+ else:
+ inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
+
+ if audio_mask is not None:
+ batch_size, token_num, dims = inputs_embeds.shape
+ _, l, _ = encoder_out.shape
+ encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
+ inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
+ inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
+
+ model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
+ loss = model_outputs.loss
+
+
+ stats = {}
+ with torch.no_grad():
+ preds = torch.argmax(model_outputs.logits, -1)
+ acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
+ stats["acc"] = acc_att
+
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = int((text_lengths + 1).sum())
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ audio_mask = kwargs.get("audio_mask", None)
+ audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
+
+ batch = {"speech": speech, "speech_lengths": speech_lengths}
+ enc, enc_lens = self.audio_encoder.encode(**batch)
+ with autocast(False):
+ enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
+ pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
+ mask=enc_mask,
+ target_label_length=audio_token_lengths,
+ )
+
+ return pre_acoustic_embeds, pre_token_length
+
+
+ def inference(self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+
+ prompt = kwargs.get("prompt", "Transcribe speech to text.")
+
+ if kwargs.get("batch_size", 1) > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+
+
+
+ meta_data = {}
+ if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is None:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer)
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ # adaptor
+ encoder_out = self.adaptor(encoder_out)
+
+
+ prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
+ prompt_ids = tokenizer.encode(prompt_pre)
+ prompt_length = len(prompt_ids)
+ prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
+
+
+ if hasattr(self.llm.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
+ elif hasattr(self.llm.model.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
+ else:
+ inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
+
+ inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio]
+ attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
+
+ # model_outputs = self.llm.generate(
+ # inputs_embeds=inputs_embeds,
+ # max_length=kwargs.get("max_length", 200),
+ # max_new_tokens=kwargs.get("max_new_tokens", 200),
+ # num_beams=kwargs.get("num_beams", 4),
+ # do_sample=kwargs.get("do_sample", False),
+ # min_length=kwargs.get("min_length", 1),
+ # top_p=kwargs.get("top_p", 1.0),
+ # repetition_penalty=kwargs.get("repetition_penalty", 1.0),
+ # length_penalty=kwargs.get("length_penalty", 1.0),
+ # temperature=kwargs.get("temperature", 1.0),
+ # attention_mask=attention_mask,
+ # bos_token_id=tokenizer.bos_token_id,
+ # eos_token_id=tokenizer.eos_token_id,
+ # pad_token_id=tokenizer.pad_token_id
+ # )
+
+
+ model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
+ preds = torch.argmax(model_outputs.logits, -1)
+ text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
+ text = text[0].split(': \n')[-1]
+ # preds = torch.argmax(model_outputs.logits, -1)
+
+ ibest_writer = None
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"{0 + 1}best_recog"]
+
+ results = []
+ result_i = {"key": key[0], "text": text}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["text"][key[0]] = text
+
+
+
+
+ return results, meta_data
+
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 60ddc24..4d9f5d8 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -10,7 +10,7 @@
from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
-
+from torch.cuda.amp import autocast
@tables.register("predictor_classes", "CifPredictor")
class CifPredictor(torch.nn.Module):
@@ -28,42 +28,44 @@
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
target_label_length=None):
- h = hidden
- context = h.transpose(1, 2)
- queries = self.pad(context)
- memory = self.cif_conv1d(queries)
- output = memory + context
- output = self.dropout(output)
- output = output.transpose(1, 2)
- output = torch.relu(output)
- output = self.cif_output(output)
- alphas = torch.sigmoid(output)
- alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
- if mask is not None:
- mask = mask.transpose(-1, -2).float()
- alphas = alphas * mask
- if mask_chunk_predictor is not None:
- alphas = alphas * mask_chunk_predictor
- alphas = alphas.squeeze(-1)
- mask = mask.squeeze(-1)
- if target_label_length is not None:
- target_length = target_label_length
- elif target_label is not None:
- target_length = (target_label != ignore_id).float().sum(-1)
- else:
- target_length = None
- token_num = alphas.sum(-1)
- if target_length is not None:
- alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
- elif self.tail_threshold > 0.0:
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+
+ with autocast(False):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ memory = self.cif_conv1d(queries)
+ output = memory + context
+ output = self.dropout(output)
+ output = output.transpose(1, 2)
+ output = torch.relu(output)
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+ if mask is not None:
+ mask = mask.transpose(-1, -2).float()
+ alphas = alphas * mask
+ if mask_chunk_predictor is not None:
+ alphas = alphas * mask_chunk_predictor
+ alphas = alphas.squeeze(-1)
+ mask = mask.squeeze(-1)
+ if target_label_length is not None:
+ target_length = target_label_length
+ elif target_label is not None:
+ target_length = (target_label != ignore_id).float().sum(-1)
+ else:
+ target_length = None
+ token_num = alphas.sum(-1)
+ if target_length is not None:
+ alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
+ elif self.tail_threshold > 0.0:
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-
- if target_length is None and self.tail_threshold > 0.0:
- token_num_int = torch.max(token_num).type(torch.int32).item()
- acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
-
+ if target_length is None and self.tail_threshold > 0.0:
+ token_num_int = torch.max(token_num).type(torch.int32).item()
+ acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+
return acoustic_embeds, token_num, alphas, cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
@@ -169,41 +171,43 @@
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
target_label_length=None):
- h = hidden
- context = h.transpose(1, 2)
- queries = self.pad(context)
- output = torch.relu(self.cif_conv1d(queries))
- output = output.transpose(1, 2)
-
- output = self.cif_output(output)
- alphas = torch.sigmoid(output)
- alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
- if mask is not None:
- mask = mask.transpose(-1, -2).float()
- alphas = alphas * mask
- if mask_chunk_predictor is not None:
- alphas = alphas * mask_chunk_predictor
- alphas = alphas.squeeze(-1)
- mask = mask.squeeze(-1)
- if target_label_length is not None:
- target_length = target_label_length.squeeze(-1)
- elif target_label is not None:
- target_length = (target_label != ignore_id).float().sum(-1)
- else:
- target_length = None
- token_num = alphas.sum(-1)
- if target_length is not None:
- alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
- elif self.tail_threshold > 0.0:
- if self.tail_mask:
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+
+ with autocast(False):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ output = torch.relu(self.cif_conv1d(queries))
+ output = output.transpose(1, 2)
+
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+ if mask is not None:
+ mask = mask.transpose(-1, -2).float()
+ alphas = alphas * mask
+ if mask_chunk_predictor is not None:
+ alphas = alphas * mask_chunk_predictor
+ alphas = alphas.squeeze(-1)
+ mask = mask.squeeze(-1)
+ if target_label_length is not None:
+ target_length = target_label_length.squeeze(-1)
+ elif target_label is not None:
+ target_length = (target_label != ignore_id).float().sum(-1)
else:
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
-
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
- if target_length is None and self.tail_threshold > 0.0:
- token_num_int = torch.max(token_num).type(torch.int32).item()
- acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+ target_length = None
+ token_num = alphas.sum(-1)
+ if target_length is not None:
+ alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
+ elif self.tail_threshold > 0.0:
+ if self.tail_mask:
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+ else:
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
+
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+ if target_length is None and self.tail_threshold > 0.0:
+ token_num_int = torch.max(token_num).type(torch.int32).item()
+ acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak
@@ -370,62 +374,6 @@
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
return predictor_alignments.detach(), predictor_alignments_length.detach()
-
- def gen_tf2torch_map_dict(self):
-
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
- ## predictor
- "{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (2, 1, 0),
- }, # (256,256,3),(3,256,256)
- "{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.cif_output.weight".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1,256),(1,256,1)
- "{}.cif_output.bias".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1,),(1,)
- }
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
- map_dict = self.gen_tf2torch_map_dict()
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- return var_dict_torch_update
class mae_loss(torch.nn.Module):
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index cfdd26a..21ad874 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -25,8 +25,8 @@
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.bicif_paraformer.model import BiCifParaformer
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
diff --git a/funasr/tokenizer/hf_tokenizer.py b/funasr/tokenizer/hf_tokenizer.py
new file mode 100644
index 0000000..c856b3d
--- /dev/null
+++ b/funasr/tokenizer/hf_tokenizer.py
@@ -0,0 +1,15 @@
+
+try:
+ from transformers import AutoTokenizer
+except:
+ print("If you want to use hugging, please `pip install -U transformers`")
+
+from funasr.register import tables
+
+@tables.register("tokenizer_classes", "HuggingfaceTokenizer")
+def HuggingfaceTokenizer(init_param_path, **kwargs):
+
+ tokenizer = AutoTokenizer.from_pretrained(init_param_path)
+
+ return tokenizer
+
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index 8493bf5..84c6320 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -38,52 +38,17 @@
)
return match_state
-def assigment_scope_map(dst_state: dict, src_state: dict, scope_map: str=None):
- """Compute the union of the current variables and checkpoint variables."""
- import collections
- import re
-
- # current model variables
- name_to_variable = collections.OrderedDict()
- for name, var in dst_state.items():
- name_to_variable[name] = var
-
- scope_map_num = 0
- if scope_map is not None:
- scope_map = scope_map.split(",")
- scope_map_num = len(scope_map) // 2
- for scope_map_idx in range(scope_map_num):
- scope_map_id = scope_map_idx * 2
- logging.info('assignment_map from scope {} to {}'.format(scope_map[scope_map_id], scope_map[scope_map_id+1]))
-
- assignment_map = {}
- for name, var in src_state.items():
-
- if scope_map:
- for scope_map_idx in range(scope_map_num):
- scope_map_id = scope_map_idx * 2
- try:
- idx = name.index(scope_map[scope_map_id])
- new_name = scope_map[scope_map_id+1] + name[idx + len(scope_map[scope_map_id]):]
- if new_name in name_to_variable:
- assignment_map[name] = var
- except:
- continue
- else:
- if name in name_to_variable:
- assignment_map[name] = var
-
- return assignment_map
-
def load_pretrained_model(
path: str,
model: torch.nn.Module,
- ignore_init_mismatch: bool,
+ ignore_init_mismatch: bool=True,
map_location: str = "cpu",
oss_bucket=None,
- scope_map=None,
+ scope_map=[],
excludes=None,
+ ignore_mismatch=False,
+ **kwargs,
):
"""Load a model state and set it to the model.
@@ -108,57 +73,39 @@
src_state = src_state["model"] if "model" in src_state else src_state
+ if isinstance(scope_map, str):
+ scope_map = scope_map.split(",")
+ scope_map += ["module.", "None"]
+
for k in dst_state.keys():
- if not k.startswith("module.") and "module." + k in src_state.keys():
- k_ddp = "module." + k
+
+ k_src = k
+
+ if scope_map is not None:
+ src_prefix = ""
+ dst_prefix = ""
+ for i in range(0, len(scope_map), 2):
+ src_prefix = scope_map[i] if scope_map[i].lower() != "none" else ""
+ dst_prefix = scope_map[i+1] if scope_map[i+1].lower() != "none" else ""
+
+ if dst_prefix == "" and (src_prefix + k) in src_state.keys():
+ k_src = src_prefix + k
+ if not k_src.startswith("module."):
+ print(f"init param, map: {k} from {k_src} in ckpt")
+ elif k.startswith(dst_prefix) and k.replace(dst_prefix, src_prefix, 1) in src_state.keys():
+ k_src = k.replace(dst_prefix, src_prefix, 1)
+ if not k_src.startswith("module."):
+ print(f"init param, map: {k} from {k_src} in ckpt")
+
+ if k_src in src_state.keys():
+ if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape:
+ print(f"ignore_mismatch:{ignore_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}")
+ else:
+ dst_state[k] = src_state[k_src]
+
+
else:
- k_ddp = k
- if k_ddp in src_state:
- dst_state[k] = src_state[k_ddp]
- else:
- print(f"Warning, miss key in ckpt: {k}, mapped: {k_ddp}")
+ print(f"Warning, miss key in ckpt: {k}, mapped: {k_src}")
flag = obj.load_state_dict(dst_state, strict=True)
# print(flag)
-
-# def load_pretrained_model(
-# path: str,
-# model: torch.nn.Module,
-# ignore_init_mismatch: bool,
-# map_location: str = "cpu",
-# oss_bucket=None,
-# scope_map=None,
-# excludes=None,
-# ):
-# """Load a model state and set it to the model.
-#
-# Args:
-# init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
-#
-# Examples:
-#
-# """
-#
-# obj = model
-#
-# if oss_bucket is None:
-# src_state = torch.load(path, map_location=map_location)
-# else:
-# buffer = BytesIO(oss_bucket.get_object(path).read())
-# src_state = torch.load(buffer, map_location=map_location)
-# src_state = src_state["model"] if "model" in src_state else src_state
-#
-# if excludes is not None:
-# for e in excludes.split(","):
-# src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
-#
-# dst_state = obj.state_dict()
-# src_state = assigment_scope_map(dst_state, src_state, scope_map)
-#
-# if ignore_init_mismatch:
-# src_state = filter_state_dict(dst_state, src_state)
-#
-# logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
-# logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
-# dst_state.update(src_state)
-# obj.load_state_dict(dst_state, strict=True)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index d175fbe..3b20596 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -5,7 +5,8 @@
from tqdm import tqdm
from datetime import datetime
import torch.distributed as dist
-from contextlib import nullcontext
+from torch.cuda.amp import autocast, GradScaler
+from contextlib import nullcontext, contextmanager
# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
from pathlib import Path
@@ -13,6 +14,15 @@
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
from funasr.train_utils.average_nbest_models import average_checkpoints
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+
+@contextmanager
+def maybe_autocast(enabled):
+ if enabled:
+ with autocast():
+ yield
+ else:
+ yield
class Trainer:
"""
@@ -36,8 +46,9 @@
dataloader_train,
dataloader_val,
local_rank,
- use_ddp=False,
- use_fsdp=False,
+ use_ddp: bool = False,
+ use_fsdp: bool = False,
+ use_fp16: bool = False,
output_dir: str="./",
**kwargs):
"""
@@ -72,6 +83,11 @@
self.kwargs = kwargs
self.log_interval = kwargs.get("log_interval", 50)
self.batch_total = 0
+ self.use_fp16 = use_fp16
+ self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
+ scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
+ scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
+ self.scaler = scaler
try:
@@ -103,6 +119,8 @@
'optimizer': self.optim.state_dict(),
'scheduler': self.scheduler.state_dict(),
}
+ if self.scaler:
+ state["scaler_state"] = self.scaler.state_dict()
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
@@ -141,6 +159,8 @@
self.model.load_state_dict(dst_state)
self.optim.load_state_dict(checkpoint['optimizer'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
+ if self.scaler and 'scaler_state' in checkpoint:
+ self.scaler.load_state_dict(checkpoint['scaler_state'])
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
print(f"No checkpoint found at '{ckpt}', starting from scratch")
@@ -221,9 +241,10 @@
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
with my_context():
time2 = time.perf_counter()
-
- retval = self.model(**batch)
- torch.cuda.empty_cache()
+ with maybe_autocast(self.use_fp16):
+ retval = self.model(**batch)
+
+ if self.disable_gpu_cache: torch.cuda.empty_cache()
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
@@ -241,7 +262,10 @@
loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss / accum_grad
- loss.backward()
+ if self.use_fp16:
+ self.scaler.scale(loss).backward()
+ else:
+ loss.backward()
time4 = time.perf_counter()
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
@@ -264,10 +288,14 @@
# Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
- self.optim.step()
+ if self.use_fp16:
+ self.scaler.step(self.optim)
+ self.scaler.update()
+ else:
+ self.optim.step()
self.scheduler.step()
# Clear gradients for the next accumulation stage
- self.optim.zero_grad()
+ self.optim.zero_grad(set_to_none=True)
total_time = f"{time.perf_counter() - time5:0.3f}"
time5 = time.perf_counter()
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
diff --git a/runtime/docs/SDK_advanced_guide_online_zh.md b/runtime/docs/SDK_advanced_guide_online_zh.md
index ac711a8..713f9bd 100644
--- a/runtime/docs/SDK_advanced_guide_online_zh.md
+++ b/runtime/docs/SDK_advanced_guide_online_zh.md
@@ -3,7 +3,7 @@
[//]: # (FunASR鎻愪緵鍙究鎹锋湰鍦版垨鑰呬簯绔湇鍔″櫒閮ㄧ讲鐨勫疄鏃惰闊冲惉鍐欐湇鍔★紝鍐呮牳涓篎unASR宸插紑婧恟untime-SDK銆�)
[//]: # (闆嗘垚浜嗚揪鎽╅櫌璇煶瀹為獙瀹ゅ湪Modelscope绀惧尯寮�婧愮殑璇煶绔偣妫�娴�(VAD)銆丳araformer-large闈炴祦寮忚闊宠瘑鍒�(ASR)銆丳araformer-large娴佸紡璇煶璇嗗埆(ASR)銆佹爣鐐�(PUNC) 绛夌浉鍏宠兘鍔涖�傝蒋浠跺寘鏃㈠彲浠ュ疄鏃跺湴杩涜璇煶杞枃瀛楋紝鑰屼笖鑳藉鍦ㄨ璇濆彞灏剧敤楂樼簿搴︾殑杞啓鏂囧瓧淇杈撳嚭锛岃緭鍑烘枃瀛楀甫鏈夋爣鐐癸紝鏀寔楂樺苟鍙戝璺姹�)
-FunASR瀹炴椂璇煶鍚啓杞欢鍖咃紝闆嗘垚浜嗗疄鏃剁増鏈殑璇煶绔偣妫�娴嬫ā鍨嬨�佽闊宠瘑鍒�佽闊宠瘑鍒�佹爣鐐归娴嬫ā鍨嬬瓑銆傞噰鐢ㄥ妯″瀷鍗忓悓锛屾棦鍙互瀹炴椂鐨勮繘琛岃闊宠浆鏂囧瓧锛屼篃鍙互鍦ㄨ璇濆彞灏剧敤楂樼簿搴﹁浆鍐欐枃瀛椾慨姝h緭鍑猴紝杈撳嚭鏂囧瓧甯︽湁鏍囩偣锛屾敮鎸佸璺姹傘�備緷鎹娇鐢ㄨ�呭満鏅笉鍚岋紝鏀寔瀹炴椂璇煶鍚啓鏈嶅姟锛坥nline锛夈�侀潪瀹炴椂涓�鍙ヨ瘽杞啓锛坥ffline锛変笌瀹炴椂涓庨潪瀹炴椂涓�浣撳寲鍗忓悓锛�2pass锛�3绉嶆湇鍔℃ā寮忋�傝蒋浠跺寘鎻愪緵鏈塰tml銆乸ython銆乧++銆乯ava涓巆#绛夊绉嶇紪绋嬭瑷�瀹㈡埛绔紝鐢ㄦ埛鍙互鐩存帴浣跨敤涓庤繘涓�姝ュ紑鍙戙��
+FunASR瀹炴椂璇煶鍚啓杞欢鍖咃紝闆嗘垚浜嗗疄鏃剁増鏈殑璇煶绔偣妫�娴嬫ā鍨嬨�佽闊宠瘑鍒�佹爣鐐归娴嬫ā鍨嬬瓑銆傞噰鐢ㄥ妯″瀷鍗忓悓锛屾棦鍙互瀹炴椂鐨勮繘琛岃闊宠浆鏂囧瓧锛屼篃鍙互鍦ㄨ璇濆彞灏剧敤楂樼簿搴﹁浆鍐欐枃瀛椾慨姝h緭鍑猴紝杈撳嚭鏂囧瓧甯︽湁鏍囩偣锛屾敮鎸佸璺姹傘�備緷鎹娇鐢ㄨ�呭満鏅笉鍚岋紝鏀寔瀹炴椂璇煶鍚啓鏈嶅姟锛坥nline锛夈�侀潪瀹炴椂涓�鍙ヨ瘽杞啓锛坥ffline锛変笌瀹炴椂涓庨潪瀹炴椂涓�浣撳寲鍗忓悓锛�2pass锛�3绉嶆湇鍔℃ā寮忋�傝蒋浠跺寘鎻愪緵鏈塰tml銆乸ython銆乧++銆乯ava涓巆#绛夊绉嶇紪绋嬭瑷�瀹㈡埛绔紝鐢ㄦ埛鍙互鐩存帴浣跨敤涓庤繘涓�姝ュ紑鍙戙��
鏈枃妗d负FunASR瀹炴椂杞啓鏈嶅姟寮�鍙戞寚鍗椼�傚鏋滄偍鎯冲揩閫熶綋楠屽疄鏃惰闊冲惉鍐欐湇鍔★紝鍙弬鑰僛蹇�熶笂鎵媇(#蹇�熶笂鎵�)銆�
diff --git a/runtime/docs/docker_offline_cpu_en_lists b/runtime/docs/docker_offline_cpu_en_lists
index 8361fce..9212110 100644
--- a/runtime/docs/docker_offline_cpu_en_lists
+++ b/runtime/docs/docker_offline_cpu_en_lists
@@ -1,4 +1,5 @@
DOCKER:
+ funasr-runtime-sdk-en-cpu-0.1.4
funasr-runtime-sdk-en-cpu-0.1.3
funasr-runtime-sdk-en-cpu-0.1.2
DEFAULT_ASR_MODEL:
diff --git a/runtime/docs/docker_offline_cpu_zh_lists b/runtime/docs/docker_offline_cpu_zh_lists
index 520da51..5c0578f 100644
--- a/runtime/docs/docker_offline_cpu_zh_lists
+++ b/runtime/docs/docker_offline_cpu_zh_lists
@@ -1,5 +1,5 @@
DOCKER:
- funasr-runtime-sdk-cpu-0.4.2
+ funasr-runtime-sdk-cpu-0.4.3
funasr-runtime-sdk-cpu-0.3.0
funasr-runtime-sdk-cpu-0.2.2
DEFAULT_ASR_MODEL:
diff --git a/runtime/docs/docker_online_cpu_zh_lists b/runtime/docs/docker_online_cpu_zh_lists
index eb6f1d3..49743ea 100644
--- a/runtime/docs/docker_online_cpu_zh_lists
+++ b/runtime/docs/docker_online_cpu_zh_lists
@@ -1,7 +1,7 @@
DOCKER:
+ funasr-runtime-sdk-online-cpu-0.1.8
funasr-runtime-sdk-online-cpu-0.1.7
funasr-runtime-sdk-online-cpu-0.1.6
- funasr-runtime-sdk-online-cpu-0.1.5
DEFAULT_ASR_MODEL:
damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx
diff --git a/runtime/html5/static/index.html b/runtime/html5/static/index.html
index d98c62b..de8139e 100644
--- a/runtime/html5/static/index.html
+++ b/runtime/html5/static/index.html
@@ -51,6 +51,12 @@
</div>
<br>
+ <div id="use_itn_div" style="border:2px solid #ccc;display:block;">
+ 閫嗘枃鏈爣鍑嗗寲(ITN):<br/>
+ <label><input name="use_itn" type="radio" value="false" checked="true"/>鍚� </label>
+ <label><input name="use_itn" type="radio" value="true" />鏄� </label>
+ </div>
+ <br>
<div style="border:2px solid #ccc;">
鐑瘝璁剧疆(涓�琛屼竴涓叧閿瓧锛岀┖鏍奸殧寮�鏉冮噸,濡�"闃块噷宸村反 20")锛�
<br>
diff --git a/runtime/html5/static/main.js b/runtime/html5/static/main.js
index b3661cd..9a5a875 100644
--- a/runtime/html5/static/main.js
+++ b/runtime/html5/static/main.js
@@ -563,4 +563,14 @@
}
+}
+
+function getUseITN() {
+ var obj = document.getElementsByName("use_itn");
+ for (var i = 0; i < obj.length; i++) {
+ if (obj[i].checked) {
+ return obj[i].value === "true";
+ }
+ }
+ return false;
}
\ No newline at end of file
diff --git a/runtime/html5/static/wsconnecter.js b/runtime/html5/static/wsconnecter.js
index 30b99d4..db140ef 100644
--- a/runtime/html5/static/wsconnecter.js
+++ b/runtime/html5/static/wsconnecter.js
@@ -71,7 +71,7 @@
"wav_name": "h5",
"is_speaking": true,
"chunk_interval":10,
- "itn":false,
+ "itn":getUseITN(),
"mode":getAsrMode(),
};
diff --git a/runtime/python/websocket/funasr_wss_server.py b/runtime/python/websocket/funasr_wss_server.py
index 37ca6a9..015d87b 100644
--- a/runtime/python/websocket/funasr_wss_server.py
+++ b/runtime/python/websocket/funasr_wss_server.py
@@ -180,8 +180,8 @@
websocket.wav_name = messagejson.get("wav_name")
if "chunk_size" in messagejson:
chunk_size = messagejson["chunk_size"]
- if isinstance(chunk_size, str):
- chunk_size = chunk_size.split(',')
+ if isinstance(chunk_size, str):
+ chunk_size = chunk_size.split(',')
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
diff --git a/setup.py b/setup.py
index f703bb4..4e76c80 100644
--- a/setup.py
+++ b/setup.py
@@ -40,11 +40,11 @@
"umap_learn",
"jaconv",
"hydra-core>=1.3.2",
+ "tensorboardX",
],
# train: The modules invoked when training only.
"train": [
"editdistance",
- "tensorboardX",
],
# all: The modules should be optionally installled due to some reason.
# Please consider moving them to "install" occasionally
--
Gitblit v1.9.1