From 3a101795429659be9fb540f31317dfe14e362045 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期二, 29 十月 2024 11:40:27 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main

---
 funasr/models/fsmn_vad_streaming/model.py                      |   46 +++++++++++-----------
 fun_text_processing/inverse_text_normalization/run_evaluate.py |   33 +++++++---------
 2 files changed, 37 insertions(+), 42 deletions(-)

diff --git a/fun_text_processing/inverse_text_normalization/run_evaluate.py b/fun_text_processing/inverse_text_normalization/run_evaluate.py
index 76e6e3c..bea92fa 100644
--- a/fun_text_processing/inverse_text_normalization/run_evaluate.py
+++ b/fun_text_processing/inverse_text_normalization/run_evaluate.py
@@ -9,16 +9,14 @@
     training_data_to_tokens,
 )
 
-
 """
 Runs Evaluation on data in the format of : <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
 like the Google text normalization data https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
 """
 
-
 def parse_args():
     parser = ArgumentParser()
-    parser.add_argument("--input", help="input file path", type=str)
+    parser.add_argument("--input", help="input file path", type=str, required=True)
     parser.add_argument(
         "--lang",
         help="language",
@@ -39,15 +37,13 @@
     )
     return parser.parse_args()
 
-
 if __name__ == "__main__":
     # Example usage:
     # python run_evaluate.py --input=<INPUT> --cat=<CATEGORY> --filter
     args = parse_args()
     if args.lang == "en":
-        from fun_text_processing.inverse_text_normalization.en.clean_eval_data import (
-            filter_loaded_data,
-        )
+        from fun_text_processing.inverse_text_normalization.en.clean_eval_data import filter_loaded_data
+
     file_path = args.input
     inverse_normalizer = InverseNormalizer()
 
@@ -57,6 +53,7 @@
     if args.filter:
         training_data = filter_loaded_data(training_data)
 
+    # Evaluate at sentence level if no specific category is provided
     if args.category is None:
         print("Sentence level evaluation...")
         sentences_un_normalized, sentences_normalized, _ = training_data_to_sentences(training_data)
@@ -68,12 +65,12 @@
         )
         print("- Accuracy: " + str(sentences_accuracy))
 
+    # Evaluate at token level
     print("Token level evaluation...")
     tokens_per_type = training_data_to_tokens(training_data, category=args.category)
     token_accuracy = {}
-    for token_type in tokens_per_type:
+    for token_type, (tokens_un_normalized, tokens_normalized) in tokens_per_type.items():
         print("- Token type: " + token_type)
-        tokens_un_normalized, tokens_normalized = tokens_per_type[token_type]
         print("  - Data: " + str(len(tokens_normalized)) + " tokens")
         tokens_prediction = inverse_normalizer.inverse_normalize_list(tokens_normalized)
         print("  - Denormalized. Evaluating...")
@@ -81,9 +78,9 @@
             tokens_prediction, tokens_un_normalized, input=tokens_normalized
         )
         print("  - Accuracy: " + str(token_accuracy[token_type]))
