From 98c94ab3ab0266482117343a064beeb6bd6bcedc Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 28 二月 2024 20:45:07 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR merge

---
 funasr/models/llm_asr_nar/model.py                              |  338 ++++++++++++++
 .gitignore                                                      |    1 
 funasr/models/llm_asr_nar/__init__.py                           |    0 
 funasr/train_utils/trainer.py                                   |   46 +
 runtime/html5/static/main.js                                    |   10 
 setup.py                                                        |    2 
 runtime/docs/docker_online_cpu_zh_lists                         |    2 
 funasr/models/seaco_paraformer/model.py                         |    2 
 funasr/bin/train.py                                             |   10 
 funasr/datasets/llm_datasets/samplers.py                        |  277 +++++++++++
 funasr/models/llm_asr_nar/adaptor.py                            |   29 +
 funasr/tokenizer/hf_tokenizer.py                                |   15 
 runtime/html5/static/index.html                                 |    6 
 funasr/auto/auto_frontend.py                                    |   10 
 funasr/datasets/llm_datasets/__init__.py                        |    0 
 runtime/python/websocket/funasr_wss_server.py                   |    4 
 runtime/html5/static/wsconnecter.js                             |    2 
 runtime/docs/docker_offline_cpu_zh_lists                        |    2 
 funasr/auto/auto_model.py                                       |   96 ++-
 examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py |    8 
 funasr/datasets/llm_datasets/preprocessor.py                    |   37 +
 funasr/models/paraformer/cif_predictor.py                       |  200 +++-----
 README_zh.md                                                    |   24 
 funasr/train_utils/load_pretrained_model.py                     |  123 +---
 README.md                                                       |   25 
 funasr/datasets/llm_datasets/datasets.py                        |  131 +++++
 runtime/docs/SDK_advanced_guide_online_zh.md                    |    2 
 funasr/metrics/compute_acc.py                                   |   17 
 runtime/docs/docker_offline_cpu_en_lists                        |    1 
 29 files changed, 1,123 insertions(+), 297 deletions(-)

diff --git a/.gitignore b/.gitignore
index adf2937..b0d4692 100644
--- a/.gitignore
+++ b/.gitignore
@@ -25,3 +25,4 @@
 emotion2vec*
 GPT-SoVITS*
 modelscope_models
