From b1c186fd00fef54bcad3aa1d073a1a313642d641 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 08 五月 2024 00:31:29 +0800
Subject: [PATCH] Dev gzf exp (#1700)

---
 funasr/bin/inference.py                                         |    3 -
 funasr/datasets/sense_voice_datasets/datasets.py                |    2 
 funasr/models/sense_voice/model.py                              |   52 ++++++++++++++++-
 funasr/utils/misc.py                                            |   14 ++++
 funasr/bin/export.py                                            |    3 -
 funasr/models/sense_voice/decoder.py                            |    5 +
 funasr/models/sense_voice/search.py                             |    2 
 examples/industrial_data_pretraining/paraformer/demo.py         |   10 --
 funasr/auto/auto_model.py                                       |   29 +++++----
 examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py |    4 -
 funasr/models/sense_voice/whisper_lib/decoding.py               |    4 +
 11 files changed, 91 insertions(+), 37 deletions(-)

diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
index 0f30a37..21ce0cb 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -9,11 +9,9 @@
 
 model = AutoModel(model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch")
 
-mm = model.model
-for p in mm.parameters():
-    print(f"{p.numel()}")
 res = model.generate(input=wav_file)
 print(res)
+
 # [[beg1, end1], [beg2, end2], .., [begN, endN]]
 # beg/end: ms
 
diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py
index f6f4c75..eb7e72f 100644
--- a/examples/industrial_data_pretraining/paraformer/demo.py
+++ b/examples/industrial_data_pretraining/paraformer/demo.py
@@ -14,14 +14,8 @@
 )
 
 res = model.generate(
-    input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
-)
-res = model.generate(
-    input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
-)
-
-res = model.generate(
-    input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
+    input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
+    cache={},
 )
 
 print(res)
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 32fd560..577c328 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -26,6 +26,7 @@
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
 from funasr.utils import export_utils
+from funasr.utils import misc
 
 try:
     from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
@@ -35,14 +36,7 @@
 
 
 def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
-    """
-
-    :param input:
-    :param input_len:
-    :param data_type:
-    :param frontend:
-    :return:
-    """
+    """ """
     data_list = []
     key_list = []
     filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
@@ -73,7 +67,8 @@
                     key_list.append(key)
         else:
             if key is None:
-                key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+                # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+                key = misc.extract_filename_without_extension(data_in)
             data_list = [data_in]
             key_list = [key]
     elif isinstance(data_in, (list, tuple)):
@@ -90,10 +85,14 @@
         else:
             # [audio sample point, fbank, text]
             data_list = data_in
-            key_list = [
-                "rand_key_" + "".join(random.choice(chars) for _ in range(13))
-                for _ in range(len(data_in))
-            ]
+            key_list = []
+            for data_i in data_in:
+                if isinstance(data_i, str) and os.path.exists(data_i):
+                    key = misc.extract_filename_without_extension(data_i)
+                else:
+                    key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+                key_list.append(key)
+
     else:  # raw text; audio sample point, fbank; bytes
         if isinstance(data_in, bytes):  # audio bytes
             data_in = load_bytes(data_in)
@@ -108,6 +107,10 @@
 class AutoModel:
 
     def __init__(self, **kwargs):
+
+        log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+        logging.basicConfig(level=log_level)
+
         if not kwargs.get("disable_log", True):
             tables.print()
 
diff --git a/funasr/bin/export.py b/funasr/bin/export.py
index 6c9b49f..9d01401 100644
--- a/funasr/bin/export.py
+++ b/funasr/bin/export.py
@@ -17,9 +17,6 @@
             return cfg_item
 
     kwargs = to_plain_list(cfg)
-    log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
-
-    logging.basicConfig(level=log_level)
 
     if kwargs.get("debug", False):
         import pdb
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index 2a1b6aa..39ee5c0 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -16,9 +16,6 @@
             return cfg_item
 
     kwargs = to_plain_list(cfg)
-    log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
-
-    logging.basicConfig(level=log_level)
 
     if kwargs.get("debug", False):
         import pdb
diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py
index 5d80956..ee2f13d 100644
--- a/funasr/datasets/sense_voice_datasets/datasets.py
+++ b/funasr/datasets/sense_voice_datasets/datasets.py
@@ -112,7 +112,7 @@
 
             eos = self.tokenizer.encode(self.eos, allowed_special="all")  # [eos]
 
-            ids = prompt_ids + target_ids + eos
+            ids = prompt_ids + target_ids + eos  # [sos, task, lid, text, eos]
             ids_lengths = len(ids)
 
             text = torch.tensor(ids, dtype=torch.int64)
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
index dd00ca8..03b7532 100644
--- a/funasr/models/sense_voice/decoder.py
+++ b/funasr/models/sense_voice/decoder.py
@@ -472,7 +472,7 @@
         is_pad_mask = kwargs.get("is_pad_mask", False)
         is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
 
-        fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 or cache is None else None
+        fsmn_cache = cache[layer]["fsmn_cache"] if cache is not None and len(cache) > 0 else None
         # if fsmn_cache is not None:
         #     x = x[:, -1:]
         att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache)
@@ -599,5 +599,6 @@
     def score(self, ys, state, x):
         """Score."""
         ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
-        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
+        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=None)
+        logp = torch.log_softmax(logp, dim=-1)
         return logp.squeeze(0)[-1, :], state
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index d5e4130..0230638 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -378,14 +378,19 @@
         stats = {}
 
         # 1. Forward decoder
