From 3e333c0abf31825e84d9673faf5e77601ced1112 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期四, 16 三月 2023 16:49:03 +0800
Subject: [PATCH] space between tokens

---
 funasr/models_transducer/espnet_transducer_model_unified.py |    4 ++--
 funasr/tasks/asr_transducer.py                              |    7 +++++++
 funasr/models_transducer/error_calculator.py                |    1 -
 3 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/funasr/models_transducer/error_calculator.py b/funasr/models_transducer/error_calculator.py
index 17dbf36..34b1dc7 100644
--- a/funasr/models_transducer/error_calculator.py
+++ b/funasr/models_transducer/error_calculator.py
@@ -137,7 +137,6 @@
         for i, char_pred_i in enumerate(char_pred):
             pred = char_pred_i.replace(" ", "")
             target = char_target[i].replace(" ", "")
-
             distances.append(editdistance.eval(pred, target))
             lens.append(len(target))
 
diff --git a/funasr/models_transducer/espnet_transducer_model_unified.py b/funasr/models_transducer/espnet_transducer_model_unified.py
index efe3f4e..6df86f8 100644
--- a/funasr/models_transducer/espnet_transducer_model_unified.py
+++ b/funasr/models_transducer/espnet_transducer_model_unified.py
@@ -455,7 +455,8 @@
                 gather=True,
         )
 
-        if not self.training and (self.report_cer or self.report_wer):
+        #if not self.training and (self.report_cer or self.report_wer):
+        if self.report_cer or self.report_wer:
             if self.error_calculator is None:
                 self.error_calculator = ErrorCalculator(
                     self.decoder,
@@ -468,7 +469,6 @@
                 )
 
             cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
-
             return loss_transducer, cer_transducer, wer_transducer
 
         return loss_transducer, None, None
diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
index 3c7a782..be14455 100644
--- a/funasr/tasks/asr_transducer.py
+++ b/funasr/tasks/asr_transducer.py
@@ -138,6 +138,12 @@
             help="Integer-string mapper for tokens.",
         )
         group.add_argument(
+            "--split_with_space",
+            type=str2bool,
+            default=True,
+            help="whether to split text using <space>",
+        )
+        group.add_argument(
             "--input_size",
             type=int_or_none,
             default=None,
@@ -289,6 +295,7 @@
                 non_linguistic_symbols=args.non_linguistic_symbols,
                 text_cleaner=args.cleaner,
                 g2p_type=args.g2p,
+                split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
                 rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
                 rir_apply_prob=args.rir_apply_prob
                 if hasattr(args, "rir_apply_prob")

--
Gitblit v1.9.1