From 00ea1186f96e6732e2edb4fab6c0ed6896e3b352 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 22:53:18 +0800
Subject: [PATCH] funasr2
---
funasr/bin/inference.py | 3
funasr/models/paraformer/model.py | 2
funasr/models/paraformer/template.yaml | 5
funasr/bin/train.py | 3
funasr/utils/register.py | 4
funasr/models/conformer/template.yaml | 117 +++++++++++++++++++++++
funasr/models/transformer/template.yaml | 111 ++++++++++++++++++++++
funasr/models/neat_contextual_paraformer/template.yaml | 45 +++++---
8 files changed, 265 insertions(+), 25 deletions(-)
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index 09e28f3..fd884cd 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -81,7 +81,7 @@
class AutoModel:
def __init__(self, **kwargs):
- registry_tables.print_register_tables()
+ registry_tables.print()
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
@@ -108,6 +108,7 @@
frontend_class = registry_tables.frontend_classes.get(frontend.lower())
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
+ kwargs["input_size"] = frontend.output_size()
# build model
model_class = registry_tables.model_classes.get(kwargs["model"].lower())
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 72fa9fa..8112002 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -39,7 +39,7 @@
# preprocess_config(kwargs)
# import pdb; pdb.set_trace()
# set random seed
- registry_tables.print_register_tables()
+ registry_tables.print()
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
@@ -72,6 +72,7 @@
frontend_class = registry_tables.frontend_classes.get(frontend.lower())
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
+ kwargs["input_size"] = frontend.output_size()
# import pdb;
# pdb.set_trace()
diff --git a/funasr/models/conformer/template.yaml b/funasr/models/conformer/template.yaml
new file mode 100644
index 0000000..6094313
--- /dev/null
+++ b/funasr/models/conformer/template.yaml
@@ -0,0 +1,117 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.utils.register import registry_tables
+# registry_tables.print()
+
+# network architecture
+#model: funasr.models.paraformer.model:Paraformer
+model: Transformer
+model_conf:
+ ctc_weight: 0.3
+ lsm_weight: 0.1 # label smoothing option
+ length_normalized_loss: false
+
+# encoder
+encoder: ConformerEncoder
+encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.0
+ input_layer: conv2d # encoder architecture type
+ normalize_before: true
+ pos_enc_layer_type: rel_pos
+ selfattention_layer_type: rel_selfattn
+ activation_type: swish
+ macaron_style: true
+ use_cnn_module: true
+ cnn_module_kernel: 15
+
+# decoder
+decoder: TransformerDecoder
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 6
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.0
+ src_attention_dropout_rate: 0.0
+
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
+specaug: SpecAug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 40
+ num_time_mask: 2
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 150
+ val_scheduler_criterion:
+ - valid
+ - acc
+ best_model_criterion:
+ - - valid
+ - acc
+ - max
+ keep_nbest_models: 10
+ log_interval: 50
+
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: DynamicBatchLocalShuffleSampler
+ batch_type: example # example or length
+ batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 500
+ shuffle: True
+ num_workers: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+ unk_symbol: <unk>
+ split_with_space: true
+
+
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin
+ reduce: true
+ ignore_nan_grad: true
+normalize: null
diff --git a/examples/industrial_data_pretraining/paraformer-large/conf/finetune.yaml b/funasr/models/neat_contextual_paraformer/template.yaml
similarity index 65%
rename from examples/industrial_data_pretraining/paraformer-large/conf/finetune.yaml
rename to funasr/models/neat_contextual_paraformer/template.yaml
index 880aad9..012ecf7 100644
--- a/examples/industrial_data_pretraining/paraformer-large/conf/finetune.yaml
+++ b/funasr/models/neat_contextual_paraformer/template.yaml
@@ -1,6 +1,12 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.utils.register import registry_tables
+# registry_tables.print()
# network architecture
-model: funasr.cli.models.paraformer:Paraformer
+model: NeatContextualParaformer
model_conf:
ctc_weight: 0.0
lsm_weight: 0.1
@@ -8,9 +14,10 @@
predictor_weight: 1.0
predictor_bias: 1
sampling_ratio: 0.75
+ inner_dim: 512
# encoder
-encoder: sanm
+encoder: SANMEncoder
encoder_conf:
output_size: 512
attention_heads: 4
@@ -26,8 +33,9 @@
sanm_shfit: 0
selfattention_layer_type: sanm
+
# decoder
-decoder: paraformer_decoder_sanm
+decoder: ContextualParaformerDecoder
decoder_conf:
attention_heads: 4
linear_units: 2048
@@ -40,7 +48,7 @@
kernel_size: 11
sanm_shfit: 0
-predictor: cif_predictor_v2
+predictor: CifPredictorV2
predictor_conf:
idim: 512
threshold: 1.0
@@ -49,7 +57,7 @@
tail_threshold: 0.45
# frontend related
-frontend: wav_frontend
+frontend: WavFrontend
frontend_conf:
fs: 16000
window: hamming
@@ -59,7 +67,7 @@
lfr_m: 7
lfr_n: 6
-specaug: specaug_lfr
+specaug: SpecAugLFR
specaug_conf:
apply_time_warp: false
time_warp_window: 5
@@ -97,21 +105,22 @@
scheduler_conf:
warmup_steps: 30000
-
+dataset: AudioDataset
dataset_conf:
- data_names: speech,text
- data_types: sound,text
+ index_ds: IndexDSJsonl
+ batch_sampler: DynamicBatchLocalShuffleSampler
+ batch_type: example # example or length
+ batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 500
shuffle: True
- shuffle_conf:
- shuffle_size: 2048
- sort_size: 500
- batch_conf:
- batch_type: example
- batch_size: 2
- num_workers: 8
+ num_workers: 0
-split_with_space: true
-input_size: 560
+tokenizer: CharTokenizer
+tokenizer_conf:
+ unk_symbol: <unk>
+ split_with_space: true
+
ctc_conf:
dropout_rate: 0.0
ctc_type: builtin
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index fad8385..03a0bd2 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -39,8 +39,6 @@
def __init__(
self,
# token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[str] = None,
- frontend_conf: Optional[Dict] = None,
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,
diff --git a/funasr/models/paraformer/template.yaml b/funasr/models/paraformer/template.yaml
index 1909600..000f641 100644
--- a/funasr/models/paraformer/template.yaml
+++ b/funasr/models/paraformer/template.yaml
@@ -1,6 +1,10 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
+# to print the register_table:
+# from funasr.utils.register import registry_tables
+# registry_tables.print()
+
# network architecture
#model: funasr.models.paraformer.model:Paraformer
model: Paraformer
@@ -117,7 +121,6 @@
split_with_space: true
-input_size: 560
ctc_conf:
dropout_rate: 0.0
ctc_type: builtin
diff --git a/funasr/models/transformer/template.yaml b/funasr/models/transformer/template.yaml
new file mode 100644
index 0000000..798e374
--- /dev/null
+++ b/funasr/models/transformer/template.yaml
@@ -0,0 +1,111 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.utils.register import registry_tables
+# registry_tables.print()
+
+# network architecture
+#model: funasr.models.paraformer.model:Paraformer
+model: Transformer
+model_conf:
+ ctc_weight: 0.3
+ lsm_weight: 0.1 # label smoothing option
+ length_normalized_loss: false
+
+# encoder
+encoder: TransformerEncoder
+encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.0
+ input_layer: conv2d # encoder architecture type
+ normalize_before: true
+
+# decoder
+decoder: TransformerDecoder
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 6
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.0
+ src_attention_dropout_rate: 0.0
+
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
+specaug: SpecAug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 40
+ num_time_mask: 2
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 150
+ val_scheduler_criterion:
+ - valid
+ - acc
+ best_model_criterion:
+ - - valid
+ - acc
+ - max
+ keep_nbest_models: 10
+ log_interval: 50
+
+optim: adam
+optim_conf:
+ lr: 0.002
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: DynamicBatchLocalShuffleSampler
+ batch_type: example # example or length
+ batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 500
+ shuffle: True
+ num_workers: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+ unk_symbol: <unk>
+ split_with_space: true
+
+
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin
+ reduce: true
+ ignore_nan_grad: true
+normalize: null
diff --git a/funasr/utils/register.py b/funasr/utils/register.py
index 0dfcdab..6fe04f7 100644
--- a/funasr/utils/register.py
+++ b/funasr/utils/register.py
@@ -1,6 +1,6 @@
import logging
import inspect
-from dataclasses import dataclass, fields
+from dataclasses import dataclass
@dataclass
@@ -19,7 +19,7 @@
dataset_classes = {}
index_ds_classes = {}
- def print_register_tables(self,):
+ def print(self,):
print("\nregister_tables: \n")
fields = vars(self)
for classes_key, classes_dict in fields.items():
--
Gitblit v1.9.1