From 49f13908deaed06bb4b0a01631e85e2833f1f051 Mon Sep 17 00:00:00 2001
From: smohan-speech <smohan@mail.ustc.edu.cn>
Date: 星期日, 07 五月 2023 02:27:58 +0800
Subject: [PATCH] add speaker-attributed ASR task for alimeeting
---
/dev/null | 291 ----------------------------------------------------------
1 files changed, 0 insertions(+), 291 deletions(-)
diff --git a/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py b/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py
deleted file mode 100755
index 1fd63d6..0000000
--- a/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py
+++ /dev/null
@@ -1,243 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-from io import BytesIO
-from pathlib import Path
-from typing import Tuple, Optional
-
-import kaldiio
-import humanfriendly
-import numpy as np
-import resampy
-import soundfile
-from tqdm import tqdm
-from typeguard import check_argument_types
-
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.fileio.read_text import read_2column_text
-from funasr.fileio.sound_scp import SoundScpWriter
-
-
-def humanfriendly_or_none(value: str):
- if value in ("none", "None", "NONE"):
- return None
- return humanfriendly.parse_size(value)
-
-
-def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
- """
-
- >>> str2int_tuple('3,4,5')
- (3, 4, 5)
-
- """
- assert check_argument_types()
- if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
- return None
- return tuple(map(int, integers.strip().split(",")))
-
-
-def main():
- logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
- logging.basicConfig(level=logging.INFO, format=logfmt)
- logging.info(get_commandline_args())
-
- parser = argparse.ArgumentParser(
- description='Create waves list from "wav.scp"',
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument("scp")
- parser.add_argument("outdir")
- parser.add_argument(
- "--name",
- default="wav",
- help="Specify the prefix word of output file name " 'such as "wav.scp"',
- )
- parser.add_argument("--segments", default=None)
- parser.add_argument(
- "--fs",
- type=humanfriendly_or_none,
- default=None,
- help="If the sampling rate specified, " "Change the sampling rate.",
- )
- parser.add_argument("--audio-format", default="wav")
- group = parser.add_mutually_exclusive_group()
- group.add_argument("--ref-channels", default=None, type=str2int_tuple)
- group.add_argument("--utt2ref-channels", default=None, type=str)
- args = parser.parse_args()
-
- out_num_samples = Path(args.outdir) / f"utt2num_samples"
-
- if args.ref_channels is not None:
-
- def utt2ref_channels(x) -> Tuple[int, ...]:
- return args.ref_channels
-
- elif args.utt2ref_channels is not None:
- utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)
-
- def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
- chs_str = d[x]
- return tuple(map(int, chs_str.split()))
-
- else:
- utt2ref_channels = None
-
- Path(args.outdir).mkdir(parents=True, exist_ok=True)
- out_wavscp = Path(args.outdir) / f"{args.name}.scp"
- if args.segments is not None:
- # Note: kaldiio supports only wav-pcm-int16le file.
- loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
- if args.audio_format.endswith("ark"):
- fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
- fscp = out_wavscp.open("w")
- else:
- writer = SoundScpWriter(
- args.outdir,
- out_wavscp,
- format=args.audio_format,
- )
-
- with out_num_samples.open("w") as fnum_samples:
- for uttid, (rate, wave) in tqdm(loader):
- # wave: (Time,) or (Time, Nmic)
- if wave.ndim == 2 and utt2ref_channels is not None:
- wave = wave[:, utt2ref_channels(uttid)]
-
- if args.fs is not None and args.fs != rate:
- # FIXME(kamo): To use sox?
- wave = resampy.resample(
- wave.astype(np.float64), rate, args.fs, axis=0
- )
- wave = wave.astype(np.int16)
- rate = args.fs
- if args.audio_format.endswith("ark"):
- if "flac" in args.audio_format:
- suf = "flac"
- elif "wav" in args.audio_format:
- suf = "wav"
- else:
- raise RuntimeError("wav.ark or flac")
-
- # NOTE(kamo): Using extended ark format style here.
- # This format is incompatible with Kaldi
- kaldiio.save_ark(
- fark,
- {uttid: (wave, rate)},
- scp=fscp,
- append=True,
- write_function=f"soundfile_{suf}",
- )
-
- else:
- writer[uttid] = rate, wave
- fnum_samples.write(f"{uttid} {len(wave)}\n")
- else:
- if args.audio_format.endswith("ark"):
- fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
- else:
- wavdir = Path(args.outdir) / f"data_{args.name}"
- wavdir.mkdir(parents=True, exist_ok=True)
-
- with Path(args.scp).open("r") as fscp, out_wavscp.open(
- "w"
- ) as fout, out_num_samples.open("w") as fnum_samples:
- for line in tqdm(fscp):
- uttid, wavpath = line.strip().split(None, 1)
-
- if wavpath.endswith("|"):
- # Streaming input e.g. cat a.wav |
- with kaldiio.open_like_kaldi(wavpath, "rb") as f:
- with BytesIO(f.read()) as g:
- wave, rate = soundfile.read(g, dtype=np.int16)
- if wave.ndim == 2 and utt2ref_channels is not None:
- wave = wave[:, utt2ref_channels(uttid)]
-
- if args.fs is not None and args.fs != rate:
- # FIXME(kamo): To use sox?
- wave = resampy.resample(
- wave.astype(np.float64), rate, args.fs, axis=0
- )
- wave = wave.astype(np.int16)
- rate = args.fs
-
- if args.audio_format.endswith("ark"):
- if "flac" in args.audio_format:
- suf = "flac"
- elif "wav" in args.audio_format:
- suf = "wav"
- else:
- raise RuntimeError("wav.ark or flac")
-
- # NOTE(kamo): Using extended ark format style here.
- # This format is incompatible with Kaldi
- kaldiio.save_ark(
- fark,
- {uttid: (wave, rate)},
- scp=fout,
- append=True,
- write_function=f"soundfile_{suf}",
- )
- else:
- owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
- soundfile.write(owavpath, wave, rate)
- fout.write(f"{uttid} {owavpath}\n")
- else:
- wave, rate = soundfile.read(wavpath, dtype=np.int16)
- if wave.ndim == 2 and utt2ref_channels is not None:
- wave = wave[:, utt2ref_channels(uttid)]
- save_asis = False
-
- elif args.audio_format.endswith("ark"):
- save_asis = False
-
- elif Path(wavpath).suffix == "." + args.audio_format and (
- args.fs is None or args.fs == rate
- ):
- save_asis = True
-
- else:
- save_asis = False
-
- if save_asis:
- # Neither --segments nor --fs are specified and
- # the line doesn't end with "|",
- # i.e. not using unix-pipe,
- # only in this case,
- # just using the original file as is.
- fout.write(f"{uttid} {wavpath}\n")
- else:
- if args.fs is not None and args.fs != rate:
- # FIXME(kamo): To use sox?
- wave = resampy.resample(
- wave.astype(np.float64), rate, args.fs, axis=0
- )
- wave = wave.astype(np.int16)
- rate = args.fs
-
- if args.audio_format.endswith("ark"):
- if "flac" in args.audio_format:
- suf = "flac"
- elif "wav" in args.audio_format:
- suf = "wav"
- else:
- raise RuntimeError("wav.ark or flac")
-
- # NOTE(kamo): Using extended ark format style here.
- # This format is not supported in Kaldi.
- kaldiio.save_ark(
- fark,
- {uttid: (wave, rate)},
- scp=fout,
- append=True,
- write_function=f"soundfile_{suf}",
- )
- else:
- owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
- soundfile.write(owavpath, wave, rate)
- fout.write(f"{uttid} {owavpath}\n")
- fnum_samples.write(f"{uttid} {len(wave)}\n")
-
-
-if __name__ == "__main__":
- main()
diff --git a/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py b/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py
deleted file mode 100755
index b0c61e5..0000000
--- a/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#!/usr/bin/env python
-import sys
-
-
-def get_commandline_args(no_executable=True):
- extra_chars = [
- " ",
- ";",
- "&",
- "|",
- "<",
- ">",
- "?",
- "*",
- "~",
- "`",
- '"',
- "'",
- "\\",
- "{",
- "}",
- "(",
- ")",
- ]
-
- # Escape the extra characters for shell
- argv = [
- arg.replace("'", "'\\''")
- if all(char not in arg for char in extra_chars)
- else "'" + arg.replace("'", "'\\''") + "'"
- for arg in sys.argv
- ]
-
- if no_executable:
- return " ".join(argv[1:])
- else:
- return sys.executable + " " + " ".join(argv)
-
-
-def main():
- print(get_commandline_args())
-
-
-if __name__ == "__main__":
- main()
diff --git a/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh b/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh
deleted file mode 100755
index 15e4563..0000000
--- a/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh
+++ /dev/null
@@ -1,142 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-SECONDS=0
-log() {
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-help_message=$(cat << EOF
-Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]]
-e.g.
-$0 data/test/wav.scp data/test_format/
-
-Format 'wav.scp': In short words,
-changing "kaldi-datadir" to "modified-kaldi-datadir"
-
-The 'wav.scp' format in kaldi is very flexible,
-e.g. It can use unix-pipe as describing that wav file,
-but it sometime looks confusing and make scripts more complex.
-This tools creates actual wav files from 'wav.scp'
-and also segments wav files using 'segments'.
-
-Options
- --fs <fs>
- --segments <segments>
- --nj <nj>
- --cmd <cmd>
-EOF
-)
-
-out_filename=wav.scp
-cmd=utils/run.pl
-nj=30
-fs=none
-segments=
-
-ref_channels=
-utt2ref_channels=
-
-audio_format=wav
-write_utt2num_samples=true
-
-log "$0 $*"
-. utils/parse_options.sh
-
-if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then
- log "${help_message}"
- log "Error: invalid command line arguments"
- exit 1
-fi
-
-. ./path.sh # Setup the environment
-
-scp=$1
-if [ ! -f "${scp}" ]; then
- log "${help_message}"
- echo "$0: Error: No such file: ${scp}"
- exit 1
-fi
-dir=$2
-
-
-if [ $# -eq 2 ]; then
- logdir=${dir}/logs
- outdir=${dir}/data
-
-elif [ $# -eq 3 ]; then
- logdir=$3
- outdir=${dir}/data
-
-elif [ $# -eq 4 ]; then
- logdir=$3
- outdir=$4
-fi
-
-
-mkdir -p ${logdir}
-
-rm -f "${dir}/${out_filename}"
-
-
-opts=
-if [ -n "${utt2ref_channels}" ]; then
- opts="--utt2ref-channels ${utt2ref_channels} "
-elif [ -n "${ref_channels}" ]; then
- opts="--ref-channels ${ref_channels} "
-fi
-
-
-if [ -n "${segments}" ]; then
- log "[info]: using ${segments}"
- nutt=$(<${segments} wc -l)
- nj=$((nj<nutt?nj:nutt))
-
- split_segments=""
- for n in $(seq ${nj}); do
- split_segments="${split_segments} ${logdir}/segments.${n}"
- done
-
- utils/split_scp.pl "${segments}" ${split_segments}
-
- ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
- pyscripts/audio/format_wav_scp.py \
- ${opts} \
- --fs ${fs} \
- --audio-format "${audio_format}" \
- "--segment=${logdir}/segments.JOB" \
- "${scp}" "${outdir}/format.JOB"
-
-else
- log "[info]: without segments"
- nutt=$(<${scp} wc -l)
- nj=$((nj<nutt?nj:nutt))
-
- split_scps=""
- for n in $(seq ${nj}); do
- split_scps="${split_scps} ${logdir}/wav.${n}.scp"
- done
-
- utils/split_scp.pl "${scp}" ${split_scps}
- ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
- pyscripts/audio/format_wav_scp.py \
- ${opts} \
- --fs "${fs}" \
- --audio-format "${audio_format}" \
- "${logdir}/wav.JOB.scp" ${outdir}/format.JOB""
-fi
-
-# Workaround for the NFS problem
-ls ${outdir}/format.* > /dev/null
-
-# concatenate the .scp files together.
-for n in $(seq ${nj}); do
- cat "${outdir}/format.${n}/wav.scp" || exit 1;
-done > "${dir}/${out_filename}" || exit 1
-
-if "${write_utt2num_samples}"; then
- for n in $(seq ${nj}); do
- cat "${outdir}/format.${n}/utt2num_samples" || exit 1;
- done > "${dir}/utt2num_samples" || exit 1
-fi
-
-log "Successfully finished. [elapsed=${SECONDS}s]"
diff --git a/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh b/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh
deleted file mode 100755
index 9e08dba..0000000
--- a/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh
+++ /dev/null
@@ -1,116 +0,0 @@
-#!/usr/bin/env bash
-
-# 2020 @kamo-naoyuki
-# This file was copied from Kaldi and
-# I deleted parts related to wav duration
-# because we shouldn't use kaldi's command here
-# and we don't need the files actually.
-
-# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
-# 2014 Tom Ko
-# 2018 Emotech LTD (author: Pawel Swietojanski)
-# Apache 2.0
-
-# This script operates on a directory, such as in data/train/,
-# that contains some subset of the following files:
-# wav.scp
-# spk2utt
-# utt2spk
-# text
-#
-# It generates the files which are used for perturbing the speed of the original data.
-
-export LC_ALL=C
-set -euo pipefail
-
-if [[ $# != 3 ]]; then
- echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>"
- echo "e.g.:"
- echo " $0 0.9 data/train_si284 data/train_si284p"
- exit 1
-fi
-
-factor=$1
-srcdir=$2
-destdir=$3
-label="sp"
-spk_prefix="${label}${factor}-"
-utt_prefix="${label}${factor}-"
-
-#check is sox on the path
-
-! command -v sox &>/dev/null && echo "sox: command not found" && exit 1;
-
-if [[ ! -f ${srcdir}/utt2spk ]]; then
- echo "$0: no such file ${srcdir}/utt2spk"
- exit 1;
-fi
-
-if [[ ${destdir} == "${srcdir}" ]]; then
- echo "$0: this script requires <srcdir> and <destdir> to be different."
- exit 1
-fi
-
-mkdir -p "${destdir}"
-
-<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map"
-<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map"
-<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map"
-if [[ ! -f ${srcdir}/utt2uniq ]]; then
- <"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq"
-else
- <"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq"
-fi
-
-
-<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
- utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
-
-utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
-
-if [[ -f ${srcdir}/segments ]]; then
-
- utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
- utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
- awk -v factor="${factor}" \
- '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
- >"${destdir}"/segments
-
- utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
- # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
- awk -v factor="${factor}" \
- '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
- else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
- else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
- > "${destdir}"/wav.scp
- if [[ -f ${srcdir}/reco2file_and_channel ]]; then
- utils/apply_map.pl -f 1 "${destdir}"/reco_map \
- <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
- fi
-
-else # no segments->wav indexed by utterance.
- if [[ -f ${srcdir}/wav.scp ]]; then
- utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
- # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
- awk -v factor="${factor}" \
- '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
- else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
- else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
- > "${destdir}"/wav.scp
- fi
-fi
-
-if [[ -f ${srcdir}/text ]]; then
- utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
-fi
-if [[ -f ${srcdir}/spk2gender ]]; then
- utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
-fi
-if [[ -f ${srcdir}/utt2lang ]]; then
- utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
-fi
-
-rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
-echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
-
-utils/validate_data_dir.sh --no-feats --no-text "${destdir}"
diff --git a/funasr/losses/nll_loss.py b/funasr/losses/nll_loss.py
deleted file mode 100644
index 7e4e294..0000000
--- a/funasr/losses/nll_loss.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import torch
-from torch import nn
-
-class NllLoss(nn.Module):
- """Nll loss.
-
- :param int size: the number of class
- :param int padding_idx: ignored class id
- :param bool normalize_length: normalize loss by sequence length if True
- :param torch.nn.Module criterion: loss function
- """
-
- def __init__(
- self,
- size,
- padding_idx,
- normalize_length=False,
- criterion=nn.NLLLoss(reduction='none'),
- ):
- """Construct an LabelSmoothingLoss object."""
- super(NllLoss, self).__init__()
- self.criterion = criterion
- self.padding_idx = padding_idx
- self.size = size
- self.true_dist = None
- self.normalize_length = normalize_length
-
- def forward(self, x, target):
- """Compute loss between x and target.
-
- :param torch.Tensor x: prediction (batch, seqlen, class)
- :param torch.Tensor target:
- target signal masked with self.padding_id (batch, seqlen)
- :return: scalar float value
- :rtype torch.Tensor
- """
- assert x.size(2) == self.size
- batch_size = x.size(0)
- x = x.view(-1, self.size)
- target = target.view(-1)
- with torch.no_grad():
- ignore = target == self.padding_idx # (B,)
- total = len(target) - ignore.sum().item()
- target = target.masked_fill(ignore, 0) # avoid -1 index
- kl = self.criterion(x , target)
- denom = total if self.normalize_length else batch_size
- return kl.masked_fill(ignore, 0).sum() / denom
diff --git a/funasr/models/decoder/decoder_layer_sa_asr.py b/funasr/models/decoder/decoder_layer_sa_asr.py
deleted file mode 100644
index 80afc51..0000000
--- a/funasr/models/decoder/decoder_layer_sa_asr.py
+++ /dev/null
@@ -1,169 +0,0 @@
-import torch
-from torch import nn
-
-from funasr.modules.layer_norm import LayerNorm
-
-
-class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
-
- def __init__(
- self,
- size,
- self_attn,
- src_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- ):
- """Construct an DecoderLayer object."""
- super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
- self.size = size
- self.self_attn = self_attn
- self.src_attn = src_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(size)
- self.norm2 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear1 = nn.Linear(size + size, size)
- self.concat_linear2 = nn.Linear(size + size, size)
-
- def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
-
- residual = tgt
- if self.normalize_before:
- tgt = self.norm1(tgt)
-
- if cache is None:
- tgt_q = tgt
- tgt_q_mask = tgt_mask
- else:
- # compute only the last frame query keeping dim: max_time_out -> 1
- assert cache.shape == (
- tgt.shape[0],
- tgt.shape[1] - 1,
- self.size,
- ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
- tgt_q = tgt[:, -1:, :]
- residual = residual[:, -1:, :]
- tgt_q_mask = None
- if tgt_mask is not None:
- tgt_q_mask = tgt_mask[:, -1:, :]
-
- if self.concat_after:
- tgt_concat = torch.cat(
- (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
- )
- x = residual + self.concat_linear1(tgt_concat)
- else:
- x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
- if not self.normalize_before:
- x = self.norm1(x)
- z = x
-
- residual = x
- if self.normalize_before:
- x = self.norm1(x)
-
- skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
-
- if self.concat_after:
- x_concat = torch.cat(
- (x, skip), dim=-1
- )
- x = residual + self.concat_linear2(x_concat)
- else:
- x = residual + self.dropout(skip)
- if not self.normalize_before:
- x = self.norm1(x)
-
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- x = residual + self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm2(x)
-
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
-
- return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
-
-class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
-
- def __init__(
- self,
- size,
- d_size,
- src_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- ):
- """Construct an DecoderLayer object."""
- super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
- self.size = size
- self.src_attn = src_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(size)
- self.norm2 = LayerNorm(size)
- self.norm3 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- self.spk_linear = nn.Linear(d_size, size, bias=False)
- if self.concat_after:
- self.concat_linear1 = nn.Linear(size + size, size)
- self.concat_linear2 = nn.Linear(size + size, size)
-
- def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
-
- residual = tgt
- if self.normalize_before:
- tgt = self.norm1(tgt)
-
- if cache is None:
- tgt_q = tgt
- tgt_q_mask = tgt_mask
- else:
-
- tgt_q = tgt[:, -1:, :]
- residual = residual[:, -1:, :]
- tgt_q_mask = None
- if tgt_mask is not None:
- tgt_q_mask = tgt_mask[:, -1:, :]
-
- x = tgt_q
- if self.normalize_before:
- x = self.norm2(x)
- if self.concat_after:
- x_concat = torch.cat(
- (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
- )
- x = residual + self.concat_linear2(x_concat)
- else:
- x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
- if not self.normalize_before:
- x = self.norm2(x)
- residual = x
-
- if dn!=None:
- x = x + self.spk_linear(dn)
- if self.normalize_before:
- x = self.norm3(x)
-
- x = residual + self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm3(x)
-
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
-
- return x, tgt_mask, memory, memory_mask
-
-
-
diff --git a/funasr/models/decoder/transformer_decoder_sa_asr.py b/funasr/models/decoder/transformer_decoder_sa_asr.py
deleted file mode 100644
index 949f9c8..0000000
--- a/funasr/models/decoder/transformer_decoder_sa_asr.py
+++ /dev/null
@@ -1,291 +0,0 @@
-from typing import Any
-from typing import List
-from typing import Sequence
-from typing import Tuple
-
-import torch
-from typeguard import check_argument_types
-
-from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.attention import MultiHeadedAttention
-from funasr.modules.attention import CosineDistanceAttention
-from funasr.models.decoder.transformer_decoder import DecoderLayer
-from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeAsrDecoderFirstLayer
-from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeSpkDecoderFirstLayer
-from funasr.modules.dynamic_conv import DynamicConvolution
-from funasr.modules.dynamic_conv2d import DynamicConvolution2D
-from funasr.modules.embedding import PositionalEncoding
-from funasr.modules.layer_norm import LayerNorm
-from funasr.modules.lightconv import LightweightConvolution
-from funasr.modules.lightconv2d import LightweightConvolution2D
-from funasr.modules.mask import subsequent_mask
-from funasr.modules.positionwise_feed_forward import (
- PositionwiseFeedForward, # noqa: H301
-)
-from funasr.modules.repeat import repeat
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-from funasr.models.decoder.abs_decoder import AbsDecoder
-
-class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
-
- def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- spker_embedding_dim: int = 256,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- input_layer: str = "embed",
- use_asr_output_layer: bool = True,
- use_spk_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- ):
- assert check_argument_types()
- super().__init__()
- attention_dim = encoder_output_size
-
- if input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(vocab_size, attention_dim),
- pos_enc_class(attention_dim, positional_dropout_rate),
- )
- elif input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(vocab_size, attention_dim),
- torch.nn.LayerNorm(attention_dim),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(attention_dim, positional_dropout_rate),
- )
- else:
- raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
-
- self.normalize_before = normalize_before
- if self.normalize_before:
- self.after_norm = LayerNorm(attention_dim)
- if use_asr_output_layer:
- self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
- else:
- self.asr_output_layer = None
-
- if use_spk_output_layer:
- self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
- else:
- self.spk_output_layer = None
-
- self.cos_distance_att = CosineDistanceAttention()
-
- self.decoder1 = None
- self.decoder2 = None
- self.decoder3 = None
- self.decoder4 = None
-
- def forward(
- self,
- asr_hs_pad: torch.Tensor,
- spk_hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
- profile: torch.Tensor,
- profile_lens: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
-
- tgt = ys_in_pad
- # tgt_mask: (B, 1, L)
- tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
- # m: (1, L, L)
- m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
- # tgt_mask: (B, L, L)
- tgt_mask = tgt_mask & m
-
- asr_memory = asr_hs_pad
- spk_memory = spk_hs_pad
- memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
- # Spk decoder
- x = self.embed(tgt)
-
- x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
- x, tgt_mask, asr_memory, spk_memory, memory_mask
- )
- x, tgt_mask, spk_memory, memory_mask = self.decoder2(
- x, tgt_mask, spk_memory, memory_mask
- )
- if self.normalize_before:
- x = self.after_norm(x)
- if self.spk_output_layer is not None:
- x = self.spk_output_layer(x)
- dn, weights = self.cos_distance_att(x, profile, profile_lens)
- # Asr decoder
- x, tgt_mask, asr_memory, memory_mask = self.decoder3(
- z, tgt_mask, asr_memory, memory_mask, dn
- )
- x, tgt_mask, asr_memory, memory_mask = self.decoder4(
- x, tgt_mask, asr_memory, memory_mask
- )
-
- if self.normalize_before:
- x = self.after_norm(x)
- if self.asr_output_layer is not None:
- x = self.asr_output_layer(x)
-
- olens = tgt_mask.sum(1)
- return x, weights, olens
-
-
- def forward_one_step(
- self,
- tgt: torch.Tensor,
- tgt_mask: torch.Tensor,
- asr_memory: torch.Tensor,
- spk_memory: torch.Tensor,
- profile: torch.Tensor,
- cache: List[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
-
- x = self.embed(tgt)
-
- if cache is None:
- cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
- new_cache = []
- x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
- x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
- )
- new_cache.append(x)
- for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
- x, tgt_mask, spk_memory, _ = decoder(
- x, tgt_mask, spk_memory, None, cache=c
- )
- new_cache.append(x)
- if self.normalize_before:
- x = self.after_norm(x)
- else:
- x = x
- if self.spk_output_layer is not None:
- x = self.spk_output_layer(x)
- dn, weights = self.cos_distance_att(x, profile, None)
-
- x, tgt_mask, asr_memory, _ = self.decoder3(
- z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
- )
- new_cache.append(x)
-
- for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
- x, tgt_mask, asr_memory, _ = decoder(
- x, tgt_mask, asr_memory, None, cache=c
- )
- new_cache.append(x)
-
- if self.normalize_before:
- y = self.after_norm(x[:, -1])
- else:
- y = x[:, -1]
- if self.asr_output_layer is not None:
- y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
-
- return y, weights, new_cache
-
- def score(self, ys, state, asr_enc, spk_enc, profile):
- """Score."""
- ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
- logp, weights, state = self.forward_one_step(
- ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
- )
- return logp.squeeze(0), weights.squeeze(), state
-
-class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
- def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- spker_embedding_dim: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- asr_num_blocks: int = 6,
- spk_num_blocks: int = 3,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_asr_output_layer: bool = True,
- use_spk_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- ):
- assert check_argument_types()
- super().__init__(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- spker_embedding_dim=spker_embedding_dim,
- dropout_rate=dropout_rate,
- positional_dropout_rate=positional_dropout_rate,
- input_layer=input_layer,
- use_asr_output_layer=use_asr_output_layer,
- use_spk_output_layer=use_spk_output_layer,
- pos_enc_class=pos_enc_class,
- normalize_before=normalize_before,
- )
-
- attention_dim = encoder_output_size
-
- self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
- attention_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, self_attention_dropout_rate
- ),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- )
- self.decoder2 = repeat(
- spk_num_blocks - 1,
- lambda lnum: DecoderLayer(
- attention_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, self_attention_dropout_rate
- ),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
-
- self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
- attention_dim,
- spker_embedding_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- )
- self.decoder4 = repeat(
- asr_num_blocks - 1,
- lambda lnum: DecoderLayer(
- attention_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, self_attention_dropout_rate
- ),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
--
Gitblit v1.9.1