From 00ea1186f96e6732e2edb4fab6c0ed6896e3b352 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 22:53:18 +0800
Subject: [PATCH] funasr2

---
 funasr/bin/inference.py                                |    3 
 funasr/models/paraformer/model.py                      |    2 
 funasr/models/paraformer/template.yaml                 |    5 
 funasr/bin/train.py                                    |    3 
 funasr/utils/register.py                               |    4 
 funasr/models/conformer/template.yaml                  |  117 +++++++++++++++++++++++
 funasr/models/transformer/template.yaml                |  111 ++++++++++++++++++++++
 funasr/models/neat_contextual_paraformer/template.yaml |   45 +++++---
 8 files changed, 265 insertions(+), 25 deletions(-)

diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index 09e28f3..fd884cd 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -81,7 +81,7 @@
 
 class AutoModel:
 	def __init__(self, **kwargs):
-		registry_tables.print_register_tables()
+		registry_tables.print()
 		assert "model" in kwargs
 		if "model_conf" not in kwargs:
 			logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
@@ -108,6 +108,7 @@
 			frontend_class = registry_tables.frontend_classes.get(frontend.lower())
 			frontend = frontend_class(**kwargs["frontend_conf"])
 			kwargs["frontend"] = frontend
+			kwargs["input_size"] = frontend.output_size()
 		
 		# build model
 		model_class = registry_tables.model_classes.get(kwargs["model"].lower())
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 72fa9fa..8112002 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -39,7 +39,7 @@
 	# preprocess_config(kwargs)
 	# import pdb; pdb.set_trace()
 	# set random seed
-	registry_tables.print_register_tables()
+	registry_tables.print()
 	set_all_random_seed(kwargs.get("seed", 0))
 	torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
 	torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
@@ -72,6 +72,7 @@
 		frontend_class = registry_tables.frontend_classes.get(frontend.lower())
 		frontend = frontend_class(**kwargs["frontend_conf"])
 		kwargs["frontend"] = frontend
+		kwargs["input_size"] = frontend.output_size()
 	
 	# import pdb;
 	# pdb.set_trace()
diff --git a/funasr/models/conformer/template.yaml b/funasr/models/conformer/template.yaml
new file mode 100644
index 0000000..6094313
--- /dev/null
+++ b/funasr/models/conformer/template.yaml
@@ -0,0 +1,117 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.utils.register import registry_tables
+# registry_tables.print()
+
+# network architecture
+#model: funasr.models.paraformer.model:Paraformer
+model: Transformer
+model_conf:
+    ctc_weight: 0.3
+    lsm_weight: 0.1     # label smoothing option
+    length_normalized_loss: false
+
+# encoder
+encoder: ConformerEncoder
+encoder_conf:
+    output_size: 256    # dimension of attention
+    attention_heads: 4
+    linear_units: 2048  # the number of units of position-wise feed forward
+    num_blocks: 12      # the number of encoder blocks
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.0
+    input_layer: conv2d # encoder architecture type
+    normalize_before: true
+    pos_enc_layer_type: rel_pos
+    selfattention_layer_type: rel_selfattn
+    activation_type: swish
+    macaron_style: true
+    use_cnn_module: true
+    cnn_module_kernel: 15
+
+# decoder
+decoder: TransformerDecoder
+decoder_conf:
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 6
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.0
+    src_attention_dropout_rate: 0.0
+
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
+specaug: SpecAug
+specaug_conf:
+    apply_time_warp: true
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    num_freq_mask: 2
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 40
+    num_time_mask: 2
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 150
+  val_scheduler_criterion:
+      - valid
+      - acc
+  best_model_criterion:
+  -   - valid
+      - acc
+      - max
+  keep_nbest_models: 10
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_type: example # example or length
+    batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 500
+    shuffle: True
+    num_workers: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+  split_with_space: true
+
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin
+    reduce: true
+    ignore_nan_grad: true
+normalize: null
diff --git a/examples/industrial_data_pretraining/paraformer-large/conf/finetune.yaml b/funasr/models/neat_contextual_paraformer/template.yaml
similarity index 65%
rename from examples/industrial_data_pretraining/paraformer-large/conf/finetune.yaml
rename to funasr/models/neat_contextual_paraformer/template.yaml
index 880aad9..012ecf7 100644
--- a/examples/industrial_data_pretraining/paraformer-large/conf/finetune.yaml
+++ b/funasr/models/neat_contextual_paraformer/template.yaml
@@ -1,6 +1,12 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.utils.register import registry_tables
+# registry_tables.print()
 
 # network architecture
-model: funasr.cli.models.paraformer:Paraformer
+model: NeatContextualParaformer
 model_conf:
     ctc_weight: 0.0
     lsm_weight: 0.1
@@ -8,9 +14,10 @@
     predictor_weight: 1.0
     predictor_bias: 1
     sampling_ratio: 0.75
+    inner_dim: 512
 
 # encoder
-encoder: sanm
+encoder: SANMEncoder
 encoder_conf:
     output_size: 512
     attention_heads: 4
