From 0ec7f0aea6a2ef86a607539eedb97334bdea56b7 Mon Sep 17 00:00:00 2001
From: yhliang <68215459+yhliang-aslp@users.noreply.github.com>
Date: 星期四, 18 五月 2023 11:41:44 +0800
Subject: [PATCH] Merge pull request #525 from alibaba-damo-academy/dev_lyh

---
 egs/alimeeting/sa-asr/asr_local.sh                        |    8 ++++----
 funasr/models/frontend/default.py                         |    4 ++--
 egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py |    4 ++--
 egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh       |   11 ++++++-----
 egs/alimeeting/sa-asr/path.sh                             |    2 +-
 5 files changed, 15 insertions(+), 14 deletions(-)

diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa-asr/asr_local.sh
index 30401b9..05599b7 100755
--- a/egs/alimeeting/sa-asr/asr_local.sh
+++ b/egs/alimeeting/sa-asr/asr_local.sh
@@ -1153,10 +1153,10 @@
         mkdir -p ${sa_asr_exp}/log
         INIT_FILE=${sa_asr_exp}/ddp_init
         
-        if [ ! -f "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then
+        if [ ! -f "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pb" ]; then
             # download xvector extractor model file
             python local/download_xvector_model.py exp
-            log "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth"
+            log "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pb"
         fi
         
         if [ -f $INIT_FILE ];then
@@ -1195,8 +1195,8 @@
                     --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \
                     --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \
                     --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \
-                    --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder"   \
-                    --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense"   \
+                    --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pb:encoder:spk_encoder"   \
+                    --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pb:decoder:spk_encoder:decoder.output_dense"   \
                     --valid_data_path_and_name_and_type "${_asr_valid_dir}/${_scp},speech,${_type}" \
                     --valid_data_path_and_name_and_type "${_asr_valid_dir}/text,text,text" \
                     --valid_data_path_and_name_and_type "${_asr_valid_dir}/oracle_profile_nopadding.scp,profile,npy" \
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
index 7d39cdc..c13ee42 100755
--- a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
+++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
@@ -61,9 +61,9 @@
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 
     log "stage 1:process alimeeting near dir"
     
-    find -L $near_raw_dir/audio_dir -iname "*.wav" >  $near_dir/wavlist
+    find -L $near_raw_dir/audio_dir -iname "*.wav" | sort >  $near_dir/wavlist
     awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid   
-    find -L $near_raw_dir/textgrid_dir  -iname "*.TextGrid" > $near_dir/textgrid.flist
+    find -L $near_raw_dir/textgrid_dir  -iname "*.TextGrid" | sort > $near_dir/textgrid.flist
     n1_wav=$(wc -l < $near_dir/wavlist)
     n2_text=$(wc -l < $near_dir/textgrid.flist)
     log  near file found $n1_wav wav and $n2_text text.
@@ -90,9 +90,9 @@
 if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
     log "stage 2:process alimeeting far dir"
     
-    find -L $far_raw_dir/audio_dir -iname "*.wav" >  $far_dir/wavlist
+    find -L $far_raw_dir/audio_dir -iname "*.wav" | sort >  $far_dir/wavlist
     awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid   
-    find -L $far_raw_dir/textgrid_dir  -iname "*.TextGrid" > $far_dir/textgrid.flist
+    find -L $far_raw_dir/textgrid_dir  -iname "*.TextGrid" | sort > $far_dir/textgrid.flist
     n1_wav=$(wc -l < $far_dir/wavlist)
     n2_text=$(wc -l < $far_dir/textgrid.flist)
     log  far file found $n1_wav wav and $n2_text text.
@@ -120,7 +120,8 @@
 
 if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
     log "stage 3: finali data process"
-
+    local/fix_data_dir.sh $near_dir
+    local/fix_data_dir.sh $far_dir
     local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
     local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
 
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py b/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
index b70a32a..186f1de 100644
--- a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
+++ b/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
@@ -42,8 +42,8 @@
             global_spk_list_tmp = global_spk_list[: ]
             for spk in meeting_map_tmp[meeting]:
                 global_spk_list_tmp.remove(spk)
-                padding_spk = random.sample(global_spk_list_tmp, 4 - num)
-                meeting_map_tmp[meeting] = meeting_map_tmp[meeting] + padding_spk
+            padding_spk = random.sample(global_spk_list_tmp, 4 - num)
+            meeting_map_tmp[meeting] = meeting_map_tmp[meeting] + padding_spk
     
     meeting_map = {}
     os.system('mkdir -p ' + path + '/oracle_profile_padding')
diff --git a/egs/alimeeting/sa-asr/path.sh b/egs/alimeeting/sa-asr/path.sh
index 5721f3f..dfc2b78 100755
--- a/egs/alimeeting/sa-asr/path.sh
+++ b/egs/alimeeting/sa-asr/path.sh
@@ -2,4 +2,4 @@
 
 # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
 export PYTHONIOENCODING=UTF-8
-export PATH=$FUNASR_DIR/funasr/bin:$PATH
\ No newline at end of file
+export PATH=$FUNASR_DIR/funasr/bin:./utils:$PATH
\ No newline at end of file
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index 2e1b0c4..c4dd7c5 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -102,8 +102,8 @@
         if input_stft.dim() == 4:
             # h: (B, T, C, F) -> h: (B, T, F)
             if self.training:
-                if self.use_channel == None:
-                    input_stft = input_stft[:, :, 0, :]
+                if self.use_channel is not None:
+                    input_stft = input_stft[:, :, self.use_channel, :]
                 else:
                     # Select 1ch randomly
                     ch = np.random.randint(input_stft.size(2))

--
Gitblit v1.9.1