+examples/aishell/llm_asr_nar/*
diff --git a/README.md b/README.md
index 454adc9..04a3e68 100644
--- a/README.md
+++ b/README.md
@@ -105,10 +105,8 @@
 from funasr import AutoModel
 # paraformer-zh is a multi-functional asr model
 # use vad, punc, spk or not as you need
-model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
-                  vad_model="fsmn-vad", vad_model_revision="v2.0.4",
-                  punc_model="ct-punc-c", punc_model_revision="v2.0.4",
-                  # spk_model="cam++", spk_model_revision="v2.0.2",
+model = AutoModel(model="paraformer-zh",  vad_model="fsmn-vad",  punc_model="ct-punc-c", 
+                  # spk_model="cam++", 
                   )
 res = model.generate(input=f"{model.model_path}/example/asr_example.wav", 
                      batch_size_s=300, 
@@ -125,7 +123,7 @@
 encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
 decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
 
-model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.4")
+model = AutoModel(model="paraformer-zh-streaming")
 
 import soundfile
 import os
@@ -148,17 +146,19 @@
 ```python
 from funasr import AutoModel
 
-model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
+model = AutoModel(model="fsmn-vad")
 wav_file = f"{model.model_path}/example/asr_example.wav"
 res = model.generate(input=wav_file)
 print(res)
 ```
+Note: The output format of the VAD model is: `[[beg1, end1], [beg2, end2], ..., [begN, endN]]`, where `begN/endN` indicates the starting/ending point of the `N-th` valid audio segment, measured in milliseconds.
+
 ### Voice Activity Detection (Streaming)
 ```python
 from funasr import AutoModel
 
 chunk_size = 200 # ms
-model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
+model = AutoModel(model="fsmn-vad")
 
 import soundfile
 
@@ -175,11 +175,18 @@
     if len(res[0]["value"]):
         print(res)
 ```
+Note: The output format for the streaming VAD model can be one of four scenarios:
+- `[[beg1, end1], [beg2, end2], .., [begN, endN]]`锛歍he same as the offline VAD output result mentioned above.
+- `[[beg, -1]]`锛欼ndicates that only a starting point has been detected.
+- `[[-1, end]]`锛欼ndicates that only an ending point has been detected.
+- `[]`锛欼ndicates that neither a starting point nor an ending point has been detected. 
+
+The output is measured in milliseconds and represents the absolute time from the starting point.
 ### Punctuation Restoration
 ```python
 from funasr import AutoModel
 
-model = AutoModel(model="ct-punc", model_revision="v2.0.4")
+model = AutoModel(model="ct-punc")
 res = model.generate(input="閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�")
 print(res)
 ```
@@ -187,7 +194,7 @@
 ```python
 from funasr import AutoModel
 
-model = AutoModel(model="fa-zh", model_revision="v2.0.4")
+model = AutoModel(model="fa-zh")
 wav_file = f"{model.model_path}/example/asr_example.wav"
 text_file = f"{model.model_path}/example/text.txt"
 res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
diff --git a/README_zh.md b/README_zh.md
index 07cdd1f..63ad2e2 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -101,10 +101,8 @@
 from funasr import AutoModel
 # paraformer-zh is a multi-functional asr model
 # use vad, punc, spk or not as you need
-model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
-                  vad_model="fsmn-vad", vad_model_revision="v2.0.4",
-                  punc_model="ct-punc-c", punc_model_revision="v2.0.4",
-                  # spk_model="cam++", spk_model_revision="v2.0.2",
+model = AutoModel(model="paraformer-zh",  vad_model="fsmn-vad", punc_model="ct-punc-c", 
+                  # spk_model="cam++"
                   )
 res = model.generate(input=f"{model.model_path}/example/asr_example.wav", 
             batch_size_s=300, 
@@ -122,7 +120,7 @@
 encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
 decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
 
-model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.4")
+model = AutoModel(model="paraformer-zh-streaming")
 
 import soundfile
 import os
@@ -146,19 +144,21 @@
 ```python
 from funasr import AutoModel
 
-model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
+model = AutoModel(model="fsmn-vad")
 
 wav_file = f"{model.model_path}/example/asr_example.wav"
 res = model.generate(input=wav_file)
 print(res)
 ```
+娉細VAD妯″瀷杈撳嚭鏍煎紡涓猴細`[[beg1, end1], [beg2, end2], .., [begN, endN]]`锛屽叾涓璥begN/endN`琛ㄧず绗琡N`涓湁鏁堥煶棰戠墖娈电殑璧峰鐐�/缁撴潫鐐癸紝
+鍗曚綅涓烘绉掋��
 
 ### 璇煶绔偣妫�娴嬶紙瀹炴椂锛�
 ```python
 from funasr import AutoModel
 
 chunk_size = 200 # ms
-model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
+model = AutoModel(model="fsmn-vad")
 
 import soundfile
 
@@ -175,12 +175,18 @@
     if len(res[0]["value"]):
         print(res)
 ```
+娉細娴佸紡VAD妯″瀷杈撳嚭鏍煎紡涓�4绉嶆儏鍐碉細
+- `[[beg1, end1], [beg2, end2], .., [begN, endN]]`锛氬悓涓婄绾縑AD杈撳嚭缁撴灉銆�
+- `[[beg, -1]]`锛氳〃绀哄彧妫�娴嬪埌璧峰鐐广��
+- `[[-1, end]]`锛氳〃绀哄彧妫�娴嬪埌缁撴潫鐐广��
+- `[]`锛氳〃绀烘棦娌℃湁妫�娴嬪埌璧峰鐐癸紝涔熸病鏈夋娴嬪埌缁撴潫鐐�
+杈撳嚭缁撴灉鍗曚綅涓烘绉掞紝浠庤捣濮嬬偣寮�濮嬬殑缁濆鏃堕棿銆�
 
 ### 鏍囩偣鎭㈠
 ```python
 from funasr import AutoModel
 
-model = AutoModel(model="ct-punc", model_revision="v2.0.4")
+model = AutoModel(model="ct-punc")
 
 res = model.generate(input="閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�")
 print(res)
@@ -190,7 +196,7 @@
 ```python
 from funasr import AutoModel
 
-model = AutoModel(model="fa-zh", model_revision="v2.0.0")
+model = AutoModel(model="fa-zh")
 
 wav_file = f"{model.model_path}/example/asr_example.wav"
 text_file = f"{model.model_path}/example/text.txt"
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
index f043123..c28db7a 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -10,6 +10,8 @@
 
 res = model.generate(input=wav_file)
 print(res)
+# [[beg1, end1], [beg2, end2], .., [begN, endN]]
+# beg/end: ms
 
 
 
@@ -37,3 +39,9 @@
     # print(res)
     if len(res[0]["value"]):
         print(res)
+
+
+# 1. [[beg1, end1], [beg2, end2], .., [begN, endN]]; [[beg, end]]; [[beg1, end1], [beg2, end2]]
+# 2. [[beg, -1]]
+# 3. [[-1, end]]
+# beg/end: ms
\ No newline at end of file
diff --git a/funasr/auto/auto_frontend.py b/funasr/auto/auto_frontend.py
index 8f2f069..35ea23f 100644
--- a/funasr/auto/auto_frontend.py
+++ b/funasr/auto/auto_frontend.py
@@ -1,3 +1,8 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import json
 import time
 import torch
@@ -12,15 +17,14 @@
 from funasr.register import tables
 from funasr.utils.load_utils import load_bytes
 from funasr.download.file import download_from_url
+from funasr.auto.auto_model import prepare_data_iterator
+from funasr.utils.timestamp_tools import timestamp_sentence
 from funasr.download.download_from_hub import download_model
 from funasr.utils.vad_utils import slice_padding_audio_samples
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils.timestamp_tools import timestamp_sentence
 from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
-from funasr.auto.auto_model import prepare_data_iterator
-
 
 
 class AutoFrontend:
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index e5faa2a..a6be691 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -1,3 +1,8 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import json
 import time
 import copy
@@ -12,12 +17,12 @@
 from funasr.register import tables
 from funasr.utils.load_utils import load_bytes
 from funasr.download.file import download_from_url
+from funasr.utils.timestamp_tools import timestamp_sentence
 from funasr.download.download_from_hub import download_model
 from funasr.utils.vad_utils import slice_padding_audio_samples
+from funasr.utils.load_utils import load_audio_text_image_video
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
-from funasr.utils.load_utils import load_audio_text_image_video
-from funasr.utils.timestamp_tools import timestamp_sentence
 from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
 try:
     from funasr.models.campplus.cluster_backend import ClusterBackend
@@ -90,7 +95,7 @@
 class AutoModel:
     
     def __init__(self, **kwargs):
-        if not kwargs.get("disable_log", False):
+        if not kwargs.get("disable_log", True):
             tables.print()
         
         model, kwargs = self.build_model(**kwargs)
@@ -157,8 +162,10 @@
             tokenizer_class = tables.tokenizer_classes.get(tokenizer)
             tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
             kwargs["tokenizer"] = tokenizer
-            kwargs["token_list"] = tokenizer.token_list
-            vocab_size = len(tokenizer.token_list)
+
+            kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
+            kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
+            vocab_size = len(kwargs["token_list"])
         else:
             vocab_size = -1
         
@@ -179,15 +186,18 @@
         # init_param
         init_param = kwargs.get("init_param", None)
         if init_param is not None:
-            logging.info(f"Loading pretrained params from {init_param}")
-            load_pretrained_model(
-                model=model,
-                path=init_param,
-                ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
-                oss_bucket=kwargs.get("oss_bucket", None),
-                scope_map=kwargs.get("scope_map", None),
-                excludes=kwargs.get("excludes", None),
-            )
+            if os.path.exists(init_param):
+                logging.info(f"Loading pretrained params from {init_param}")
+                load_pretrained_model(
+                    model=model,
+                    path=init_param,
+                    ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
+                    oss_bucket=kwargs.get("oss_bucket", None),
+                    scope_map=kwargs.get("scope_map", []),
+                    excludes=kwargs.get("excludes", None),
+                )
+            else:
+                print(f"error, init_param does not exist!: {init_param}")
         
         return model, kwargs
     
@@ -219,7 +229,7 @@
         speed_stats = {}
         asr_result_list = []
         num_samples = len(data_list)
-        disable_pbar = kwargs.get("disable_pbar", False)
+        disable_pbar = self.kwargs.get("disable_pbar", False)
         pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True) if not disable_pbar else None
         time_speech_total = 0.0
         time_escape_total = 0.0
@@ -231,12 +241,12 @@
             if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank
                 batch["data_in"] = data_batch[0]
                 batch["data_lengths"] = input_len
-        
+
             time1 = time.perf_counter()
             with torch.no_grad():
                 results, meta_data = model.inference(**batch, **kwargs)
             time2 = time.perf_counter()
-            
+
             asr_result_list.extend(results)
 
             # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
@@ -261,31 +271,29 @@
             pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
         torch.cuda.empty_cache()
         return asr_result_list
-    
+
     def inference_with_vad(self, input, input_len=None, **cfg):
-        
+        kwargs = self.kwargs
         # step.1: compute the vad model
         self.vad_kwargs.update(cfg)
         beg_vad = time.time()
         res = self.inference(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
         end_vad = time.time()
-        print(f"time cost vad: {end_vad - beg_vad:0.3f}")
 
 
         # step.2 compute asr model
         model = self.model
-        kwargs = self.kwargs
         kwargs.update(cfg)
         batch_size = int(kwargs.get("batch_size_s", 300))*1000
         batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
         kwargs["batch_size"] = batch_size
-        
+
         key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None))
         results_ret_list = []
         time_speech_total_all_samples = 1e-6
 
         beg_total = time.time()
-        pbar_total = tqdm(colour="red", total=len(res), dynamic_ncols=True)
+        pbar_total = tqdm(colour="red", total=len(res), dynamic_ncols=True) if not kwargs.get("disable_pbar", False) else None
         for i in range(len(res)):
             key = res[i]["key"]
             vadsegments = res[i]["value"]
@@ -296,14 +304,14 @@
             data_with_index = [(vadsegments[i], i) for i in range(n)]
             sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
             results_sorted = []
-            
+
             if not len(sorted_data):
                 logging.info("decoding, utt: {}, empty speech".format(key))
                 continue
 
             if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
                 batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
-            
+
             batch_size_ms_cum = 0
             beg_idx = 0
             beg_asr_total = time.time()
@@ -322,8 +330,8 @@
                     continue
                 batch_size_ms_cum = 0
                 end_idx = j + 1
-                speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])       
-                results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, disable_pbar=True, **cfg)
+                speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
+                results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
                 if self.spk_model is not None:
                     # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
                     for _b in range(len(speech_j)):
@@ -333,26 +341,26 @@
                         segments = sv_chunk(vad_segments)
                         all_segments.extend(segments)
                         speech_b = [i[2] for i in segments]
-                        spk_res = self.inference(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, disable_pbar=True, **cfg)
+                        spk_res = self.inference(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
                         results[_b]['spk_embedding'] = spk_res[0]['spk_embedding']
                 beg_idx = end_idx
                 if len(results) < 1:
                     continue
                 results_sorted.extend(results)
-            
+
             # end_asr_total = time.time()
             # time_escape_total_per_sample = end_asr_total - beg_asr_total
             # pbar_sample.update(1)
             # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
             #                      f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
             #                      f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
-            
+
             restored_data = [0] * n
             for j in range(n):
                 index = sorted_data[j][1]
                 restored_data[index] = results_sorted[j]
             result = {}
-            
+
             # results combine for texts, timestamps, speaker embeddings and others
             # TODO: rewrite for clean code
             for j in range(n):
@@ -379,18 +387,21 @@
                             result[k] = restored_data[j][k]
                         else:
                             result[k] += restored_data[j][k]
-            
-            return_raw_text = kwargs.get('return_raw_text', False)            
+
+            return_raw_text = kwargs.get('return_raw_text', False)
             # step.3 compute punc model
             if self.punc_model is not None:
-                self.punc_kwargs.update(cfg)
-                punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, disable_pbar=True, **cfg)
-                raw_text = copy.copy(result["text"])
-                if return_raw_text: result['raw_text'] = raw_text
-                result["text"] = punc_res[0]["text"]
+                if not len(result["text"]):
+                    result['raw_text'] = ''
+                else:
+                    self.punc_kwargs.update(cfg)
+                    punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
+                    raw_text = copy.copy(result["text"])
+                    if return_raw_text: result['raw_text'] = raw_text
+                    result["text"] = punc_res[0]["text"]
             else:
                 raw_text = None
-                
+
             # speaker embedding cluster after resorted
             if self.spk_model is not None and kwargs.get('return_spk_res', True):
                 if raw_text is None:
@@ -429,13 +440,14 @@
                                                    return_raw_text=return_raw_text)
                 result['sentence_info'] = sentence_list
             if "spk_embedding" in result: del result['spk_embedding']
-                    
+
             result["key"] = key
             results_ret_list.append(result)
             end_asr_total = time.time()
             time_escape_total_per_sample = end_asr_total - beg_asr_total
-            pbar_total.update(1)
-            pbar_total.set_description(f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+            if pbar_total:
+                pbar_total.update(1)
+                pbar_total.set_description(f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
                                  f"time_speech: {time_speech_total_per_sample: 0.3f}, "
                                  f"time_escape: {time_escape_total_per_sample:0.3f}")
 
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 4538224..569757a 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -85,7 +85,9 @@
 
     # build model
     model_class = tables.model_classes.get(kwargs["model"])
-    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+    vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None
+    vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size
+    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
 
 
 
@@ -103,13 +105,15 @@
                     path=p,
                     ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
                     oss_bucket=kwargs.get("oss_bucket", None),
-                    scope_map=kwargs.get("scope_map", None),
+                    scope_map=kwargs.get("scope_map", []),
                     excludes=kwargs.get("excludes", None),
                 )
             else:
                 logging.info(f"Checkpoint does not exist, init randomly: {p}")
-    else:
+    elif kwargs.get("init", None):
         initialize(model, kwargs.get("init", "kaiming_normal"))
+    else:
+        print("No initialize method")
 
 
     # freeze_param
diff --git a/funasr/datasets/llm_datasets/__init__.py b/funasr/datasets/llm_datasets/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/llm_datasets/__init__.py
diff --git a/funasr/datasets/llm_datasets/datasets.py b/funasr/datasets/llm_datasets/datasets.py
new file mode 100644
index 0000000..9673d76
--- /dev/null
+++ b/funasr/datasets/llm_datasets/datasets.py
@@ -0,0 +1,131 @@
+import torch
+import copy
+
+from funasr.register import tables
+from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
+
+
+@tables.register("dataset_classes", "AudioLLMDataset")
+class AudioLLMDataset(torch.utils.data.Dataset):
+    """
+    AudioLLMDataset
+    """
+    def __init__(self,
+                 path,
+                 index_ds: str = None,
+                 frontend=None,
+                 tokenizer=None,
+                 int_pad_value: int = -1,
+                 float_pad_value: float = 0.0,
+                  **kwargs):
+        super().__init__()
+        index_ds_class = tables.index_ds_classes.get(index_ds)
+        self.index_ds = index_ds_class(path, **kwargs)
+        preprocessor_speech = kwargs.get("preprocessor_speech", None)
+        if preprocessor_speech:
+            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
+            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
+        self.preprocessor_speech = preprocessor_speech
+        preprocessor_text = kwargs.get("preprocessor_text", None)
+        if preprocessor_text:
+            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
+            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
+        self.preprocessor_text = preprocessor_text
+        
+        self.frontend = frontend
+        self.fs = 16000 if frontend is None else frontend.fs
+        self.data_type = "sound"
+        self.tokenizer = tokenizer
+
+        self.float_pad_value = float_pad_value
+        self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
+        self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
+            self.prompt)  # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
+        self.prompt_af = ""
+        self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
+        self.int_pad_value = self.IGNORE_INDEX
+    
+    def get_source_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_source_len(item)
+    
+    def get_target_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_target_len(item)
+    
+    def __len__(self):
+        return len(self.index_ds)
+    
+    def __getitem__(self, index):
+        item = self.index_ds[index]
+        # import pdb;
+        # pdb.set_trace()
+        source = item["source"]
+        data_src = load_audio_text_image_video(source, fs=self.fs)
+        if self.preprocessor_speech:
+            data_src = self.preprocessor_speech(data_src, fs=self.fs)
+        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
+        speech = speech.squeeze(0)
+
+        target = item["target"]
+        if self.preprocessor_text:
+            target = self.preprocessor_text(target)
+        
+        
+        prompt_ids_pre = self.tokenizer.encode(self.prompt_pre) # [bos,prompt]
+        prompt_pre_length = len(prompt_ids_pre)
+        
+        prompt_input = "{}{}".format(self.prompt_pre, target)
+        prompt_input_ids = self.tokenizer.encode(prompt_input)
+        audio_length = len(prompt_input_ids) - prompt_pre_length
+        input_ids = prompt_input_ids + [self.tokenizer.pad_token_id]
+        input_ids = torch.tensor(input_ids, dtype=torch.int64) #[bos, prompt, input, pad]
+        input_ids[prompt_pre_length:] = -1  # [bos, prompt,-1,-1]
+        attention_mask = input_ids.ge(-1) # [true, true, true, true], length mask
+
+        prompt_answer = "{}{}".format(self.prompt_pre, target)
+        prompt_answer_ids = self.tokenizer.encode(prompt_answer)
+        answer_length = len(prompt_answer_ids) - prompt_pre_length
+        labels_ids = copy.deepcopy(prompt_input_ids) + [self.tokenizer.eos_token_id]
+        labels_ids = torch.tensor(labels_ids, dtype=torch.int64)  # [bos, prompt, input, eos]
+        labels_ids[:prompt_pre_length] = -1  # [-1, -1, input, eos]
+        label_mask = labels_ids.ge(0)  # [False,False,True,True]
+        labels_ids[~label_mask] = self.IGNORE_INDEX  # [-100,-100,input,eos]
+        
+        audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0]
+        audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
+        
+        ids = self.tokenizer.encode(target) # token ids is different from labels_ids
+        text = torch.tensor(ids, dtype=torch.int64)
+        text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
+        
+        return {"speech": speech,
+                "speech_lengths": speech_lengths,
+                "text": text,
+                "text_lengths": text_lengths,
+                "input_ids": input_ids,
+                "attention_mask": attention_mask,
+                "labels_ids": labels_ids,
+                "label_mask": label_mask,
+                "audio_mask": audio_mask,
+                }
+    
+    
+    def collator(self, samples: list=None):
+        outputs = {}
+        for sample in samples:
+            for key in sample.keys():
+                if key not in outputs:
+                    outputs[key] = []
+                outputs[key].append(sample[key])
+
+        for key, data_list in outputs.items():
+            if isinstance(data_list[0], torch.Tensor):
+                if data_list[0].dtype == torch.int64:
+    
+                    pad_value = self.int_pad_value
+                else:
+                    pad_value = self.float_pad_value
+                
+                outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+        return outputs
diff --git a/funasr/datasets/llm_datasets/preprocessor.py b/funasr/datasets/llm_datasets/preprocessor.py
new file mode 100644
index 0000000..9f20672
--- /dev/null
+++ b/funasr/datasets/llm_datasets/preprocessor.py
@@ -0,0 +1,37 @@
+import os
+import json
+import torch
+import logging
+import concurrent.futures
+import librosa
+import torch.distributed as dist
+from typing import Collection
+import torch
+import torchaudio
+from torch import nn
+import random
+import re
+import string
+from funasr.tokenizer.cleaner import TextCleaner
+from funasr.register import tables
+
+
+
+@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
+class TextPreprocessSegDict(nn.Module):
+	def __init__(self,
+	             **kwargs):
+		super().__init__()
+		
+	
+	def forward(self, text, **kwargs):
+		# 瀹氫箟鑻辨枃鏍囩偣绗﹀彿
+		en_punct = string.punctuation
+		# 瀹氫箟涓枃鏍囩偣绗﹀彿锛堥儴鍒嗗父鐢ㄧ殑锛�
+		cn_punct = '銆傦紵锛侊紝銆侊紱锛氣�溾�濃�樷�欙紙锛夈�娿�嬨�愩�戔�︹�旓綖路'
+		# 鍚堝苟鑻辨枃鍜屼腑鏂囨爣鐐圭鍙�
+		all_punct = en_punct + cn_punct
+		# 鍒涘缓姝e垯琛ㄨ揪寮忔ā寮忥紝鍖归厤浠讳綍鍦╝ll_punct涓殑瀛楃
+		punct_pattern = re.compile('[{}]'.format(re.escape(all_punct)))
+		# 浣跨敤姝e垯琛ㄨ揪寮忕殑sub鏂规硶鏇挎崲鎺夎繖浜涘瓧绗�
+		return punct_pattern.sub('', text)
diff --git a/funasr/datasets/llm_datasets/samplers.py b/funasr/datasets/llm_datasets/samplers.py
new file mode 100644
index 0000000..914e776
--- /dev/null
+++ b/funasr/datasets/llm_datasets/samplers.py
@@ -0,0 +1,277 @@
+import torch
+import numpy as np
+import logging
+import torch.distributed as dist
+
+from funasr.register import tables
+
+
+@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
+class BatchSampler(torch.utils.data.BatchSampler):
+    
+    def __init__(self, dataset,
+                 batch_type: str = "example",
+                 batch_size: int = 100,
+                 buffer_size: int = 30,
+                 drop_last: bool = False,
+                 shuffle: bool = True,
+                 is_training: bool = True,
+                 **kwargs):
+        
+        self.drop_last = drop_last
+        self.pre_idx = -1
+        self.dataset = dataset
+        self.total_samples = len(dataset)
+        self.batch_type = batch_type
+        self.batch_size = int(batch_size)
+        self.buffer_size = buffer_size
+        self.max_token_length = kwargs.get("max_token_length", 5000)
+        self.shuffle_idx = np.arange(self.total_samples)
+        self.shuffle = shuffle and is_training
+        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+        
+    
+    def __len__(self):
+        return (self.total_samples-1) // self.batch_size + 1
+    
+    def set_epoch(self, epoch):
+        np.random.seed(epoch)
+    
+    def __iter__(self):
+        
+        if self.shuffle:
+            np.random.shuffle(self.shuffle_idx)
+        
+        batch = []
+        max_token = 0
+        num_sample = 0
+        
+        iter_num = (self.total_samples - 1) // self.buffer_size + 1
+        # print("iter_num: ", iter_num)
+        for iter in range(self.pre_idx + 1, iter_num):
+            datalen_with_index = []
+            for i in range(self.buffer_size):
+                idx = iter * self.buffer_size + i
+                if idx >= self.total_samples:
+                    continue
+                
+                idx_map = self.shuffle_idx[idx]
+                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
+                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
+                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
+                sample_len_cur = source_len + target_len
+                
+                
+                datalen_with_index.append([idx, sample_len_cur])
+            
+            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
+            for item in datalen_with_index_sort:
+                idx, sample_len_cur_raw = item
+                if sample_len_cur_raw > self.max_token_length:
+                    continue
+                
+                max_token_cur = max(max_token, sample_len_cur_raw)
+                max_token_padding = 1 + num_sample
+                if self.batch_type != 'example':
+                    max_token_padding *= max_token_cur
+                if max_token_padding <= self.batch_size:
+                    batch.append(idx)
+                    max_token = max_token_cur
+                    num_sample += 1
+                else:
+                    yield batch
+                    batch = [idx]
+                    max_token = sample_len_cur_raw
+                    num_sample = 1
+
+
+@tables.register("batch_sampler_classes", "BatchSampler")
+@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
+class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
+    
+    def __init__(self, dataset,
+                 batch_type: str = "example",
+                 batch_size: int = 100,
+                 buffer_size: int = 30,
+                 drop_last: bool = True,
+                 shuffle: bool = True,
+                 is_training: bool = True,
+                 **kwargs):
+        
+        self.drop_last = drop_last
+        self.pre_idx = -1
+        self.dataset = dataset
+        self.total_samples = len(dataset)
+        self.batch_type = batch_type
+        self.batch_size = int(batch_size)
+        self.buffer_size = buffer_size
+        self.max_token_length = kwargs.get("max_token_length", 1500)
+        self.shuffle_idx = np.arange(self.total_samples)
+        self.shuffle = shuffle and is_training
+        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+        
+        try:
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+        except:
+            rank = 0
+            world_size = 1
+        self.rank = rank
+        self.world_size = world_size
+        
+    def __len__(self):
+        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
+    
+    def set_epoch(self, epoch):
+        np.random.seed(epoch)
+    
+    def __iter__(self):
+    
+        batch_size_total = self.batch_size * self.world_size
+        
+        if self.shuffle:
+            np.random.shuffle(self.shuffle_idx)
+        
+        batch = []
+        max_token = 0
+        num_sample = 0
+        
+        iter_num = (self.total_samples - 1) // self.buffer_size + 1
+        # print("iter_num: ", iter_num)
+        for iter in range(self.pre_idx + 1, iter_num):
+            # if iter == iter_num -1 and self.drop_last:
+            #     continue
+            datalen_with_index = []
+            for i in range(self.buffer_size):
+                idx = iter * self.buffer_size + i
+                if idx >= self.total_samples:
+                    continue
+                
+                idx_map = self.shuffle_idx[idx]
+                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
+                
+                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
+                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
+                sample_len_cur = source_len + target_len
+                
+                datalen_with_index.append([idx, sample_len_cur])
+            
+            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
+            for item in datalen_with_index_sort:
+                idx, sample_len_cur_raw = item
+                if sample_len_cur_raw > self.max_token_length:
+                    continue
+
+                max_token_cur = max(max_token, sample_len_cur_raw)
+                max_token_padding = 1 + num_sample
+                # if self.batch_type != 'example':
+                #     max_token_padding *= max_token_cur
+                if max_token_padding <= batch_size_total:
+                    batch.append(idx)
+                    max_token = max_token_cur
+                    num_sample += 1
+                else:
+                    batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
+                    yield batch_rank
+                    batch = [idx]
+                    max_token = sample_len_cur_raw
+                    num_sample = 1
+
+
+@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
+class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
+    
+    def __init__(self, dataset,
+                 batch_type: str = "example",
+                 batch_size: int = 100,
+                 buffer_size: int = 30,
+                 drop_last: bool = True,
+                 shuffle: bool = True,
+                 is_training: bool = True,
+                 **kwargs):
+        
+        self.drop_last = drop_last
+        self.pre_idx = -1
+        self.dataset = dataset
+        self.total_samples = len(dataset)
+        self.batch_type = batch_type
+        self.batch_size = int(batch_size)
+        self.buffer_size = buffer_size
+        self.max_token_length = kwargs.get("max_token_length", 1500)
+        self.shuffle_idx = np.arange(self.total_samples)
+        self.shuffle = shuffle and is_training
+        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+        
+        try:
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+        except:
+            rank = 0
+            world_size = 1
+        self.rank = rank
+        self.world_size = world_size
+    
+    def __len__(self):
+        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
+    
+    def set_epoch(self, epoch):
+        np.random.seed(epoch)
+    
+    def __iter__(self):
+        
+        batch_size_total = self.batch_size * self.world_size
+        if self.shuffle:
+            np.random.shuffle(self.shuffle_idx)
+        
+        batch_list_all_rank = []
+        batch_list_cur = []
+        max_token = 0
+        num_sample = 0
+        
+        iter_num = (self.total_samples - 1) // self.buffer_size + 1
+        # print("iter_num: ", iter_num)
+        for iter in range(self.pre_idx + 1, iter_num):
+            # if iter == iter_num - 1 and self.drop_last:
+            #     continue
+            datalen_with_index = []
+            for i in range(self.buffer_size):
+                idx = iter * self.buffer_size + i
+                if idx >= self.total_samples:
+                    continue
+                
+                idx_map = self.shuffle_idx[idx]
+                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
+                
+                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
+                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
+                sample_len_cur = source_len + target_len
+                
+                datalen_with_index.append([idx, sample_len_cur])
+            
+            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
+            for ii, item in enumerate(datalen_with_index_sort):
+                is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort)
+                idx, sample_len_cur_raw = item
+                if sample_len_cur_raw > self.max_token_length:
+                    continue
+                
+                max_token_cur = max(max_token, sample_len_cur_raw)
+                max_token_padding = 1 + num_sample
+                
+                if self.batch_type != 'example':
+                    max_token_padding *= max_token_cur
+                if len(batch_list_all_rank) < self.world_size:
+                    
+                    if max_token_padding <= self.batch_size:
+                        batch_list_cur.append(idx)
+                        max_token = max_token_cur
+                        num_sample += 1
+                    else:
+                        batch_list_all_rank.append(batch_list_cur)
+                        batch_list_cur = []
+                else:
+                    batch_rank = batch_list_all_rank[self.rank]
+                    yield batch_rank
+                    batch_list_all_rank = [idx]
+                    max_token = sample_len_cur_raw
+                    num_sample = 1
diff --git a/funasr/metrics/compute_acc.py b/funasr/metrics/compute_acc.py
index 9d16e1f..ec8067f 100644
--- a/funasr/metrics/compute_acc.py
+++ b/funasr/metrics/compute_acc.py
@@ -21,3 +21,20 @@
     )
     denominator = torch.sum(mask)
     return float(numerator) / float(denominator)
+
+def compute_accuracy(pad_outputs, pad_targets, ignore_label):
+    """Calculate accuracy.
+
+    Args:
+        pad_outputs (LongTensor): Prediction tensors (B, Lmax).
+        pad_targets (LongTensor): Target label tensors (B, Lmax).
+        ignore_label (int): Ignore label id.
+
+    Returns:
+        float: Accuracy value (0.0 - 1.0).
+
+    """
+    mask = pad_targets != ignore_label
+    numerator = torch.sum(pad_outputs.masked_select(mask) == pad_targets.masked_select(mask))
+    denominator = torch.sum(mask)
+    return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type
\ No newline at end of file
diff --git a/funasr/models/llm_asr_nar/__init__.py b/funasr/models/llm_asr_nar/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/llm_asr_nar/__init__.py
diff --git a/funasr/models/llm_asr_nar/adaptor.py b/funasr/models/llm_asr_nar/adaptor.py
new file mode 100644
index 0000000..0676e7d
--- /dev/null
+++ b/funasr/models/llm_asr_nar/adaptor.py
@@ -0,0 +1,29 @@
+import torch
+import torch.nn as nn
+
+from funasr.register import tables
+
+@tables.register("adaptor_classes", "Linear")
+class Linear(nn.Module):
+    def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
+        super().__init__()
+        self.k = downsample_rate
+        self.encoder_dim = encoder_dim
+        self.llm_dim = llm_dim
+        self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
+        self.relu = nn.ReLU()
+        self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
+
+    def forward(self, x):
+        batch_size, seq_len, dim = x.size()
+        num_frames_to_discard = seq_len % self.k
+        if num_frames_to_discard > 0:
+            x = x[:, :-num_frames_to_discard, :]
+        seq_len = x.size(1)
+        
+        x = x.contiguous()
+        x = x.view(batch_size, seq_len // self.k, dim * self.k)
+        x = self.linear1(x)
+        x = self.relu(x)
+        x = self.linear2(x)
+        return x
diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py
new file mode 100644
index 0000000..6a4ecce
--- /dev/null
+++ b/funasr/models/llm_asr_nar/model.py
@@ -0,0 +1,338 @@
+import logging
+from typing import Union, Dict, List, Tuple, Optional
+
+import time
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.cuda.amp import autocast
+
+from funasr.models.scama.utils import sequence_mask
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.ctc.ctc import CTC
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy, compute_accuracy
+# from funasr.models.e2e_asr_common import ErrorCalculator
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.register import tables
+
+
+@tables.register("model_classes", "LLMASRNAR")
+class LLMASRNAR(nn.Module):
+    """ """
+    
+    def __init__(
+        self,
+        specaug: str = None,
+        specaug_conf: dict = None,
+        normalize: str = None,
+        normalize_conf: dict = None,
+        encoder: str = None,
+        encoder_conf: dict = None,
+        decoder: str = None,
+        decoder_conf: dict = None,
+        ctc: str = None,
+        ctc_conf: dict = None,
+        ctc_weight: float = 0.5,
+        llm: str = None,
+        llm_conf: dict = None,
+        adaptor: str = None,
+        adaptor_conf: dict = None,
+        input_size: int = 80,
+        vocab_size: int = -1,
+        ignore_id: int = -1,
+        blank_id: int = 0,
+        sos: int = 1,
+        eos: int = 2,
+        lsm_weight: float = 0.0,
+        length_normalized_loss: bool = False,
+        report_cer: bool = True,
+        report_wer: bool = True,
+        sym_space: str = "<space>",
+        sym_blank: str = "<blank>",
+        # extract_feats_in_collect_stats: bool = True,
+        share_embedding: bool = False,
+        # preencoder: Optional[AbsPreEncoder] = None,
+        # postencoder: Optional[AbsPostEncoder] = None,
+        **kwargs,
+    ):
+        
+        super().__init__()
+        
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+        if normalize is not None:
+            normalize_class = tables.normalize_classes.get(normalize)
+            normalize = normalize_class(**normalize_conf)
+        
+        # audio encoder
+        hub = encoder_conf.get("hub", None)
+        if hub == "funasr":
+            from funasr import AutoModel
+            init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+            model = AutoModel(model=init_param_path, model_revision="v2.0.4")
+            # frontend = model.kwargs.get("frontend")
+            model.model.decoder = None
+            
+            self.audio_encoder = model.model
+            # self.frontend = frontend
+            
+        elif hub == "hf":
+            pass
+        else:
+            encoder_class = tables.encoder_classes.get(encoder)
+            encoder = encoder_class(input_size=input_size, **encoder_conf)
+            encoder_output_size = encoder.output_size()
+
+        # llm
+        hub = llm_conf.get("hub", "hf")
+        self.llm = None
+        if hub == "hf":
+            from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
+
+            init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+            model = AutoModelForCausalLM.from_pretrained(
+                init_param_path,
+                load_in_8bit=None,
+                device_map=None,
+                use_cache=None,
+            )
+            freeze = llm_conf.get("freeze", True)
+            if freeze:
+                for name, param in model.named_parameters():
+                    param.requires_grad = False
+                model.eval()
+            self.llm = model
+        
+        # adaptor
+        adaptor_class = tables.adaptor_classes.get(adaptor)
+        adaptor = adaptor_class(**adaptor_conf)
+        
+        self.adaptor = adaptor
+        
+        
+        self.blank_id = blank_id
+        self.sos = sos if sos is not None else vocab_size - 1
+        self.eos = eos if eos is not None else vocab_size - 1
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.specaug = specaug
+        self.normalize = normalize
+        self.encoder = encoder
+
+
+        self.criterion_att = LabelSmoothingLoss(
+            size=vocab_size,
+            padding_idx=ignore_id,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+        #
+        # if report_cer or report_wer:
+        #     self.error_calculator = ErrorCalculator(
+        #         token_list, sym_space, sym_blank, report_cer, report_wer
+        #     )
+        #
+        self.error_calculator = None
+
+        self.length_normalized_loss = length_normalized_loss
+        self.beam_search = None
+    
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        input_ids: torch.Tensor,
+        attention_mask:torch.Tensor,
+        labels_ids: torch.Tensor,
+        label_mask: torch.Tensor,
+        audio_mask: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+        # import pdb;
+        # pdb.set_trace()
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+        
+        batch_size = speech.shape[0]
+        
+        # audio encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)
+        
+        # adaptor
+        encoder_out = self.adaptor(encoder_out)
+
+        if input_ids is not None:
+            input_ids[input_ids == -1] = 0
+            input_ids[input_ids == -100] = 0
+            if hasattr(self.llm.model, "embed_tokens"):
+                inputs_embeds = self.llm.model.embed_tokens(input_ids)
+            elif hasattr(self.llm.model.model, "embed_tokens"):
+                inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
+            else:
+                inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
+
+            if audio_mask is not None:
+                batch_size, token_num, dims = inputs_embeds.shape
+                _, l, _ = encoder_out.shape
+                encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
+                inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
+                inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
+
+        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
+        loss = model_outputs.loss
+
+
+        stats = {}
+        with torch.no_grad():
+            preds = torch.argmax(model_outputs.logits, -1)
+            acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
+            stats["acc"] = acc_att
+
+        stats["loss"] = torch.clone(loss.detach())
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + 1).sum())
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+    
+    def encode(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+    
+        audio_mask = kwargs.get("audio_mask", None)
+        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
+
+        batch = {"speech": speech, "speech_lengths": speech_lengths}
+        enc, enc_lens = self.audio_encoder.encode(**batch)
+        with autocast(False):
+            enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
+            pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
+                                                                               mask=enc_mask,
+                                                                               target_label_length=audio_token_lengths,
+                                                                               )
+
+        return pre_acoustic_embeds, pre_token_length
+
+
+    def inference(self,
+                  data_in,
+                  data_lengths=None,
+                  key: list = None,
+                  tokenizer=None,
+                  frontend=None,
+                  **kwargs,
+                  ):
+        
+        prompt = kwargs.get("prompt", "Transcribe speech to text.")
+        
+        if kwargs.get("batch_size", 1) > 1:
+            raise NotImplementedError("batch decoding is not implemented")
+
+
+        
+        meta_data = {}
+        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+                                                            data_type=kwargs.get("data_type", "sound"),
+                                                            tokenizer=tokenizer)
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+                                                   frontend=frontend)
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+        
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+        
+        # Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        # adaptor
+        encoder_out = self.adaptor(encoder_out)
+        
+    
+        prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
+        prompt_ids = tokenizer.encode(prompt_pre)
+        prompt_length = len(prompt_ids)
+        prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
+
+
+        if hasattr(self.llm.model, "embed_tokens"):
+            inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
+        elif hasattr(self.llm.model.model, "embed_tokens"):
+            inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
+        else:
+            inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
+
+        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1)  # [prompt, audio]
+        attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
+        
+        # model_outputs = self.llm.generate(
+        #     inputs_embeds=inputs_embeds,
+        #     max_length=kwargs.get("max_length", 200),
+        #     max_new_tokens=kwargs.get("max_new_tokens", 200),
+        #     num_beams=kwargs.get("num_beams", 4),
+        #     do_sample=kwargs.get("do_sample", False),
+        #     min_length=kwargs.get("min_length", 1),
+        #     top_p=kwargs.get("top_p", 1.0),
+        #     repetition_penalty=kwargs.get("repetition_penalty", 1.0),
+        #     length_penalty=kwargs.get("length_penalty", 1.0),
+        #     temperature=kwargs.get("temperature", 1.0),
+        #     attention_mask=attention_mask,
+        #     bos_token_id=tokenizer.bos_token_id,
+        #     eos_token_id=tokenizer.eos_token_id,
+        #     pad_token_id=tokenizer.pad_token_id
+        # )
+
+
+        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
+        preds = torch.argmax(model_outputs.logits, -1)
+        text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
+        text = text[0].split(': \n')[-1]
+        # preds = torch.argmax(model_outputs.logits, -1)
+        
+        ibest_writer = None
+        if kwargs.get("output_dir") is not None:
+            if not hasattr(self, "writer"):
+                self.writer = DatadirWriter(kwargs.get("output_dir"))
+            ibest_writer = self.writer[f"{0 + 1}best_recog"]
+
+        results = []
+        result_i = {"key": key[0], "text": text}
+        results.append(result_i)
+
+        if ibest_writer is not None:
+            ibest_writer["text"][key[0]] = text
+        
+        
+        
+        
+        return results, meta_data
+
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 60ddc24..4d9f5d8 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -10,7 +10,7 @@
 from funasr.register import tables
 from funasr.train_utils.device_funcs import to_device
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
-
+from torch.cuda.amp import autocast
 
 @tables.register("predictor_classes", "CifPredictor")
 class CifPredictor(torch.nn.Module):
@@ -28,42 +28,44 @@
 
     def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
                 target_label_length=None):
-        h = hidden
-        context = h.transpose(1, 2)
-        queries = self.pad(context)
-        memory = self.cif_conv1d(queries)
-        output = memory + context
-        output = self.dropout(output)
-        output = output.transpose(1, 2)
-        output = torch.relu(output)
-        output = self.cif_output(output)
-        alphas = torch.sigmoid(output)
-        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
-        if mask is not None:
-            mask = mask.transpose(-1, -2).float()
-            alphas = alphas * mask
-        if mask_chunk_predictor is not None:
-            alphas = alphas * mask_chunk_predictor
-        alphas = alphas.squeeze(-1)
-        mask = mask.squeeze(-1)
-        if target_label_length is not None:
-            target_length = target_label_length
-        elif target_label is not None:
-            target_length = (target_label != ignore_id).float().sum(-1)
-        else:
-            target_length = None
-        token_num = alphas.sum(-1)
-        if target_length is not None:
-            alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
-        elif self.tail_threshold > 0.0:
-            hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+    
+        with autocast(False):
+            h = hidden
+            context = h.transpose(1, 2)
+            queries = self.pad(context)
+            memory = self.cif_conv1d(queries)
+            output = memory + context
+            output = self.dropout(output)
+            output = output.transpose(1, 2)
+            output = torch.relu(output)
+            output = self.cif_output(output)
+            alphas = torch.sigmoid(output)
+            alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+            if mask is not None:
+                mask = mask.transpose(-1, -2).float()
+                alphas = alphas * mask
+            if mask_chunk_predictor is not None:
+                alphas = alphas * mask_chunk_predictor
+            alphas = alphas.squeeze(-1)
+            mask = mask.squeeze(-1)
+            if target_label_length is not None:
+                target_length = target_label_length
+            elif target_label is not None:
+                target_length = (target_label != ignore_id).float().sum(-1)
+            else:
+                target_length = None
+            token_num = alphas.sum(-1)
+            if target_length is not None:
+                alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
+            elif self.tail_threshold > 0.0:
+                hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+                
+            acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
             
-        acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-        
-        if target_length is None and self.tail_threshold > 0.0:
-            token_num_int = torch.max(token_num).type(torch.int32).item()
-            acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
-            
+            if target_length is None and self.tail_threshold > 0.0:
+                token_num_int = torch.max(token_num).type(torch.int32).item()
+                acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+                
         return acoustic_embeds, token_num, alphas, cif_peak
 
     def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
@@ -169,41 +171,43 @@
 
     def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
                 target_label_length=None):
-        h = hidden
-        context = h.transpose(1, 2)
-        queries = self.pad(context)
-        output = torch.relu(self.cif_conv1d(queries))
-        output = output.transpose(1, 2)
-
-        output = self.cif_output(output)
-        alphas = torch.sigmoid(output)
-        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
-        if mask is not None:
-            mask = mask.transpose(-1, -2).float()
-            alphas = alphas * mask
-        if mask_chunk_predictor is not None:
-            alphas = alphas * mask_chunk_predictor
-        alphas = alphas.squeeze(-1)
-        mask = mask.squeeze(-1)
-        if target_label_length is not None:
-            target_length = target_label_length.squeeze(-1)
-        elif target_label is not None:
-            target_length = (target_label != ignore_id).float().sum(-1)
-        else:
-            target_length = None
-        token_num = alphas.sum(-1)
-        if target_length is not None:
-            alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
-        elif self.tail_threshold > 0.0:
-            if self.tail_mask:
-                hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+        
+        with autocast(False):
+            h = hidden
+            context = h.transpose(1, 2)
+            queries = self.pad(context)
+            output = torch.relu(self.cif_conv1d(queries))
+            output = output.transpose(1, 2)
+    
+            output = self.cif_output(output)
+            alphas = torch.sigmoid(output)
+            alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+            if mask is not None:
+                mask = mask.transpose(-1, -2).float()
+                alphas = alphas * mask
+            if mask_chunk_predictor is not None:
+                alphas = alphas * mask_chunk_predictor
+            alphas = alphas.squeeze(-1)
+            mask = mask.squeeze(-1)
+            if target_label_length is not None:
+                target_length = target_label_length.squeeze(-1)
+            elif target_label is not None:
+                target_length = (target_label != ignore_id).float().sum(-1)
             else:
-                hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
-
-        acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-        if target_length is None and self.tail_threshold > 0.0:
-            token_num_int = torch.max(token_num).type(torch.int32).item()
-            acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+                target_length = None
+            token_num = alphas.sum(-1)
+            if target_length is not None:
+                alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
+            elif self.tail_threshold > 0.0:
+                if self.tail_mask:
+                    hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+                else:
+                    hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
+    
+            acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+            if target_length is None and self.tail_threshold > 0.0:
+                token_num_int = torch.max(token_num).type(torch.int32).item()
+                acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
 
         return acoustic_embeds, token_num, alphas, cif_peak
 
@@ -370,62 +374,6 @@
         predictor_alignments = index_div_bool_zeros_count_tile_out
         predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
         return predictor_alignments.detach(), predictor_alignments_length.detach()
-
-    def gen_tf2torch_map_dict(self):
-    
-        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
-        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
-        map_dict_local = {
-            ## predictor
-            "{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": (2, 1, 0),
-                 },  # (256,256,3),(3,256,256)
-            "{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.cif_output.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1,256),(1,256,1)
-            "{}.cif_output.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1,),(1,)
-        }
-        return map_dict_local
-
-    def convert_tf2torch(self,
-                         var_dict_tf,
-                         var_dict_torch,
-                         ):
-        map_dict = self.gen_tf2torch_map_dict()
-        var_dict_torch_update = dict()
-        for name in sorted(var_dict_torch.keys(), reverse=False):
-            names = name.split('.')
-            if names[0] == self.tf2torch_tensor_name_prefix_torch:
-                name_tf = map_dict[name]["name"]
-                data_tf = var_dict_tf[name_tf]
-                if map_dict[name]["squeeze"] is not None:
-                    data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
-                if map_dict[name]["transpose"] is not None:
-                    data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
-                data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                var_dict_torch[
-                                                                                                    name].size(),
-                                                                                                data_tf.size())
-                var_dict_torch_update[name] = data_tf
-                logging.info(
-                    "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
-                                                                                  var_dict_tf[name_tf].shape))
-    
-        return var_dict_torch_update
 
 
 class mae_loss(torch.nn.Module):
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index cfdd26a..21ad874 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -25,8 +25,8 @@
 from funasr.train_utils.device_funcs import force_gatherable
 from funasr.models.bicif_paraformer.model import BiCifParaformer
 from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
 from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
 from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 
diff --git a/funasr/tokenizer/hf_tokenizer.py b/funasr/tokenizer/hf_tokenizer.py
new file mode 100644
index 0000000..c856b3d
--- /dev/null
+++ b/funasr/tokenizer/hf_tokenizer.py
@@ -0,0 +1,15 @@
+
+try:
+	from transformers import AutoTokenizer
+except:
+	print("If you want to use hugging, please `pip install -U transformers`")
+
+from funasr.register import tables
+
+@tables.register("tokenizer_classes", "HuggingfaceTokenizer")
+def HuggingfaceTokenizer(init_param_path, **kwargs):
+
+	tokenizer = AutoTokenizer.from_pretrained(init_param_path)
+	
+	return tokenizer
+
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index 8493bf5..84c6320 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -38,52 +38,17 @@
 				)
 	return match_state
 
-def assigment_scope_map(dst_state: dict, src_state: dict, scope_map: str=None):
-	"""Compute the union of the current variables and checkpoint variables."""
-	import collections
-	import re
-
-	# current model variables
-	name_to_variable = collections.OrderedDict()
-	for name, var in dst_state.items():
-		name_to_variable[name] = var
-	
-	scope_map_num = 0
-	if scope_map is not None:
-		scope_map = scope_map.split(",")
-		scope_map_num = len(scope_map) // 2
-		for scope_map_idx in range(scope_map_num):
-			scope_map_id = scope_map_idx * 2
-			logging.info('assignment_map from scope {} to {}'.format(scope_map[scope_map_id], scope_map[scope_map_id+1]))
-	
-	assignment_map = {}
-	for name, var in src_state.items():
-
-		if scope_map:
-			for scope_map_idx in range(scope_map_num):
-				scope_map_id = scope_map_idx * 2
-				try:
-					idx = name.index(scope_map[scope_map_id])
-					new_name = scope_map[scope_map_id+1] + name[idx + len(scope_map[scope_map_id]):]
-					if new_name in name_to_variable:
-						assignment_map[name] = var
-				except:
-					continue
-		else:
-			if name in name_to_variable:
-				assignment_map[name] = var
-	
-	return assignment_map
-
 
 def load_pretrained_model(
 	path: str,
 	model: torch.nn.Module,
-	ignore_init_mismatch: bool,
+	ignore_init_mismatch: bool=True,
 	map_location: str = "cpu",
 	oss_bucket=None,
-	scope_map=None,
+	scope_map=[],
 	excludes=None,
+	ignore_mismatch=False,
+	**kwargs,
 ):
 	"""Load a model state and set it to the model.
 
@@ -108,57 +73,39 @@
 	
 	src_state = src_state["model"] if "model" in src_state else src_state
 	
+	if isinstance(scope_map, str):
+		scope_map = scope_map.split(",")
+	scope_map += ["module.", "None"]
+	
 	for k in dst_state.keys():
-		if not k.startswith("module.") and "module." + k in src_state.keys():
-			k_ddp = "module." + k
+		
+		k_src = k
+
+		if scope_map is not None:
+			src_prefix = ""
+			dst_prefix = ""
+			for i in range(0, len(scope_map), 2):
+				src_prefix = scope_map[i] if scope_map[i].lower() != "none" else ""
+				dst_prefix = scope_map[i+1] if scope_map[i+1].lower() != "none" else ""
+				
+				if dst_prefix == "" and (src_prefix + k) in src_state.keys():
+					k_src = src_prefix + k
+					if not k_src.startswith("module."):
+						print(f"init param, map: {k} from {k_src} in ckpt")
+				elif k.startswith(dst_prefix) and k.replace(dst_prefix, src_prefix, 1) in src_state.keys():
+					k_src = k.replace(dst_prefix, src_prefix, 1)
+					if not k_src.startswith("module."):
+						print(f"init param, map: {k} from {k_src} in ckpt")
+					
+		if k_src in src_state.keys():
+			if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape:
+				print(f"ignore_mismatch:{ignore_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}")
+			else:
+				dst_state[k] = src_state[k_src]
+
+
 		else:
-			k_ddp = k
-		if k_ddp in src_state:
-			dst_state[k] = src_state[k_ddp]
-		else:
-			print(f"Warning, miss key in ckpt: {k}, mapped: {k_ddp}")
+			print(f"Warning, miss key in ckpt: {k}, mapped: {k_src}")
 			
 	flag = obj.load_state_dict(dst_state, strict=True)
 	# print(flag)
-
-# def load_pretrained_model(
-# 	path: str,
-# 	model: torch.nn.Module,
-# 	ignore_init_mismatch: bool,
-# 	map_location: str = "cpu",
-# 	oss_bucket=None,
-# 	scope_map=None,
-# 	excludes=None,
-# ):
-# 	"""Load a model state and set it to the model.
-#
-# 	Args:
-# 		init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
-#
-# 	Examples:
-#
-# 	"""
-#
-# 	obj = model
-#
-# 	if oss_bucket is None:
-# 		src_state = torch.load(path, map_location=map_location)
-# 	else:
-# 		buffer = BytesIO(oss_bucket.get_object(path).read())
-# 		src_state = torch.load(buffer, map_location=map_location)
-# 	src_state = src_state["model"] if "model" in src_state else src_state
-#
-# 	if excludes is not None:
-# 		for e in excludes.split(","):
-# 			src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
-#
-# 	dst_state = obj.state_dict()
-# 	src_state = assigment_scope_map(dst_state, src_state, scope_map)
-#
-# 	if ignore_init_mismatch:
-# 		src_state = filter_state_dict(dst_state, src_state)
-#
-# 	logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
-# 	logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
-# 	dst_state.update(src_state)
-# 	obj.load_state_dict(dst_state, strict=True)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index d175fbe..3b20596 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -5,7 +5,8 @@
 from tqdm import tqdm
 from datetime import datetime
 import torch.distributed as dist
-from contextlib import nullcontext
+from torch.cuda.amp import autocast, GradScaler
+from contextlib import nullcontext, contextmanager
 # from torch.utils.tensorboard import SummaryWriter
 from tensorboardX import SummaryWriter
 from pathlib import Path
@@ -13,6 +14,15 @@
 from funasr.train_utils.device_funcs import to_device
 from funasr.train_utils.recursive_op import recursive_average
 from funasr.train_utils.average_nbest_models import average_checkpoints
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+
+@contextmanager
+def maybe_autocast(enabled):
+    if enabled:
+        with autocast():
+            yield
+    else:
+        yield
 
 class Trainer:
     """
@@ -36,8 +46,9 @@
                  dataloader_train,
                  dataloader_val,
                  local_rank,
-                 use_ddp=False,
-                 use_fsdp=False,
+                 use_ddp: bool = False,
+                 use_fsdp: bool = False,
+                 use_fp16: bool = False,
                  output_dir: str="./",
                  **kwargs):
         """
@@ -72,6 +83,11 @@
         self.kwargs = kwargs
         self.log_interval = kwargs.get("log_interval", 50)
         self.batch_total = 0
+        self.use_fp16 = use_fp16
+        self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
+        scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
+        scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
+        self.scaler = scaler
         
     
         try:
@@ -103,6 +119,8 @@
             'optimizer': self.optim.state_dict(),
             'scheduler': self.scheduler.state_dict(),
         }
+        if self.scaler:
+            state["scaler_state"] = self.scaler.state_dict()
         # Create output directory if it does not exist
         os.makedirs(self.output_dir, exist_ok=True)
         filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
@@ -141,6 +159,8 @@
             self.model.load_state_dict(dst_state)
             self.optim.load_state_dict(checkpoint['optimizer'])
             self.scheduler.load_state_dict(checkpoint['scheduler'])
+            if self.scaler and 'scaler_state' in checkpoint:
+                self.scaler.load_state_dict(checkpoint['scaler_state'])
             print(f"Checkpoint loaded successfully from '{ckpt}'")
         else:
             print(f"No checkpoint found at '{ckpt}', starting from scratch")
@@ -221,9 +241,10 @@
             my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
             with my_context():
                 time2 = time.perf_counter()
-
-                retval = self.model(**batch)
-                torch.cuda.empty_cache()
+                with maybe_autocast(self.use_fp16):
+                    retval = self.model(**batch)
+                    
+                if self.disable_gpu_cache: torch.cuda.empty_cache()
 
                 time3 = time.perf_counter()
                 speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
@@ -241,7 +262,10 @@
                     loss *= self.world_size
                 # Scale the loss since we're not updating for every mini-batch
                 loss = loss / accum_grad
-                loss.backward()
+                if self.use_fp16:
+                    self.scaler.scale(loss).backward()
+                else:
+                    loss.backward()
                 time4 = time.perf_counter()
                 speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
             
@@ -264,10 +288,14 @@
                 # Execute an optimization step (update model parameters)
                 if self.use_ddp or self.use_fsdp:
                     dist.barrier()
-                self.optim.step()
+                if self.use_fp16:
+                    self.scaler.step(self.optim)
+                    self.scaler.update()
+                else:
+                    self.optim.step()
                 self.scheduler.step()
                 # Clear gradients for the next accumulation stage
-                self.optim.zero_grad()
+                self.optim.zero_grad(set_to_none=True)
                 total_time = f"{time.perf_counter() - time5:0.3f}"
                 time5 = time.perf_counter()
                 speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
diff --git a/runtime/docs/SDK_advanced_guide_online_zh.md b/runtime/docs/SDK_advanced_guide_online_zh.md
index ac711a8..713f9bd 100644
--- a/runtime/docs/SDK_advanced_guide_online_zh.md
+++ b/runtime/docs/SDK_advanced_guide_online_zh.md
@@ -3,7 +3,7 @@
 
 [//]: # (FunASR鎻愪緵鍙究鎹锋湰鍦版垨鑰呬簯绔湇鍔″櫒閮ㄧ讲鐨勫疄鏃惰闊冲惉鍐欐湇鍔★紝鍐呮牳涓篎unASR宸插紑婧恟untime-SDK銆�)
 [//]: # (闆嗘垚浜嗚揪鎽╅櫌璇煶瀹為獙瀹ゅ湪Modelscope绀惧尯寮�婧愮殑璇煶绔偣妫�娴�&#40;VAD&#41;銆丳araformer-large闈炴祦寮忚闊宠瘑鍒�&#40;ASR&#41;銆丳araformer-large娴佸紡璇煶璇嗗埆&#40;ASR&#41;銆佹爣鐐�&#40;PUNC&#41; 绛夌浉鍏宠兘鍔涖�傝蒋浠跺寘鏃㈠彲浠ュ疄鏃跺湴杩涜璇煶杞枃瀛楋紝鑰屼笖鑳藉鍦ㄨ璇濆彞灏剧敤楂樼簿搴︾殑杞啓鏂囧瓧淇杈撳嚭锛岃緭鍑烘枃瀛楀甫鏈夋爣鐐癸紝鏀寔楂樺苟鍙戝璺姹�)
-FunASR瀹炴椂璇煶鍚啓杞欢鍖咃紝闆嗘垚浜嗗疄鏃剁増鏈殑璇煶绔偣妫�娴嬫ā鍨嬨�佽闊宠瘑鍒�佽闊宠瘑鍒�佹爣鐐归娴嬫ā鍨嬬瓑銆傞噰鐢ㄥ妯″瀷鍗忓悓锛屾棦鍙互瀹炴椂鐨勮繘琛岃闊宠浆鏂囧瓧锛屼篃鍙互鍦ㄨ璇濆彞灏剧敤楂樼簿搴﹁浆鍐欐枃瀛椾慨姝h緭鍑猴紝杈撳嚭鏂囧瓧甯︽湁鏍囩偣锛屾敮鎸佸璺姹傘�備緷鎹娇鐢ㄨ�呭満鏅笉鍚岋紝鏀寔瀹炴椂璇煶鍚啓鏈嶅姟锛坥nline锛夈�侀潪瀹炴椂涓�鍙ヨ瘽杞啓锛坥ffline锛変笌瀹炴椂涓庨潪瀹炴椂涓�浣撳寲鍗忓悓锛�2pass锛�3绉嶆湇鍔℃ā寮忋�傝蒋浠跺寘鎻愪緵鏈塰tml銆乸ython銆乧++銆乯ava涓巆#绛夊绉嶇紪绋嬭瑷�瀹㈡埛绔紝鐢ㄦ埛鍙互鐩存帴浣跨敤涓庤繘涓�姝ュ紑鍙戙��
+FunASR瀹炴椂璇煶鍚啓杞欢鍖咃紝闆嗘垚浜嗗疄鏃剁増鏈殑璇煶绔偣妫�娴嬫ā鍨嬨�佽闊宠瘑鍒�佹爣鐐归娴嬫ā鍨嬬瓑銆傞噰鐢ㄥ妯″瀷鍗忓悓锛屾棦鍙互瀹炴椂鐨勮繘琛岃闊宠浆鏂囧瓧锛屼篃鍙互鍦ㄨ璇濆彞灏剧敤楂樼簿搴﹁浆鍐欐枃瀛椾慨姝h緭鍑猴紝杈撳嚭鏂囧瓧甯︽湁鏍囩偣锛屾敮鎸佸璺姹傘�備緷鎹娇鐢ㄨ�呭満鏅笉鍚岋紝鏀寔瀹炴椂璇煶鍚啓鏈嶅姟锛坥nline锛夈�侀潪瀹炴椂涓�鍙ヨ瘽杞啓锛坥ffline锛変笌瀹炴椂涓庨潪瀹炴椂涓�浣撳寲鍗忓悓锛�2pass锛�3绉嶆湇鍔℃ā寮忋�傝蒋浠跺寘鎻愪緵鏈塰tml銆乸ython銆乧++銆乯ava涓巆#绛夊绉嶇紪绋嬭瑷�瀹㈡埛绔紝鐢ㄦ埛鍙互鐩存帴浣跨敤涓庤繘涓�姝ュ紑鍙戙��
 
 
 鏈枃妗d负FunASR瀹炴椂杞啓鏈嶅姟寮�鍙戞寚鍗椼�傚鏋滄偍鎯冲揩閫熶綋楠屽疄鏃惰闊冲惉鍐欐湇鍔★紝鍙弬鑰僛蹇�熶笂鎵媇(#蹇�熶笂鎵�)銆�
diff --git a/runtime/docs/docker_offline_cpu_en_lists b/runtime/docs/docker_offline_cpu_en_lists
index 8361fce..9212110 100644
--- a/runtime/docs/docker_offline_cpu_en_lists
+++ b/runtime/docs/docker_offline_cpu_en_lists
@@ -1,4 +1,5 @@
 DOCKER:
+  funasr-runtime-sdk-en-cpu-0.1.4
   funasr-runtime-sdk-en-cpu-0.1.3
   funasr-runtime-sdk-en-cpu-0.1.2
 DEFAULT_ASR_MODEL:
diff --git a/runtime/docs/docker_offline_cpu_zh_lists b/runtime/docs/docker_offline_cpu_zh_lists
index 520da51..5c0578f 100644
--- a/runtime/docs/docker_offline_cpu_zh_lists
+++ b/runtime/docs/docker_offline_cpu_zh_lists
@@ -1,5 +1,5 @@
 DOCKER:
-  funasr-runtime-sdk-cpu-0.4.2
+  funasr-runtime-sdk-cpu-0.4.3
   funasr-runtime-sdk-cpu-0.3.0
   funasr-runtime-sdk-cpu-0.2.2
 DEFAULT_ASR_MODEL:
diff --git a/runtime/docs/docker_online_cpu_zh_lists b/runtime/docs/docker_online_cpu_zh_lists
index eb6f1d3..49743ea 100644
--- a/runtime/docs/docker_online_cpu_zh_lists
+++ b/runtime/docs/docker_online_cpu_zh_lists
@@ -1,7 +1,7 @@
 DOCKER:
+  funasr-runtime-sdk-online-cpu-0.1.8
   funasr-runtime-sdk-online-cpu-0.1.7
   funasr-runtime-sdk-online-cpu-0.1.6
-  funasr-runtime-sdk-online-cpu-0.1.5
 DEFAULT_ASR_MODEL:
   damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
   damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx
diff --git a/runtime/html5/static/index.html b/runtime/html5/static/index.html
index d98c62b..de8139e 100644
--- a/runtime/html5/static/index.html
+++ b/runtime/html5/static/index.html
@@ -51,6 +51,12 @@
 
 				</div>
 				<br>
+				<div id="use_itn_div" style="border:2px solid #ccc;display:block;">
+					閫嗘枃鏈爣鍑嗗寲(ITN):<br/>
+					<label><input name="use_itn" type="radio" value="false" checked="true"/>鍚� </label>&nbsp;&nbsp;
+					<label><input name="use_itn" type="radio" value="true" />鏄� </label>
+			   </div>
+			   <br>
 		        <div  style="border:2px solid #ccc;">
 					鐑瘝璁剧疆(涓�琛屼竴涓叧閿瓧锛岀┖鏍奸殧寮�鏉冮噸,濡�"闃块噷宸村反 20")锛�
 					<br>
diff --git a/runtime/html5/static/main.js b/runtime/html5/static/main.js
index b3661cd..9a5a875 100644
--- a/runtime/html5/static/main.js
+++ b/runtime/html5/static/main.js
@@ -563,4 +563,14 @@
  
 		
 	}
+}
+
+function getUseITN() {
+	var obj = document.getElementsByName("use_itn");
+	for (var i = 0; i < obj.length; i++) {
+		if (obj[i].checked) {
+			return obj[i].value === "true";
+		}
+	}
+	return false;
 }
\ No newline at end of file
diff --git a/runtime/html5/static/wsconnecter.js b/runtime/html5/static/wsconnecter.js
index 30b99d4..db140ef 100644
--- a/runtime/html5/static/wsconnecter.js
+++ b/runtime/html5/static/wsconnecter.js
@@ -71,7 +71,7 @@
 			"wav_name":  "h5",
 			"is_speaking":  true,
 			"chunk_interval":10,
-			"itn":false,
+			"itn":getUseITN(),
 			"mode":getAsrMode(),
 			
 		};
diff --git a/runtime/python/websocket/funasr_wss_server.py b/runtime/python/websocket/funasr_wss_server.py
index 37ca6a9..015d87b 100644
--- a/runtime/python/websocket/funasr_wss_server.py
+++ b/runtime/python/websocket/funasr_wss_server.py
@@ -180,8 +180,8 @@
 					websocket.wav_name = messagejson.get("wav_name")
 				if "chunk_size" in messagejson:
 					chunk_size = messagejson["chunk_size"]
-			                if isinstance(chunk_size, str):
-			                    chunk_size = chunk_size.split(',')
+					if isinstance(chunk_size, str):
+						chunk_size = chunk_size.split(',')
 					websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
 				if "encoder_chunk_look_back" in messagejson:
 					websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
diff --git a/setup.py b/setup.py
index f703bb4..4e76c80 100644
--- a/setup.py
+++ b/setup.py
@@ -40,11 +40,11 @@
         "umap_learn",
         "jaconv",
         "hydra-core>=1.3.2",
+        "tensorboardX",
     ],
     # train: The modules invoked when training only.
     "train": [
         "editdistance",
-        "tensorboardX",
     ],
     # all: The modules should be optionally installled due to some reason.
     #      Please consider moving them to "install" occasionally

--
Gitblit v1.9.1