@@ -26,8 +33,9 @@
     sanm_shfit: 0
     selfattention_layer_type: sanm
 
+
 # decoder
-decoder: paraformer_decoder_sanm
+decoder: ContextualParaformerDecoder
 decoder_conf:
     attention_heads: 4
     linear_units: 2048
@@ -40,7 +48,7 @@
     kernel_size: 11
     sanm_shfit: 0
 
-predictor: cif_predictor_v2
+predictor: CifPredictorV2
 predictor_conf:
     idim: 512
     threshold: 1.0
@@ -49,7 +57,7 @@
     tail_threshold: 0.45
 
 # frontend related
-frontend: wav_frontend
+frontend: WavFrontend
 frontend_conf:
     fs: 16000
     window: hamming
@@ -59,7 +67,7 @@
     lfr_m: 7
     lfr_n: 6
 
-specaug: specaug_lfr
+specaug: SpecAugLFR
 specaug_conf:
     apply_time_warp: false
     time_warp_window: 5
@@ -97,21 +105,22 @@
 scheduler_conf:
    warmup_steps: 30000
 
-
+dataset: AudioDataset
 dataset_conf:
-    data_names: speech,text
-    data_types: sound,text
+    index_ds: IndexDSJsonl
+    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_type: example # example or length
+    batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 500
     shuffle: True
-    shuffle_conf:
-        shuffle_size: 2048
-        sort_size: 500
-    batch_conf:
-        batch_type: example
-        batch_size: 2
-    num_workers: 8
+    num_workers: 0
 
-split_with_space: true
-input_size: 560
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+  split_with_space: true
+
 ctc_conf:
     dropout_rate: 0.0
     ctc_type: builtin
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index fad8385..03a0bd2 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -39,8 +39,6 @@
 	def __init__(
 		self,
 		# token_list: Union[Tuple[str, ...], List[str]],
-		frontend: Optional[str] = None,
-		frontend_conf: Optional[Dict] = None,
 		specaug: Optional[str] = None,
 		specaug_conf: Optional[Dict] = None,
 		normalize: str = None,
diff --git a/funasr/models/paraformer/template.yaml b/funasr/models/paraformer/template.yaml
index 1909600..000f641 100644
--- a/funasr/models/paraformer/template.yaml
+++ b/funasr/models/paraformer/template.yaml
@@ -1,6 +1,10 @@
 # This is an example that demonstrates how to configure a model file.
 # You can modify the configuration according to your own requirements.
 
+# to print the register_table:
+# from funasr.utils.register import registry_tables
+# registry_tables.print()
+
 # network architecture
 #model: funasr.models.paraformer.model:Paraformer
 model: Paraformer
@@ -117,7 +121,6 @@
   split_with_space: true
 
 
-input_size: 560
 ctc_conf:
     dropout_rate: 0.0
     ctc_type: builtin
diff --git a/funasr/models/transformer/template.yaml b/funasr/models/transformer/template.yaml
new file mode 100644
index 0000000..798e374
--- /dev/null
+++ b/funasr/models/transformer/template.yaml
@@ -0,0 +1,111 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.utils.register import registry_tables
+# registry_tables.print()
+
+# network architecture
+#model: funasr.models.paraformer.model:Paraformer
+model: Transformer
+model_conf:
+    ctc_weight: 0.3
+    lsm_weight: 0.1     # label smoothing option
+    length_normalized_loss: false
+
+# encoder
+encoder: TransformerEncoder
+encoder_conf:
+    output_size: 256    # dimension of attention
+    attention_heads: 4
+    linear_units: 2048  # the number of units of position-wise feed forward
+    num_blocks: 12      # the number of encoder blocks
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.0
+    input_layer: conv2d # encoder architecture type
+    normalize_before: true
+
+# decoder
+decoder: TransformerDecoder
+decoder_conf:
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 6
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.0
+    src_attention_dropout_rate: 0.0
+
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
+specaug: SpecAug
+specaug_conf:
+    apply_time_warp: true
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    num_freq_mask: 2
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 40
+    num_time_mask: 2
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 150
+  val_scheduler_criterion:
+      - valid
+      - acc
+  best_model_criterion:
+  -   - valid
+      - acc
+      - max
+  keep_nbest_models: 10
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.002
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_type: example # example or length
+    batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 500
+    shuffle: True
+    num_workers: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+  split_with_space: true
+
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin
+    reduce: true
+    ignore_nan_grad: true
+normalize: null
diff --git a/funasr/utils/register.py b/funasr/utils/register.py
index 0dfcdab..6fe04f7 100644
--- a/funasr/utils/register.py
+++ b/funasr/utils/register.py
@@ -1,6 +1,6 @@
 import logging
 import inspect
-from dataclasses import dataclass, fields
+from dataclasses import dataclass
 
 
 @dataclass
@@ -19,7 +19,7 @@
     dataset_classes = {}
     index_ds_classes = {}
 
-    def print_register_tables(self,):
+    def print(self,):
         print("\nregister_tables: \n")
         fields = vars(self)
         for classes_key, classes_dict in fields.items():

--
Gitblit v1.9.1