From 3e333c0abf31825e84d9673faf5e77601ced1112 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期四, 16 三月 2023 16:49:03 +0800
Subject: [PATCH] space between tokens
---
funasr/models_transducer/espnet_transducer_model_unified.py | 4 ++--
funasr/tasks/asr_transducer.py | 7 +++++++
funasr/models_transducer/error_calculator.py | 1 -
3 files changed, 9 insertions(+), 3 deletions(-)
diff --git a/funasr/models_transducer/error_calculator.py b/funasr/models_transducer/error_calculator.py
index 17dbf36..34b1dc7 100644
--- a/funasr/models_transducer/error_calculator.py
+++ b/funasr/models_transducer/error_calculator.py
@@ -137,7 +137,6 @@
for i, char_pred_i in enumerate(char_pred):
pred = char_pred_i.replace(" ", "")
target = char_target[i].replace(" ", "")
-
distances.append(editdistance.eval(pred, target))
lens.append(len(target))
diff --git a/funasr/models_transducer/espnet_transducer_model_unified.py b/funasr/models_transducer/espnet_transducer_model_unified.py
index efe3f4e..6df86f8 100644
--- a/funasr/models_transducer/espnet_transducer_model_unified.py
+++ b/funasr/models_transducer/espnet_transducer_model_unified.py
@@ -455,7 +455,8 @@
gather=True,
)
- if not self.training and (self.report_cer or self.report_wer):
+ #if not self.training and (self.report_cer or self.report_wer):
+ if self.report_cer or self.report_wer:
if self.error_calculator is None:
self.error_calculator = ErrorCalculator(
self.decoder,
@@ -468,7 +469,6 @@
)
cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
-
return loss_transducer, cer_transducer, wer_transducer
return loss_transducer, None, None
diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
index 3c7a782..be14455 100644
--- a/funasr/tasks/asr_transducer.py
+++ b/funasr/tasks/asr_transducer.py
@@ -138,6 +138,12 @@
help="Integer-string mapper for tokens.",
)
group.add_argument(
+ "--split_with_space",
+ type=str2bool,
+ default=True,
+ help="whether to split text using <space>",
+ )
+ group.add_argument(
"--input_size",
type=int_or_none,
default=None,
@@ -289,6 +295,7 @@
non_linguistic_symbols=args.non_linguistic_symbols,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
+ split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
rir_apply_prob=args.rir_apply_prob
if hasattr(args, "rir_apply_prob")
--
Gitblit v1.9.1