-    token_count_per_type = {
-        token_type: len(tokens_per_type[token_type][0]) for token_type in tokens_per_type
-    }
+
+    # Calculate weighted token accuracy
+    token_count_per_type = {token_type: len(tokens) for token_type, (tokens, _) in tokens_per_type.items()}
     token_weighted_accuracy = [
         token_count_per_type[token_type] * accuracy
         for token_type, accuracy in token_accuracy.items()
@@ -96,19 +93,17 @@
         if token_type not in known_types:
             raise ValueError("Unexpected token type: " + token_type)
 
+    # Output table summarizing evaluation results if no specific category is provided
     if args.category is None:
         c1 = ["Class", "sent level"] + known_types
         c2 = ["Num Tokens", len(sentences_normalized)] + [
-            token_count_per_type[known_type] if known_type in tokens_per_type else "0"
-            for known_type in known_types
+            str(token_count_per_type.get(known_type, 0)) for known_type in known_types
         ]
-        c3 = ["Denormalization", sentences_accuracy] + [
-            token_accuracy[known_type] if known_type in token_accuracy else "0"
-            for known_type in known_types
+        c3 = ["Denormalization", str(sentences_accuracy)] + [
+            str(token_accuracy.get(known_type, "0")) for known_type in known_types
         ]
-
         for i in range(len(c1)):
-            print(f"{str(c1[i]):10s} | {str(c2[i]):10s} | {str(c3[i]):5s}")
+            print(f"{c1[i]:10s} | {c2[i]:10s} | {c3[i]:5s}")
     else:
         print(f"numbers\t{token_count_per_type[args.category]}")
         print(f"Denormalization\t{token_accuracy[args.category]}")
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 04689be..bfffca8 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -8,6 +8,7 @@
 import time
 import math
 import torch
+import numpy as np
 from torch import nn
 from enum import Enum
 from dataclasses import dataclass
@@ -334,18 +335,17 @@
             cache["stats"].data_buf_all = torch.cat(
                 (cache["stats"].data_buf_all, cache["stats"].waveform[0])
             )
-        for offset in range(
-            0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length
-        ):
-            cache["stats"].decibel.append(
-                10
-                * math.log10(
-                    (cache["stats"].waveform[0][offset : offset + frame_sample_length])
-                    .square()
-                    .sum()
-                    + 0.000001
-                )
-            )
+            
+        waveform_numpy = cache["stats"].waveform.numpy()
+
+        offsets = np.arange(0, waveform_numpy.shape[1] - frame_sample_length + 1, frame_shift_length)
+        frames = waveform_numpy[0, offsets[:, np.newaxis] + np.arange(frame_sample_length)]
+
+        decibel_numpy = 10 * np.log10(np.sum(np.square(frames), axis=1) + 0.000001)
+        decibel_numpy = decibel_numpy.tolist()
+
+        cache["stats"].decibel.extend(decibel_numpy)
+
 
     def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
         scores = self.encoder(feats, cache=cache["encoder"]).to("cpu")  # return B * T * D
@@ -406,7 +406,6 @@
         cur_seg = cache["stats"].output_data_buf[-1]
         if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
             print("warning\n")
-        out_pos = len(cur_seg.buffer)  # cur_seg.buff鐜板湪娌″仛浠讳綍鎿嶄綔
         data_to_pop = 0
         if end_point_is_sent_end:
             data_to_pop = expected_sample_number
@@ -420,12 +419,6 @@
             expected_sample_number = len(cache["stats"].data_buf)
 
         cur_seg.doa = 0
-        for sample_cpy_out in range(0, data_to_pop):
-            # cur_seg.buffer[out_pos ++] = data_buf_.back();
-            out_pos += 1
-        for sample_cpy_out in range(data_to_pop, expected_sample_number):
-            # cur_seg.buffer[out_pos++] = data_buf_.back()
-            out_pos += 1
         if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
             print("Something wrong with the VAD algorithm\n")
         cache["stats"].data_buf_start_frame += frm_cnt
@@ -512,10 +505,17 @@
         assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
         if len(cache["stats"].sil_pdf_ids) > 0:
             assert len(cache["stats"].scores) == 1  # 鍙敮鎸乥atch_size = 1鐨勬祴璇�
-            sil_pdf_scores = [
-                cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids
-            ]
-            sum_score = sum(sil_pdf_scores)
+            """
+            - Change type of `sum_score` to float. The reason is that `sum_score` is a tensor with single element.
+              and `torch.Tensor` is slower `float` when tensor has only one element.
+            - Put the iteration of `sil_pdf_ids` inside `sum()` to reduce the overhead of creating a new list.
+            - The default `sil_pdf_ids` is [0], the `if` statement is used to reduce the overhead of expression
+              generation, which result in a mere (~2%) performance gain.
+            """
+            if len(cache["stats"].sil_pdf_ids) > 1:
+                sum_score = sum(cache["stats"].scores[0][t][sil_pdf_id].item() for sil_pdf_id in cache["stats"].sil_pdf_ids)
+            else:
+                sum_score = cache["stats"].scores[0][t][cache["stats"].sil_pdf_ids[0]].item()
             noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
             total_score = 1.0
             sum_score = total_score - sum_score

--
Gitblit v1.9.1