+        # ys_pad: [sos, task, lid, text, eos]
         decoder_out = self.model.decoder(
             x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
         )
 
         # 2. Compute attention loss
-        mask = torch.ones_like(ys_pad) * (-1)
-        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
-        ys_pad_mask[ys_pad_mask == 0] = -1
+        mask = torch.ones_like(ys_pad) * (-1)  # [sos, task, lid, text, eos]: [-1, -1, -1, -1]
+        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(
+            torch.int64
+        )  # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] + [-1, -1, 0, 0, 0]
+        ys_pad_mask[ys_pad_mask == 0] = -1  # [-1, -1, lid, text, eos]
+        # decoder_out: [sos, task, lid, text]
+        # ys_pad_mask: [-1, lid, text, eos]
         loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
 
         with torch.no_grad():
@@ -797,6 +802,16 @@
                 data_type=kwargs.get("data_type", "sound"),
                 tokenizer=tokenizer,
             )
+
+            if (
+                isinstance(kwargs.get("data_type", None), (list, tuple))
+                and len(kwargs.get("data_type", [])) > 1
+            ):
+                audio_sample_list, text_token_int_list = audio_sample_list
+                text_token_int = text_token_int_list[0]
+            else:
+                text_token_int = None
+
             time2 = time.perf_counter()
             meta_data["load_data"] = f"{time2 - time1:0.3f}"
             speech, speech_lengths = extract_fbank(
@@ -832,6 +847,37 @@
             speech[None, :, :].permute(0, 2, 1), speech_lengths
         )
 
+        if text_token_int is not None:
+            i = 0
+            results = []
+            ibest_writer = None
+            if kwargs.get("output_dir") is not None:
+                if not hasattr(self, "writer"):
+                    self.writer = DatadirWriter(kwargs.get("output_dir"))
+                ibest_writer = self.writer[f"1best_recog"]
+
+            # 1. Forward decoder
+            ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
+                None, :
+            ]
+            ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to(
+                kwargs["device"]
+            )[None, :]
+            decoder_out = self.model.decoder(
+                x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
+            )
+
+            token_int = decoder_out.argmax(-1)[0, :].tolist()
+            text = tokenizer.decode(token_int)
+
+            result_i = {"key": key[i], "text": text}
+            results.append(result_i)
+
+            if ibest_writer is not None:
+                # ibest_writer["token"][key[i]] = " ".join(token)
+                ibest_writer["text"][key[i]] = text
+            return results, meta_data
+
         # c. Passed the encoder result and the beam search
         nbest_hyps = self.beam_search(
             x=encoder_out[0],
diff --git a/funasr/models/sense_voice/search.py b/funasr/models/sense_voice/search.py
index 98d02db..694e569 100644
--- a/funasr/models/sense_voice/search.py
+++ b/funasr/models/sense_voice/search.py
@@ -370,6 +370,8 @@
             # post process of one iteration
             running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
             # end detection
+            # if len(ended_hyps) > 0:
+            #     print(f"ended_hyps: {ended_hyps}")
             if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
                 logging.info(f"end detected at {i}")
                 break
diff --git a/funasr/models/sense_voice/whisper_lib/decoding.py b/funasr/models/sense_voice/whisper_lib/decoding.py
index 382a180..609d6a6 100644
--- a/funasr/models/sense_voice/whisper_lib/decoding.py
+++ b/funasr/models/sense_voice/whisper_lib/decoding.py
@@ -62,8 +62,10 @@
 
     else:
         x = x.to(mel.device)
+    # FIX(funasr): sense vocie
+    # logits = model.logits(x[:, :-1], mel)[:, -1]
+    logits = model.logits(x[:, :], mel)[:, -1]
 
-    logits = model.logits(x[:, :-1], mel)[:, -1]
     # collect detected languages; suppress all non-language tokens
     mask = torch.ones(logits.shape[-1], dtype=torch.bool)
     mask[list(tokenizer.all_language_tokens)] = False
diff --git a/funasr/utils/misc.py b/funasr/utils/misc.py
index 5eaa4f8..9f01955 100644
--- a/funasr/utils/misc.py
+++ b/funasr/utils/misc.py
@@ -78,3 +78,17 @@
     #     config_json = os.path.join(model_path, "configuration.json")
     #     if os.path.exists(config_json):
     #         shutil.copy(config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json"))
+
+
+def extract_filename_without_extension(file_path):
+    """
+    浠庣粰瀹氱殑鏂囦欢璺緞涓彁鍙栨枃浠跺悕锛堜笉鍖呭惈璺緞鍜屾墿灞曞悕锛�
+    :param file_path: 瀹屾暣鐨勬枃浠惰矾寰�
+    :return: 鏂囦欢鍚嶏紙涓嶅惈璺緞鍜屾墿灞曞悕锛�
+    """
+    # 棣栧厛锛屼娇鐢╫s.path.basename鑾峰彇璺緞涓殑鏂囦欢鍚嶉儴鍒嗭紙鍚墿灞曞悕锛�
+    filename_with_extension = os.path.basename(file_path)
+    # 鐒跺悗锛屼娇鐢╫s.path.splitext鍒嗙鏂囦欢鍚嶅拰鎵╁睍鍚�
+    filename, extension = os.path.splitext(filename_with_extension)
+    # 杩斿洖涓嶅寘鍚墿灞曞悕鐨勬枃浠跺悕
+    return filename

--
Gitblit v1.9.1