From 97d648c255316ec1fff5d82e46749076faabdd2d Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 15 一月 2024 15:41:25 +0800
Subject: [PATCH] code optimize, model update, scripts
---
funasr/models/ct_transformer/utils.py | 8
funasr/models/paraformer/model.py | 45 +-
funasr/models/paraformer/search.py | 19
examples/industrial_data_pretraining/ct_transformer/demo.py | 4
examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh | 6
funasr/models/ct_transformer/model.py | 44 +-
funasr/models/emotion2vec/modules.py | 8
funasr/models/paraformer_streaming/model.py | 52 +--
examples/industrial_data_pretraining/bicif_paraformer/infer.sh | 10
funasr/models/campplus/cluster_backend.py | 19
funasr/models/campplus/utils.py | 8
examples/industrial_data_pretraining/campplus_sv/demo.py | 4
funasr/bin/inference.py | 2
funasr/models/seaco_paraformer/model.py | 38 +-
funasr/models/campplus/model.py | 41 +-
examples/industrial_data_pretraining/paraformer/demo.py | 4
funasr/models/bicif_paraformer/model.py | 48 +-
funasr/models/monotonic_aligner/model.py | 17
funasr/models/emotion2vec/timm_modules.py | 12
funasr/models/ct_transformer_streaming/encoder.py | 62 +--
examples/industrial_data_pretraining/paraformer_streaming/demo.py | 2
examples/industrial_data_pretraining/paraformer_streaming/infer.sh | 2
examples/industrial_data_pretraining/paraformer/infer.sh | 2
funasr/models/ct_transformer_streaming/attention.py | 18
funasr/models/fsmn_vad_streaming/model.py | 29
examples/industrial_data_pretraining/paraformer-zh-spk/demo.py | 6
funasr/models/contextual_paraformer/model.py | 62 +--
funasr/models/paraformer/cif_predictor.py | 32 +
examples/industrial_data_pretraining/ct_transformer/infer.sh | 4
examples/industrial_data_pretraining/contextual_paraformer/demo.py | 2
examples/industrial_data_pretraining/monotonic_aligner/demo.py | 2
funasr/models/bat/model.py | 18
funasr/models/bicif_paraformer/cif_predictor.py | 38 +-
funasr/models/contextual_paraformer/decoder.py | 42 +-
funasr/models/paraformer/decoder.py | 33 +-
funasr/models/ct_transformer_streaming/model.py | 38 +-
examples/industrial_data_pretraining/contextual_paraformer/infer.sh | 2
funasr/models/emotion2vec/audio.py | 14
funasr/models/campplus/components.py | 80 ++--
examples/industrial_data_pretraining/bicif_paraformer/demo.py | 22 -
examples/industrial_data_pretraining/seaco_paraformer/infer.sh | 4
funasr/models/emotion2vec/model.py | 50 +-
examples/industrial_data_pretraining/seaco_paraformer/demo.py | 4
examples/industrial_data_pretraining/monotonic_aligner/infer.sh | 2
44 files changed, 472 insertions(+), 487 deletions(-)
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/demo.py b/examples/industrial_data_pretraining/bicif_paraformer/demo.py
index 57edb68..60718de 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/bicif_paraformer/demo.py
@@ -6,28 +6,14 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
- model_revision="v2.0.0",
+ model_revision="v2.0.2",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.2",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
- punc_model_revision="v2.0.1",
- spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
+ punc_model_revision="v2.0.2",
+ spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
+ spk_model_revision="v2.0.2",
)
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60)
print(res)
-
-'''try asr with speaker label with
-model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
- model_revision="v2.0.0",
- vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
- vad_model_revision="v2.0.2",
- punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
- punc_model_revision="v2.0.1",
- spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
- spk_mode='punc_segment',
- )
-
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_speaker_demo.wav", batch_size_s=300, batch_size_threshold_s=60)
-print(res)
-'''
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/infer.sh b/examples/industrial_data_pretraining/bicif_paraformer/infer.sh
index 57c5838..09e1c83 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/infer.sh
+++ b/examples/industrial_data_pretraining/bicif_paraformer/infer.sh
@@ -1,10 +1,12 @@
model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
-model_revision="v2.0.0"
+model_revision="v2.0.2"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
-vad_model_revision="v2.0.0"
+vad_model_revision="v2.0.2"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
-punc_model_revision="v2.0.1"
+punc_model_revision="v2.0.2"
+spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
+spk_model_revision="v2.0.2"
python funasr/bin/inference.py \
+model=${model} \
@@ -13,6 +15,8 @@
+vad_model_revision=${vad_model_revision} \
+punc_model=${punc_model} \
+punc_model_revision=${punc_model_revision} \
++spk_model=${spk_model} \
++spk_model_revision=${spk_model_revision} \
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav" \
+output_dir="./outputs/debug" \
+device="cpu" \
diff --git a/examples/industrial_data_pretraining/campplus_sv/demo.py b/examples/industrial_data_pretraining/campplus_sv/demo.py
index 0b5588f..6a7f105 100644
--- a/examples/industrial_data_pretraining/campplus_sv/demo.py
+++ b/examples/industrial_data_pretraining/campplus_sv/demo.py
@@ -5,7 +5,9 @@
from funasr import AutoModel
-model = AutoModel(model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common")
+model = AutoModel(model="damo/speech_campplus_sv_zh-cn_16k-common",
+ model_revision="v2.0.2",
+ )
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.py b/examples/industrial_data_pretraining/contextual_paraformer/demo.py
index c705ca8..78693c5 100644
--- a/examples/industrial_data_pretraining/contextual_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/contextual_paraformer/demo.py
@@ -5,7 +5,7 @@
from funasr import AutoModel
-model = AutoModel(model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", model_revision="v2.0.0")
+model = AutoModel(model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", model_revision="v2.0.2")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
hotword='杈炬懇闄� 榄旀惌')
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/infer.sh b/examples/industrial_data_pretraining/contextual_paraformer/infer.sh
index 158ce8a..b20742b 100644
--- a/examples/industrial_data_pretraining/contextual_paraformer/infer.sh
+++ b/examples/industrial_data_pretraining/contextual_paraformer/infer.sh
@@ -1,6 +1,6 @@
model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
-model_revision="v2.0.0"
+model_revision="v2.0.2"
python funasr/bin/inference.py \
+model=${model} \
diff --git a/examples/industrial_data_pretraining/ct_transformer/demo.py b/examples/industrial_data_pretraining/ct_transformer/demo.py
index 23965e0..d648f3d 100644
--- a/examples/industrial_data_pretraining/ct_transformer/demo.py
+++ b/examples/industrial_data_pretraining/ct_transformer/demo.py
@@ -5,7 +5,7 @@
from funasr import AutoModel
-model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.1")
+model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.2")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)
@@ -13,7 +13,7 @@
from funasr import AutoModel
-model = AutoModel(model="damo/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.1")
+model = AutoModel(model="damo/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.2")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer/infer.sh b/examples/industrial_data_pretraining/ct_transformer/infer.sh
index 4b4e949..33bb5c1 100644
--- a/examples/industrial_data_pretraining/ct_transformer/infer.sh
+++ b/examples/industrial_data_pretraining/ct_transformer/infer.sh
@@ -1,9 +1,9 @@
model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
-model_revision="v2.0.1"
+model_revision="v2.0.2"
model="damo/punc_ct-transformer_cn-en-common-vocab471067-large"
-model_revision="v2.0.1"
+model_revision="v2.0.2"
python funasr/bin/inference.py \
+model=${model} \
diff --git a/examples/industrial_data_pretraining/monotonic_aligner/demo.py b/examples/industrial_data_pretraining/monotonic_aligner/demo.py
index f5df457..def6b7d 100644
--- a/examples/industrial_data_pretraining/monotonic_aligner/demo.py
+++ b/examples/industrial_data_pretraining/monotonic_aligner/demo.py
@@ -5,7 +5,7 @@
from funasr import AutoModel
-model = AutoModel(model="damo/speech_timestamp_prediction-v1-16k-offline", model_revision="v2.0.0")
+model = AutoModel(model="damo/speech_timestamp_prediction-v1-16k-offline", model_revision="v2.0.2")
res = model(input=("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
"娆㈣繋澶у鏉ュ埌榄旀惌绀惧尯杩涜浣撻獙"),
diff --git a/examples/industrial_data_pretraining/monotonic_aligner/infer.sh b/examples/industrial_data_pretraining/monotonic_aligner/infer.sh
index 34fd1f9..dcc8722 100644
--- a/examples/industrial_data_pretraining/monotonic_aligner/infer.sh
+++ b/examples/industrial_data_pretraining/monotonic_aligner/infer.sh
@@ -1,6 +1,6 @@
model="damo/speech_timestamp_prediction-v1-16k-offline"
-model_revision="v2.0.0"
+model_revision="v2.0.2"
python funasr/bin/inference.py \
+model=${model} \
diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
index fc3a635..aa895eb 100644
--- a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
+++ b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
@@ -6,13 +6,13 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
- model_revision="v2.0.0",
+ model_revision="v2.0.2",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.2",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
- punc_model_revision="v2.0.1",
+ punc_model_revision="v2.0.2",
spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
- spk_model_revision="v2.0.0"
+ spk_model_revision="v2.0.2"
)
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh b/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh
index f3fa90d..98a325d 100644
--- a/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh
+++ b/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh
@@ -1,12 +1,12 @@
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
-model_revision="v2.0.0"
+model_revision="v2.0.2"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.2"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
-punc_model_revision="v2.0.1"
+punc_model_revision="v2.0.2"
spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
-spk_model_revision="v2.0.0"
+spk_model_revision="v2.0.2"
python funasr/bin/inference.py \
+model=${model} \
diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py
index 12b963f..20f0f64 100644
--- a/examples/industrial_data_pretraining/paraformer/demo.py
+++ b/examples/industrial_data_pretraining/paraformer/demo.py
@@ -5,7 +5,7 @@
from funasr import AutoModel
-model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.0")
+model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.2")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
print(res)
@@ -13,7 +13,7 @@
from funasr import AutoFrontend
-frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.0")
+frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.2")
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
diff --git a/examples/industrial_data_pretraining/paraformer/infer.sh b/examples/industrial_data_pretraining/paraformer/infer.sh
index 9436628..6d3732f 100644
--- a/examples/industrial_data_pretraining/paraformer/infer.sh
+++ b/examples/industrial_data_pretraining/paraformer/infer.sh
@@ -1,6 +1,6 @@
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
-model_revision="v2.0.0"
+model_revision="v2.0.2"
python funasr/bin/inference.py \
+model=${model} \
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
index 65f182c..8f7eef3 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py
+++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
@@ -9,7 +9,7 @@
encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
-model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.0")
+model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.2")
cache = {}
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
chunk_size=chunk_size,
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/infer.sh b/examples/industrial_data_pretraining/paraformer_streaming/infer.sh
index 77e839b..225f2a9 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/infer.sh
+++ b/examples/industrial_data_pretraining/paraformer_streaming/infer.sh
@@ -1,6 +1,6 @@
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
-model_revision="v2.0.0"
+model_revision="v2.0.2"
python funasr/bin/inference.py \
+model=${model} \
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
index 7f1fdb5..3a13126 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -6,11 +6,11 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
- model_revision="v2.0.0",
+ model_revision="v2.0.2",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.2",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
- punc_model_revision="v2.0.1",
+ punc_model_revision="v2.0.2",
)
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/infer.sh b/examples/industrial_data_pretraining/seaco_paraformer/infer.sh
index ac5c190..61029e1 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/infer.sh
+++ b/examples/industrial_data_pretraining/seaco_paraformer/infer.sh
@@ -1,10 +1,10 @@
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
-model_revision="v2.0.0"
+model_revision="v2.0.2"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.2"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
-punc_model_revision="v2.0.1"
+punc_model_revision="v2.0.2"
python funasr/bin/inference.py \
+model=${model} \
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index 48957dd..cefee55 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -245,7 +245,7 @@
time1 = time.perf_counter()
with torch.no_grad():
- results, meta_data = model.generate(**batch, **kwargs)
+ results, meta_data = model.inference(**batch, **kwargs)
time2 = time.perf_counter()
asr_result_list.extend(results)
diff --git a/funasr/models/bat/model.py b/funasr/models/bat/model.py
index d814e31..3fed9aa 100644
--- a/funasr/models/bat/model.py
+++ b/funasr/models/bat/model.py
@@ -1,23 +1,27 @@
-"""Boundary Aware Transducer (BAT) model."""
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
-import logging
-from contextlib import contextmanager
-from typing import Dict, List, Optional, Tuple, Union
import torch
+import logging
import torch.nn as nn
-from packaging.version import parse as V
+
+from typing import Dict, List, Optional, Tuple, Union
+
+
+from torch.cuda.amp import autocast
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
-from funasr.metrics.compute_acc import th_accuracy
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.train_utils.device_funcs import force_gatherable
-from torch.cuda.amp import autocast
+
diff --git a/funasr/models/bicif_paraformer/cif_predictor.py b/funasr/models/bicif_paraformer/cif_predictor.py
index 5a1488e..e7b3ba9 100644
--- a/funasr/models/bicif_paraformer/cif_predictor.py
+++ b/funasr/models/bicif_paraformer/cif_predictor.py
@@ -1,17 +1,15 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import torch
-from torch import nn
-from torch import Tensor
-import logging
-import numpy as np
-from funasr.train_utils.device_funcs import to_device
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.scama.utils import sequence_mask
-from typing import Optional, Tuple
from funasr.register import tables
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
-class mae_loss(nn.Module):
+class mae_loss(torch.nn.Module):
def __init__(self, normalize_length=False):
super(mae_loss, self).__init__()
@@ -95,7 +93,7 @@
return fires
@tables.register("predictor_classes", "CifPredictorV3")
-class CifPredictorV3(nn.Module):
+class CifPredictorV3(torch.nn.Module):
def __init__(self,
idim,
l_order,
@@ -116,9 +114,9 @@
):
super(CifPredictorV3, self).__init__()
- self.pad = nn.ConstantPad1d((l_order, r_order), 0)
- self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
- self.cif_output = nn.Linear(idim, 1)
+ self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
+ self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
+ self.cif_output = torch.nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
@@ -131,14 +129,14 @@
self.upsample_type = upsample_type
self.use_cif1_cnn = use_cif1_cnn
if self.upsample_type == 'cnn':
- self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
- self.cif_output2 = nn.Linear(idim, 1)
+ self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+ self.cif_output2 = torch.nn.Linear(idim, 1)
elif self.upsample_type == 'cnn_blstm':
- self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
- self.blstm = nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
- self.cif_output2 = nn.Linear(idim*2, 1)
+ self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+ self.blstm = torch.nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
+ self.cif_output2 = torch.nn.Linear(idim*2, 1)
elif self.upsample_type == 'cnn_attn':
- self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+ self.upsample_cnn = torch.nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
from funasr.models.transformer.encoder import EncoderLayer as TransformerEncoderLayer
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
@@ -157,7 +155,7 @@
True, #normalize_before,
False, #concat_after,
)
- self.cif_output2 = nn.Linear(idim, 1)
+ self.cif_output2 = torch.nn.Linear(idim, 1)
self.smooth_factor2 = smooth_factor2
self.noise_threshold2 = noise_threshold2
diff --git a/funasr/models/bicif_paraformer/model.py b/funasr/models/bicif_paraformer/model.py
index 318f1df..01f19c6 100644
--- a/funasr/models/bicif_paraformer/model.py
+++ b/funasr/models/bicif_paraformer/model.py
@@ -3,34 +3,36 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-import logging
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
import copy
-import torch
-import torch.nn as nn
-import random
-import numpy as np
import time
+import torch
+import logging
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict, List, Optional, Tuple
-from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.metrics.compute_acc import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-
-from funasr.models.paraformer.search import Hypothesis
-
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
-
-
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.model import Paraformer
+from funasr.models.paraformer.search import Hypothesis
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
@tables.register("model_classes", "BiCifParaformer")
class BiCifParaformer(Paraformer):
@@ -215,7 +217,7 @@
return loss, stats, weight
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list = None,
diff --git a/funasr/models/campplus/cluster_backend.py b/funasr/models/campplus/cluster_backend.py
index 47b45d2..3bac0a0 100644
--- a/funasr/models/campplus/cluster_backend.py
+++ b/funasr/models/campplus/cluster_backend.py
@@ -1,14 +1,17 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+# Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
-from typing import Any, Dict, Union
-
+import umap
+import scipy
+import torch
+import sklearn
import hdbscan
import numpy as np
-import scipy
-import sklearn
-import umap
+
from sklearn.cluster._kmeans import k_means
-from torch import nn
class SpectralCluster:
@@ -129,7 +132,7 @@
return labels
-class ClusterBackend(nn.Module):
+class ClusterBackend(torch.nn.Module):
r"""Perfom clustering for input embeddings and output the labels.
Args:
model_dir: A model dir.
diff --git a/funasr/models/campplus/components.py b/funasr/models/campplus/components.py
index 43d366e..8db9aef 100644
--- a/funasr/models/campplus/components.py
+++ b/funasr/models/campplus/components.py
@@ -1,41 +1,43 @@
-# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
-# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+# Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as cp
-from torch import nn
-class BasicResBlock(nn.Module):
+class BasicResBlock(torch.nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicResBlock, self).__init__()
- self.conv1 = nn.Conv2d(in_planes,
+ self.conv1 = torch.nn.Conv2d(in_planes,
planes,
kernel_size=3,
stride=(stride, 1),
padding=1,
bias=False)
- self.bn1 = nn.BatchNorm2d(planes)
- self.conv2 = nn.Conv2d(planes,
+ self.bn1 = torch.nn.BatchNorm2d(planes)
+ self.conv2 = torch.nn.Conv2d(planes,
planes,
kernel_size=3,
stride=1,
padding=1,
bias=False)
- self.bn2 = nn.BatchNorm2d(planes)
+ self.bn2 = torch.nn.BatchNorm2d(planes)
- self.shortcut = nn.Sequential()
+ self.shortcut = torch.nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
- self.shortcut = nn.Sequential(
- nn.Conv2d(in_planes,
+ self.shortcut = torch.nn.Sequential(
+ torch.nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=(stride, 1),
bias=False),
- nn.BatchNorm2d(self.expansion * planes))
+ torch.nn.BatchNorm2d(self.expansion * planes))
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
@@ -45,7 +47,7 @@
return out
-class FCM(nn.Module):
+class FCM(torch.nn.Module):
def __init__(self,
block=BasicResBlock,
num_blocks=[2, 2],
@@ -53,14 +55,14 @@
feat_dim=80):
super(FCM, self).__init__()
self.in_planes = m_channels
- self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
- self.bn1 = nn.BatchNorm2d(m_channels)
+ self.conv1 = torch.nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = torch.nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
- self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
- self.bn2 = nn.BatchNorm2d(m_channels)
+ self.conv2 = torch.nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
+ self.bn2 = torch.nn.BatchNorm2d(m_channels)
self.out_channels = m_channels * (feat_dim // 8)
def _make_layer(self, block, planes, num_blocks, stride):
@@ -69,7 +71,7 @@
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
- return nn.Sequential(*layers)
+ return torch.nn.Sequential(*layers)
def forward(self, x):
x = x.unsqueeze(1)
@@ -84,17 +86,17 @@
def get_nonlinear(config_str, channels):
- nonlinear = nn.Sequential()
+ nonlinear = torch.nn.Sequential()
for name in config_str.split('-'):
if name == 'relu':
- nonlinear.add_module('relu', nn.ReLU(inplace=True))
+ nonlinear.add_module('relu', torch.nn.ReLU(inplace=True))
elif name == 'prelu':
- nonlinear.add_module('prelu', nn.PReLU(channels))
+ nonlinear.add_module('prelu', torch.nn.PReLU(channels))
elif name == 'batchnorm':
- nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
+ nonlinear.add_module('batchnorm', torch.nn.BatchNorm1d(channels))
elif name == 'batchnorm_':
nonlinear.add_module('batchnorm',
- nn.BatchNorm1d(channels, affine=False))
+ torch.nn.BatchNorm1d(channels, affine=False))
else:
raise ValueError('Unexpected module ({}).'.format(name))
return nonlinear
@@ -109,12 +111,12 @@
return stats
-class StatsPool(nn.Module):
+class StatsPool(torch.nn.Module):
def forward(self, x):
return statistics_pooling(x)
-class TDNNLayer(nn.Module):
+class TDNNLayer(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
@@ -129,7 +131,7 @@
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
kernel_size)
padding = (kernel_size - 1) // 2 * dilation
- self.linear = nn.Conv1d(in_channels,
+ self.linear = torch.nn.Conv1d(in_channels,
out_channels,
kernel_size,
stride=stride,
@@ -144,7 +146,7 @@
return x
-class CAMLayer(nn.Module):
+class CAMLayer(torch.nn.Module):
def __init__(self,
bn_channels,
out_channels,
@@ -155,17 +157,17 @@
bias,
reduction=2):
super(CAMLayer, self).__init__()
- self.linear_local = nn.Conv1d(bn_channels,
+ self.linear_local = torch.nn.Conv1d(bn_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
- self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
- self.relu = nn.ReLU(inplace=True)
- self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
- self.sigmoid = nn.Sigmoid()
+ self.linear1 = torch.nn.Conv1d(bn_channels, bn_channels // reduction, 1)
+ self.relu = torch.nn.ReLU(inplace=True)
+ self.linear2 = torch.nn.Conv1d(bn_channels // reduction, out_channels, 1)
+ self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
y = self.linear_local(x)
@@ -187,7 +189,7 @@
return seg
-class CAMDenseTDNNLayer(nn.Module):
+class CAMDenseTDNNLayer(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
@@ -204,7 +206,7 @@
padding = (kernel_size - 1) // 2 * dilation
self.memory_efficient = memory_efficient
self.nonlinear1 = get_nonlinear(config_str, in_channels)
- self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
+ self.linear1 = torch.nn.Conv1d(in_channels, bn_channels, 1, bias=False)
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
self.cam_layer = CAMLayer(bn_channels,
out_channels,
@@ -226,7 +228,7 @@
return x
-class CAMDenseTDNNBlock(nn.ModuleList):
+class CAMDenseTDNNBlock(torch.nn.ModuleList):
def __init__(self,
num_layers,
in_channels,
@@ -257,7 +259,7 @@
return x
-class TransitLayer(nn.Module):
+class TransitLayer(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
@@ -265,7 +267,7 @@
config_str='batchnorm-relu'):
super(TransitLayer, self).__init__()
self.nonlinear = get_nonlinear(config_str, in_channels)
- self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
+ self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias)
def forward(self, x):
x = self.nonlinear(x)
@@ -273,14 +275,14 @@
return x
-class DenseLayer(nn.Module):
+class DenseLayer(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
bias=False,
config_str='batchnorm-relu'):
super(DenseLayer, self).__init__()
- self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
+ self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias)
self.nonlinear = get_nonlinear(config_str, out_channels)
def forward(self, x):
diff --git a/funasr/models/campplus/model.py b/funasr/models/campplus/model.py
index 25ef3d7..6706c84 100644
--- a/funasr/models/campplus/model.py
+++ b/funasr/models/campplus/model.py
@@ -1,25 +1,34 @@
-# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
-# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+# Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
-import os
import time
import torch
-import logging
import numpy as np
-import torch.nn as nn
from collections import OrderedDict
-from typing import Union, Dict, List, Tuple, Optional
+from contextlib import contextmanager
+from distutils.version import LooseVersion
-from funasr.utils.load_utils import load_audio_text_image_video
-from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
-from funasr.models.campplus.components import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
- BasicResBlock, get_nonlinear, FCM
from funasr.models.campplus.utils import extract_feature
+from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.models.campplus.components import DenseLayer, StatsPool, \
+ TDNNLayer, CAMDenseTDNNBlock, TransitLayer, get_nonlinear, FCM
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
@tables.register("model_classes", "CAMPPlus")
-class CAMPPlus(nn.Module):
+class CAMPPlus(torch.nn.Module):
def __init__(self,
feat_dim=80,
embedding_size=192,
@@ -36,7 +45,7 @@
channels = self.head.out_channels
self.output_level = output_level
- self.xvector = nn.Sequential(
+ self.xvector = torch.nn.Sequential(
OrderedDict([
('tdnn',
@@ -82,10 +91,10 @@
assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
for m in self.modules():
- if isinstance(m, (nn.Conv1d, nn.Linear)):
- nn.init.kaiming_normal_(m.weight.data)
+ if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
+ torch.nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
- nn.init.zeros_(m.bias)
+ torch.nn.init.zeros_(m.bias)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
@@ -95,7 +104,7 @@
x = x.transpose(1, 2)
return x
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list=None,
diff --git a/funasr/models/campplus/utils.py b/funasr/models/campplus/utils.py
index 9964356..c81cb7e 100644
--- a/funasr/models/campplus/utils.py
+++ b/funasr/models/campplus/utils.py
@@ -1,5 +1,8 @@
-# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
-# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+# Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
import io
import os
@@ -14,6 +17,7 @@
from typing import Generator, Union
from abc import ABCMeta, abstractmethod
import torchaudio.compliance.kaldi as Kaldi
+
from funasr.models.transformer.utils.nets_utils import pad_list
diff --git a/funasr/models/contextual_paraformer/decoder.py b/funasr/models/contextual_paraformer/decoder.py
index 5ec2756..c872547 100644
--- a/funasr/models/contextual_paraformer/decoder.py
+++ b/funasr/models/contextual_paraformer/decoder.py
@@ -1,22 +1,24 @@
-from typing import List
-from typing import Tuple
-import logging
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import torch
-import torch.nn as nn
+import logging
import numpy as np
-
-from funasr.models.scama import utils as myutils
-
-from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
-from funasr.models.transformer.embedding import PositionalEncoding
-from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.models.transformer.utils.repeat import repeat
-from funasr.models.paraformer.decoder import DecoderLayerSANM, ParaformerSANMDecoder
+from typing import Tuple
from funasr.register import tables
+from funasr.models.scama import utils as myutils
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.paraformer.decoder import DecoderLayerSANM, ParaformerSANMDecoder
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
-class ContextualDecoderLayer(nn.Module):
+
+class ContextualDecoderLayer(torch.nn.Module):
def __init__(
self,
size,
@@ -38,12 +40,12 @@
self.norm2 = LayerNorm(size)
if src_attn is not None:
self.norm3 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
+ self.dropout = torch.nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
- self.concat_linear1 = nn.Linear(size + size, size)
- self.concat_linear2 = nn.Linear(size + size, size)
+ self.concat_linear1 = torch.nn.Linear(size + size, size)
+ self.concat_linear2 = torch.nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
# tgt = self.dropout(tgt)
@@ -73,7 +75,7 @@
return x, tgt_mask, x_self_attn, x_src_attn
-class ContextualBiasDecoder(nn.Module):
+class ContextualBiasDecoder(torch.nn.Module):
def __init__(
self,
size,
@@ -87,7 +89,7 @@
self.src_attn = src_attn
if src_attn is not None:
self.norm3 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
+ self.dropout = torch.nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
@@ -183,7 +185,7 @@
concat_after,
),
)
- self.dropout = nn.Dropout(dropout_rate)
+ self.dropout = torch.nn.Dropout(dropout_rate)
self.bias_decoder = ContextualBiasDecoder(
size=attention_dim,
src_attn=MultiHeadedAttentionCrossAtt(
diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py
index 67d4fb0..abbac8c 100644
--- a/funasr/models/contextual_paraformer/model.py
+++ b/funasr/models/contextual_paraformer/model.py
@@ -1,42 +1,34 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import os
+import re
+import time
+import torch
+import codecs
import logging
+import tempfile
+import requests
+import numpy as np
+from typing import Dict, Tuple
from contextlib import contextmanager
from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
-import torch
-import torch.nn as nn
-import random
-import numpy as np
-import time
-# from funasr.layers.abs_normalize import AbsNormalize
+
+from funasr.register import tables
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.frontends.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.paraformer.cif_predictor import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.models.paraformer.model import Paraformer
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.search import Hypothesis
+from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.metrics.compute_acc import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.paraformer.cif_predictor import CifPredictorV3
-from funasr.models.paraformer.search import Hypothesis
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
@@ -46,14 +38,7 @@
@contextmanager
def autocast(enabled=True):
yield
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
-
-from funasr.models.paraformer.model import Paraformer
-
-from funasr.register import tables
@tables.register("model_classes", "ContextualParaformer")
class ContextualParaformer(Paraformer):
@@ -316,7 +301,7 @@
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list = None,
@@ -324,7 +309,6 @@
frontend=None,
**kwargs,
):
-
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index 5fb3ed4..285f5cc 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -1,22 +1,34 @@
-from typing import Any
-from typing import List
-from typing import Tuple
-from typing import Optional
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import torch
import numpy as np
import torch.nn.functional as F
-
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.train_utils.device_funcs import to_device
-import torch
-import torch.nn as nn
-from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
-from funasr.utils.load_utils import load_audio_text_image_video
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Any, List, Tuple, Optional
from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
@tables.register("model_classes", "CTTransformer")
-class CTTransformer(nn.Module):
+class CTTransformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
@@ -45,11 +57,11 @@
punc_weight = [1] * punc_size
- self.embed = nn.Embedding(vocab_size, embed_unit)
+ self.embed = torch.nn.Embedding(vocab_size, embed_unit)
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(**encoder_conf)
- self.decoder = nn.Linear(att_unit, punc_size)
+ self.decoder = torch.nn.Linear(att_unit, punc_size)
self.encoder = encoder
self.punc_list = punc_list
self.punc_weight = punc_weight
@@ -211,7 +223,7 @@
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list = None,
diff --git a/funasr/models/ct_transformer/utils.py b/funasr/models/ct_transformer/utils.py
index c5f85e6..01b1850 100644
--- a/funasr/models/ct_transformer/utils.py
+++ b/funasr/models/ct_transformer/utils.py
@@ -1,4 +1,10 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import re
+
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
@@ -12,8 +18,6 @@
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit:])
return sentences
-
-
def split_words(text: str, jieba_usr_dict=None, **kwargs):
if jieba_usr_dict:
diff --git a/funasr/models/ct_transformer_streaming/attention.py b/funasr/models/ct_transformer_streaming/attention.py
index 382334e..3177eca 100644
--- a/funasr/models/ct_transformer_streaming/attention.py
+++ b/funasr/models/ct_transformer_streaming/attention.py
@@ -1,22 +1,10 @@
#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
-# Copyright 2019 Shigeki Karita
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-
-"""Multi-Head Attention layer definition."""
-
-import math
-
-import numpy
import torch
-from torch import nn
-import torch.nn.functional as F
-from typing import Optional, Tuple
-
from funasr.models.sanm.attention import MultiHeadedAttentionSANM
-
-
class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
diff --git a/funasr/models/ct_transformer_streaming/encoder.py b/funasr/models/ct_transformer_streaming/encoder.py
index 32ee2f2..95e2a4b 100644
--- a/funasr/models/ct_transformer_streaming/encoder.py
+++ b/funasr/models/ct_transformer_streaming/encoder.py
@@ -1,39 +1,29 @@
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-import logging
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from funasr.models.scama.chunk_utilis import overlap_chunk
-import numpy as np
-from funasr.train_utils.device_funcs import to_device
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.sanm.attention import MultiHeadedAttention
-from funasr.models.ct_transformer_streaming.attention import MultiHeadedAttentionSANMwithMask
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
-from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
-from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
-from funasr.models.transformer.positionwise_feed_forward import (
- PositionwiseFeedForward, # noqa: H301
-)
-from funasr.models.transformer.utils.repeat import repeat
-from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
-from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
-from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
-from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
-from funasr.models.transformer.utils.subsampling import TooShortUttError
-from funasr.models.transformer.utils.subsampling import check_short_utt
-from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
-from funasr.models.ctc.ctc import CTC
+import torch
+from typing import List, Optional, Tuple
from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.sanm.attention import MultiHeadedAttention
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.utils.subsampling import check_short_utt
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.embedding import SinusoidalPositionEncoder
+from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
+from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
+from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
+from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from funasr.models.ct_transformer_streaming.attention import MultiHeadedAttentionSANMwithMask
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, Conv2dSubsampling8
-class EncoderLayerSANM(nn.Module):
+
+class EncoderLayerSANM(torch.nn.Module):
def __init__(
self,
in_size,
@@ -51,13 +41,13 @@
self.feed_forward = feed_forward
self.norm1 = LayerNorm(in_size)
self.norm2 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
+ self.dropout = torch.nn.Dropout(dropout_rate)
self.in_size = in_size
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
- self.concat_linear = nn.Linear(size + size, size)
+ self.concat_linear = torch.nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
@@ -156,7 +146,7 @@
@tables.register("encoder_classes", "SANMVadEncoder")
-class SANMVadEncoder(nn.Module):
+class SANMVadEncoder(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -306,7 +296,7 @@
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
- self.dropout = nn.Dropout(dropout_rate)
+ self.dropout = torch.nn.Dropout(dropout_rate)
def output_size(self) -> int:
return self._output_size
diff --git a/funasr/models/ct_transformer_streaming/model.py b/funasr/models/ct_transformer_streaming/model.py
index 5254d15..217767a 100644
--- a/funasr/models/ct_transformer_streaming/model.py
+++ b/funasr/models/ct_transformer_streaming/model.py
@@ -1,20 +1,28 @@
-from typing import Any
-from typing import List
-from typing import Tuple
-from typing import Optional
-import numpy as np
-import torch.nn.functional as F
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.train_utils.device_funcs import to_device
import torch
-import torch.nn as nn
-from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
-from funasr.utils.load_utils import load_audio_text_image_video
-from funasr.models.ct_transformer.model import CTTransformer
+import numpy as np
+from contextlib import contextmanager
+from distutils.version import LooseVersion
from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.ct_transformer.model import CTTransformer
+from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
@tables.register("model_classes", "CTTransformerStreaming")
class CTTransformerStreaming(CTTransformer):
@@ -47,10 +55,8 @@
def with_vad(self):
return True
-
-
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list = None,
diff --git a/funasr/models/emotion2vec/audio.py b/funasr/models/emotion2vec/audio.py
index 316d372..d21500b 100644
--- a/funasr/models/emotion2vec/audio.py
+++ b/funasr/models/emotion2vec/audio.py
@@ -3,25 +3,21 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from typing import List, Tuple
-from functools import partial
import torch
-import torch.nn as nn
-import torch.nn.functional as F
import numpy as np
+import torch.nn as nn
+from functools import partial
+import torch.nn.functional as F
+from typing import Callable, Dict
-from typing import Callable, Dict, Optional
from funasr.models.emotion2vec.fairseq_modules import (
LayerNorm,
SamePad,
TransposeLast,
ConvFeatureExtractionModel,
)
-
-from funasr.models.emotion2vec.base import ModalitySpecificEncoder, get_alibi_bias
from funasr.models.emotion2vec.modules import Modality, BlockEncoder, Decoder1d
-
-
+from funasr.models.emotion2vec.base import ModalitySpecificEncoder, get_alibi_bias
class AudioEncoder(ModalitySpecificEncoder):
diff --git a/funasr/models/emotion2vec/model.py b/funasr/models/emotion2vec/model.py
index 315c1cc..de8113c 100644
--- a/funasr/models/emotion2vec/model.py
+++ b/funasr/models/emotion2vec/model.py
@@ -4,29 +4,35 @@
# MIT License (https://opensource.org/licenses/MIT)
# Modified from https://github.com/ddlBoJack/emotion2vec/tree/main
-import logging
import os
-from functools import partial
-import numpy as np
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-from funasr.models.emotion2vec.modules import AltBlock
-from funasr.models.emotion2vec.audio import AudioEncoder
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-
-from omegaconf import OmegaConf
import time
-
-logger = logging.getLogger(__name__)
+import torch
+import logging
+import numpy as np
+from functools import partial
+from omegaconf import OmegaConf
+import torch.nn.functional as F
+from contextlib import contextmanager
+from distutils.version import LooseVersion
from funasr.register import tables
+from funasr.models.emotion2vec.modules import AltBlock
+from funasr.models.emotion2vec.audio import AudioEncoder
+from funasr.utils.load_utils import load_audio_text_image_video
+
+
+logger = logging.getLogger(__name__)
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
@tables.register("model_classes", "Emotion2vec")
-class Emotion2vec(nn.Module):
+class Emotion2vec(torch.nn.Module):
"""
Author: Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chen
emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation
@@ -39,7 +45,7 @@
self.cfg = cfg
make_layer_norm = partial(
- nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
+ torch.nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
)
def make_block(drop_path, dim=None, heads=None):
@@ -59,7 +65,7 @@
)
self.alibi_biases = {}
- self.modality_encoders = nn.ModuleDict()
+ self.modality_encoders = torch.nn.ModuleDict()
enc = AudioEncoder(
cfg.modalities.audio,
@@ -77,11 +83,11 @@
self.loss_beta = cfg.get("loss_beta")
self.loss_scale = cfg.get("loss_scale")
- self.dropout_input = nn.Dropout(cfg.get("dropout_input"))
+ self.dropout_input = torch.nn.Dropout(cfg.get("dropout_input"))
dpr = np.linspace(cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth"))
- self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
+ self.blocks = torch.nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
self.norm = None
if cfg.get("layer_norm_first"):
@@ -183,7 +189,7 @@
)
return res
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list = None,
diff --git a/funasr/models/emotion2vec/modules.py b/funasr/models/emotion2vec/modules.py
index 33947f2..fcf99eb 100644
--- a/funasr/models/emotion2vec/modules.py
+++ b/funasr/models/emotion2vec/modules.py
@@ -4,9 +4,10 @@
# LICENSE file in the root directory of this source tree.
import torch
-import torch.nn as nn
-import torch.nn.functional as F
import numpy as np
+import torch.nn as nn
+from enum import Enum, auto
+import torch.nn.functional as F
from dataclasses import dataclass
from funasr.models.emotion2vec.fairseq_modules import (
LayerNorm,
@@ -14,12 +15,11 @@
TransposeLast,
)
-from enum import Enum, auto
+
class Modality(Enum):
AUDIO = auto()
-
@dataclass
class D2vDecoderConfig:
decoder_dim: int = 384
diff --git a/funasr/models/emotion2vec/timm_modules.py b/funasr/models/emotion2vec/timm_modules.py
index 1f6285a..b26da52 100644
--- a/funasr/models/emotion2vec/timm_modules.py
+++ b/funasr/models/emotion2vec/timm_modules.py
@@ -1,12 +1,8 @@
-from itertools import repeat
-import collections.abc
-from functools import partial
-from typing import Optional, Tuple
-import numpy as np
-
-import torch
import torch.nn as nn
-import torch.nn.functional as F
+import collections.abc
+from itertools import repeat
+from functools import partial
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index c87558c..193feb0 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -1,30 +1,32 @@
-from enum import Enum
-from typing import List, Tuple, Dict, Any
-import logging
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import os
import json
+import time
+import math
import torch
from torch import nn
-import math
-from typing import Optional
-import time
-from funasr.register import tables
-from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
-from funasr.utils.datadir_writer import DatadirWriter
-
+from enum import Enum
from dataclasses import dataclass
+from funasr.register import tables
+from typing import List, Tuple, Dict, Any, Optional
+
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
+
class VadStateMachine(Enum):
kVadInStateStartPointNotDetected = 1
kVadInStateInSpeechSegment = 2
kVadInStateEndPointDetected = 3
-
class FrameState(Enum):
kFrameStateInvalid = -1
kFrameStateSpeech = 1
kFrameStateSil = 0
-
# final voice/unvoice state per frame
class AudioChangeState(Enum):
@@ -34,7 +36,6 @@
kChangeStateSil2Speech = 3
kChangeStateNoBegin = 4
kChangeStateInvalid = 5
-
class VadDetectMode(Enum):
kVadSingleUtteranceDetectMode = 0
@@ -514,7 +515,7 @@
cache["stats"] = stats
return cache
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list = None,
diff --git a/funasr/models/monotonic_aligner/model.py b/funasr/models/monotonic_aligner/model.py
index 6309732..77d95a0 100644
--- a/funasr/models/monotonic_aligner/model.py
+++ b/funasr/models/monotonic_aligner/model.py
@@ -1,22 +1,27 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import time
import copy
import torch
from torch.cuda.amp import autocast
from typing import Union, Dict, List, Tuple, Optional
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
-from funasr.register import tables
-from funasr.models.ctc.ctc import CTC
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-@tables.register("model_classes", "monotonicaligner")
+@tables.register("model_classes", "MonotonicAligner")
class MonotonicAligner(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -143,7 +148,7 @@
return encoder_out, encoder_out_lens
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list=None,
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index b06fa43..a5086c3 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -1,23 +1,25 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import torch
-from torch import nn
-from torch import Tensor
import logging
import numpy as np
-from funasr.train_utils.device_funcs import to_device
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.scama.utils import sequence_mask
-from typing import Optional, Tuple
from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+
@tables.register("predictor_classes", "CifPredictor")
-class CifPredictor(nn.Module):
+class CifPredictor(torch.nn.Module):
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
super().__init__()
- self.pad = nn.ConstantPad1d((l_order, r_order), 0)
- self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
- self.cif_output = nn.Linear(idim, 1)
+ self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
+ self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
+ self.cif_output = torch.nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
@@ -137,7 +139,7 @@
return predictor_alignments.detach(), predictor_alignments_length.detach()
@tables.register("predictor_classes", "CifPredictorV2")
-class CifPredictorV2(nn.Module):
+class CifPredictorV2(torch.nn.Module):
def __init__(self,
idim,
l_order,
@@ -153,9 +155,9 @@
):
super(CifPredictorV2, self).__init__()
- self.pad = nn.ConstantPad1d((l_order, r_order), 0)
- self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
- self.cif_output = nn.Linear(idim, 1)
+ self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
+ self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
+ self.cif_output = torch.nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
@@ -426,7 +428,7 @@
return var_dict_torch_update
-class mae_loss(nn.Module):
+class mae_loss(torch.nn.Module):
def __init__(self, normalize_length=False):
super(mae_loss, self).__init__()
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index 1df27e8..68018a0 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -1,25 +1,26 @@
-from typing import List
-from typing import Tuple
-import logging
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import torch
-import torch.nn as nn
-import numpy as np
+from typing import List, Tuple
+from funasr.register import tables
from funasr.models.scama import utils as myutils
-from funasr.models.transformer.decoder import BaseTransformerDecoder
-
-from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
-from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.decoder import DecoderLayer
-from funasr.models.transformer.attention import MultiHeadedAttention
+from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.decoder import BaseTransformerDecoder
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
-from funasr.register import tables
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
-class DecoderLayerSANM(nn.Module):
+
+class DecoderLayerSANM(torch.nn.Module):
"""Single decoder layer module.
Args:
@@ -62,12 +63,12 @@
self.norm2 = LayerNorm(size)
if src_attn is not None:
self.norm3 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
+ self.dropout = torch.nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
- self.concat_linear1 = nn.Linear(size + size, size)
- self.concat_linear2 = nn.Linear(size + size, size)
+ self.concat_linear1 = torch.nn.Linear(size + size, size)
+ self.concat_linear2 = torch.nn.Linear(size + size, size)
self.reserve_attn=False
self.attn_mat = []
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 2cd9c88..f92441d 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -1,35 +1,30 @@
-import os
-import logging
-from typing import Union, Dict, List, Tuple, Optional
-
-import torch
-import torch.nn as nn
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import time
-
-from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
-)
-
-from funasr.models.paraformer.cif_predictor import mae_loss
-
-from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.metrics.compute_acc import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-
-from funasr.models.paraformer.search import Hypothesis
-
+import torch
+import logging
from torch.cuda.amp import autocast
+from typing import Union, Dict, List, Tuple, Optional
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.search import Hypothesis
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+
@tables.register("model_classes", "Paraformer")
-class Paraformer(nn.Module):
+class Paraformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
@@ -439,7 +434,7 @@
# scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
self.beam_search = beam_search
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list=None,
diff --git a/funasr/models/paraformer/search.py b/funasr/models/paraformer/search.py
index 250baad..31f4de2 100644
--- a/funasr/models/paraformer/search.py
+++ b/funasr/models/paraformer/search.py
@@ -1,17 +1,16 @@
-from itertools import chain
-import logging
-from typing import Any
-from typing import Dict
-from typing import List
-from typing import NamedTuple
-from typing import Tuple
-from typing import Union
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import torch
+import logging
+from itertools import chain
+from typing import Any, Dict, List, NamedTuple, Tuple, Union
from funasr.metrics.common import end_detect
-from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
-from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
+from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface, ScorerInterface
+
class Hypothesis(NamedTuple):
"""Hypothesis data type."""
diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index e6f3038..bf45269 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -1,35 +1,29 @@
-import os
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import time
+import torch
import logging
+from typing import Dict, Tuple
from contextlib import contextmanager
from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
-import torch
-import torch.nn as nn
-import random
-import numpy as np
-import time
-# from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
-)
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.model import Paraformer
+from funasr.models.paraformer.search import Hypothesis
from funasr.models.paraformer.cif_predictor import mae_loss
-
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.metrics.compute_acc import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.models.paraformer.search import Hypothesis
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -38,15 +32,7 @@
@contextmanager
def autocast(enabled=True):
yield
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-from funasr.models.ctc.ctc import CTC
-from funasr.models.paraformer.model import Paraformer
-
-from funasr.register import tables
@tables.register("model_classes", "ParaformerStreaming")
class ParaformerStreaming(Paraformer):
@@ -499,7 +485,7 @@
return results
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list = None,
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index db5c7dd..a1ce310 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -1,3 +1,8 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import os
import re
import time
@@ -8,24 +13,24 @@
import tempfile
import requests
import numpy as np
-from typing import Dict
-from typing import List
-from typing import Tuple
-from typing import Union
-from typing import Optional
+from typing import Dict, Tuple
from contextlib import contextmanager
from distutils.version import LooseVersion
-from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
-)
+from funasr.register import tables
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.models.paraformer.model import Paraformer
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.search import Hypothesis
from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.models.bicif_paraformer.model import BiCifParaformer
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-from funasr.metrics.compute_acc import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.models.paraformer.search import Hypothesis
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
@@ -35,13 +40,6 @@
@contextmanager
def autocast(enabled=True):
yield
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
-
-from funasr.models.paraformer.model import Paraformer
-from funasr.models.bicif_paraformer.model import BiCifParaformer
-from funasr.register import tables
@tables.register("model_classes", "SeacoParaformer")
@@ -306,7 +304,7 @@
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
'''
- def generate(self,
+ def inference(self,
data_in,
data_lengths=None,
key: list = None,
--
Gitblit v1.9.1