From 49f13908deaed06bb4b0a01631e85e2833f1f051 Mon Sep 17 00:00:00 2001
From: smohan-speech <smohan@mail.ustc.edu.cn>
Date: 星期日, 07 五月 2023 02:27:58 +0800
Subject: [PATCH] add speaker-attributed ASR task for alimeeting

---
 /dev/null |  291 ----------------------------------------------------------
 1 files changed, 0 insertions(+), 291 deletions(-)

diff --git a/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py b/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py
deleted file mode 100755
index 1fd63d6..0000000
--- a/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py
+++ /dev/null
@@ -1,243 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-from io import BytesIO
-from pathlib import Path
-from typing import Tuple, Optional
-
-import kaldiio
-import humanfriendly
-import numpy as np
-import resampy
-import soundfile
-from tqdm import tqdm
-from typeguard import check_argument_types
-
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.fileio.read_text import read_2column_text
-from funasr.fileio.sound_scp import SoundScpWriter
-
-
-def humanfriendly_or_none(value: str):
-    if value in ("none", "None", "NONE"):
-        return None
-    return humanfriendly.parse_size(value)
-
-
-def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
-    """
-
-    >>> str2int_tuple('3,4,5')
-    (3, 4, 5)
-
-    """
-    assert check_argument_types()
-    if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
-        return None
-    return tuple(map(int, integers.strip().split(",")))
-
-
-def main():
-    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
-    logging.basicConfig(level=logging.INFO, format=logfmt)
-    logging.info(get_commandline_args())
-
-    parser = argparse.ArgumentParser(
-        description='Create waves list from "wav.scp"',
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-    parser.add_argument("scp")
-    parser.add_argument("outdir")
-    parser.add_argument(
-        "--name",
-        default="wav",
-        help="Specify the prefix word of output file name " 'such as "wav.scp"',
-    )
-    parser.add_argument("--segments", default=None)
-    parser.add_argument(
-        "--fs",
-        type=humanfriendly_or_none,
-        default=None,
-        help="If the sampling rate specified, " "Change the sampling rate.",
-    )
-    parser.add_argument("--audio-format", default="wav")
-    group = parser.add_mutually_exclusive_group()
-    group.add_argument("--ref-channels", default=None, type=str2int_tuple)
-    group.add_argument("--utt2ref-channels", default=None, type=str)
-    args = parser.parse_args()
-
-    out_num_samples = Path(args.outdir) / f"utt2num_samples"
-
-    if args.ref_channels is not None:
-
-        def utt2ref_channels(x) -> Tuple[int, ...]:
-            return args.ref_channels
-
-    elif args.utt2ref_channels is not None:
-        utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)
-
-        def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
-            chs_str = d[x]
-            return tuple(map(int, chs_str.split()))
-
-    else:
-        utt2ref_channels = None
-
-    Path(args.outdir).mkdir(parents=True, exist_ok=True)
-    out_wavscp = Path(args.outdir) / f"{args.name}.scp"
-    if args.segments is not None:
-        # Note: kaldiio supports only wav-pcm-int16le file.
-        loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
-        if args.audio_format.endswith("ark"):
-            fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
-            fscp = out_wavscp.open("w")
-        else:
-            writer = SoundScpWriter(
-                args.outdir,
-                out_wavscp,
-                format=args.audio_format,
-            )
-
-        with out_num_samples.open("w") as fnum_samples:
-            for uttid, (rate, wave) in tqdm(loader):
-                # wave: (Time,) or (Time, Nmic)
-                if wave.ndim == 2 and utt2ref_channels is not None:
-                    wave = wave[:, utt2ref_channels(uttid)]
-
-                if args.fs is not None and args.fs != rate:
-                    # FIXME(kamo): To use sox?
-                    wave = resampy.resample(
-                        wave.astype(np.float64), rate, args.fs, axis=0
-                    )
-                    wave = wave.astype(np.int16)
-                    rate = args.fs
-                if args.audio_format.endswith("ark"):
-                    if "flac" in args.audio_format:
-                        suf = "flac"
-                    elif "wav" in args.audio_format:
-                        suf = "wav"
-                    else:
-                        raise RuntimeError("wav.ark or flac")
-
-                    # NOTE(kamo): Using extended ark format style here.
-                    # This format is incompatible with Kaldi
-                    kaldiio.save_ark(
-                        fark,
-                        {uttid: (wave, rate)},
-                        scp=fscp,
-                        append=True,
-                        write_function=f"soundfile_{suf}",
-                    )
-
-                else:
-                    writer[uttid] = rate, wave
-                fnum_samples.write(f"{uttid} {len(wave)}\n")
-    else:
-        if args.audio_format.endswith("ark"):
-            fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
-        else:
-            wavdir = Path(args.outdir) / f"data_{args.name}"
-            wavdir.mkdir(parents=True, exist_ok=True)
-
-        with Path(args.scp).open("r") as fscp, out_wavscp.open(
-            "w"
-        ) as fout, out_num_samples.open("w") as fnum_samples:
-            for line in tqdm(fscp):
-                uttid, wavpath = line.strip().split(None, 1)
-
-                if wavpath.endswith("|"):
-                    # Streaming input e.g. cat a.wav |
-                    with kaldiio.open_like_kaldi(wavpath, "rb") as f:
-                        with BytesIO(f.read()) as g:
-                            wave, rate = soundfile.read(g, dtype=np.int16)
-                            if wave.ndim == 2 and utt2ref_channels is not None:
-                                wave = wave[:, utt2ref_channels(uttid)]
-
-                        if args.fs is not None and args.fs != rate:
-                            # FIXME(kamo): To use sox?
-                            wave = resampy.resample(
-                                wave.astype(np.float64), rate, args.fs, axis=0
-                            )
-                            wave = wave.astype(np.int16)
-                            rate = args.fs
-
-                        if args.audio_format.endswith("ark"):
-                            if "flac" in args.audio_format:
-                                suf = "flac"
-                            elif "wav" in args.audio_format:
-                                suf = "wav"
-                            else:
-                                raise RuntimeError("wav.ark or flac")
-
-                            # NOTE(kamo): Using extended ark format style here.
-                            # This format is incompatible with Kaldi
-                            kaldiio.save_ark(
-                                fark,
-                                {uttid: (wave, rate)},
-                                scp=fout,
-                                append=True,
-                                write_function=f"soundfile_{suf}",
-                            )
-                        else:
-                            owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
-                            soundfile.write(owavpath, wave, rate)
-                            fout.write(f"{uttid} {owavpath}\n")
-                else:
-                    wave, rate = soundfile.read(wavpath, dtype=np.int16)
-                    if wave.ndim == 2 and utt2ref_channels is not None:
-                        wave = wave[:, utt2ref_channels(uttid)]
-                        save_asis = False
-
-                    elif args.audio_format.endswith("ark"):
-                        save_asis = False
-
-                    elif Path(wavpath).suffix == "." + args.audio_format and (
-                        args.fs is None or args.fs == rate
-                    ):
-                        save_asis = True
-
-                    else:
-                        save_asis = False
-
-                    if save_asis:
-                        # Neither --segments nor --fs are specified and
-                        # the line doesn't end with "|",
-                        # i.e. not using unix-pipe,
-                        # only in this case,
-                        # just using the original file as is.
-                        fout.write(f"{uttid} {wavpath}\n")
-                    else:
-                        if args.fs is not None and args.fs != rate:
-                            # FIXME(kamo): To use sox?
-                            wave = resampy.resample(
-                                wave.astype(np.float64), rate, args.fs, axis=0
-                            )
-                            wave = wave.astype(np.int16)
-                            rate = args.fs
-
-                        if args.audio_format.endswith("ark"):
-                            if "flac" in args.audio_format:
-                                suf = "flac"
-                            elif "wav" in args.audio_format:
-                                suf = "wav"
-                            else:
-                                raise RuntimeError("wav.ark or flac")
-
-                            # NOTE(kamo): Using extended ark format style here.
-                            # This format is not supported in Kaldi.
-                            kaldiio.save_ark(
-                                fark,
-                                {uttid: (wave, rate)},
-                                scp=fout,
-                                append=True,
-                                write_function=f"soundfile_{suf}",
-                            )
-                        else:
-                            owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
-                            soundfile.write(owavpath, wave, rate)
-                            fout.write(f"{uttid} {owavpath}\n")
-                fnum_samples.write(f"{uttid} {len(wave)}\n")
-
-
-if __name__ == "__main__":
-    main()
diff --git a/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py b/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py
deleted file mode 100755
index b0c61e5..0000000
--- a/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#!/usr/bin/env python
-import sys
-
-
-def get_commandline_args(no_executable=True):
-    extra_chars = [
-        " ",
-        ";",
-        "&",
-        "|",
-        "<",
-        ">",
-        "?",
-        "*",
-        "~",
-        "`",
-        '"',
-        "'",
-        "\\",
-        "{",
-        "}",
-        "(",
-        ")",
-    ]
-
-    # Escape the extra characters for shell
-    argv = [
-        arg.replace("'", "'\\''")
-        if all(char not in arg for char in extra_chars)
-        else "'" + arg.replace("'", "'\\''") + "'"
-        for arg in sys.argv
-    ]
-
-    if no_executable:
-        return " ".join(argv[1:])
-    else:
-        return sys.executable + " " + " ".join(argv)
-
-
-def main():
-    print(get_commandline_args())
-
-
-if __name__ == "__main__":
-    main()
diff --git a/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh b/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh
deleted file mode 100755
index 15e4563..0000000
--- a/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh
+++ /dev/null
@@ -1,142 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-SECONDS=0
-log() {
-    local fname=${BASH_SOURCE[1]##*/}
-    echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-help_message=$(cat << EOF
-Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]]
-e.g.
-$0 data/test/wav.scp data/test_format/
-
-Format 'wav.scp': In short words,
-changing "kaldi-datadir" to "modified-kaldi-datadir"
-
-The 'wav.scp' format in kaldi is very flexible,
-e.g. It can use unix-pipe as describing that wav file,
-but it sometime looks confusing and make scripts more complex.
-This tools creates actual wav files from 'wav.scp'
-and also segments wav files using 'segments'.
-
-Options
-  --fs <fs>
-  --segments <segments>
-  --nj <nj>
-  --cmd <cmd>
-EOF
-)
-
-out_filename=wav.scp
-cmd=utils/run.pl
-nj=30
-fs=none
-segments=
-
-ref_channels=
-utt2ref_channels=
-
-audio_format=wav
-write_utt2num_samples=true
-
-log "$0 $*"
-. utils/parse_options.sh
-
-if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then
-    log "${help_message}"
-    log "Error: invalid command line arguments"
-    exit 1
-fi
-
-. ./path.sh  # Setup the environment
-
-scp=$1
-if [ ! -f "${scp}" ]; then
-    log "${help_message}"
-    echo "$0: Error: No such file: ${scp}"
-    exit 1
-fi
-dir=$2
-
-
-if [ $# -eq 2 ]; then
-    logdir=${dir}/logs
-    outdir=${dir}/data
-
-elif [ $# -eq 3 ]; then
-    logdir=$3
-    outdir=${dir}/data
-
-elif [ $# -eq 4 ]; then
-    logdir=$3
-    outdir=$4
-fi
-
-
-mkdir -p ${logdir}
-
-rm -f "${dir}/${out_filename}"
-
-
-opts=
-if [ -n "${utt2ref_channels}" ]; then
-    opts="--utt2ref-channels ${utt2ref_channels} "
-elif [ -n "${ref_channels}" ]; then
-    opts="--ref-channels ${ref_channels} "
-fi
-
-
-if [ -n "${segments}" ]; then
-    log "[info]: using ${segments}"
-    nutt=$(<${segments} wc -l)
-    nj=$((nj<nutt?nj:nutt))
-
-    split_segments=""
-    for n in $(seq ${nj}); do
-        split_segments="${split_segments} ${logdir}/segments.${n}"
-    done
-
-    utils/split_scp.pl "${segments}" ${split_segments}
-
-    ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
-        pyscripts/audio/format_wav_scp.py \
-            ${opts} \
-            --fs ${fs} \
-            --audio-format "${audio_format}" \
-            "--segment=${logdir}/segments.JOB" \
-            "${scp}" "${outdir}/format.JOB"
-
-else
-    log "[info]: without segments"
-    nutt=$(<${scp} wc -l)
-    nj=$((nj<nutt?nj:nutt))
-
-    split_scps=""
-    for n in $(seq ${nj}); do
-        split_scps="${split_scps} ${logdir}/wav.${n}.scp"
-    done
-
-    utils/split_scp.pl "${scp}" ${split_scps}
-    ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
-        pyscripts/audio/format_wav_scp.py \
-        ${opts} \
-        --fs "${fs}" \
-        --audio-format "${audio_format}" \
-        "${logdir}/wav.JOB.scp" ${outdir}/format.JOB""
-fi
-
-# Workaround for the NFS problem
-ls ${outdir}/format.* > /dev/null
-
-# concatenate the .scp files together.
-for n in $(seq ${nj}); do
-    cat "${outdir}/format.${n}/wav.scp" || exit 1;
-done > "${dir}/${out_filename}" || exit 1
-
-if "${write_utt2num_samples}"; then
-    for n in $(seq ${nj}); do
-        cat "${outdir}/format.${n}/utt2num_samples" || exit 1;
-    done > "${dir}/utt2num_samples"  || exit 1
-fi
-
-log "Successfully finished. [elapsed=${SECONDS}s]"
diff --git a/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh b/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh
deleted file mode 100755
index 9e08dba..0000000
--- a/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh
+++ /dev/null
@@ -1,116 +0,0 @@
-#!/usr/bin/env bash
-
-# 2020 @kamo-naoyuki
-# This file was copied from Kaldi and 
-# I deleted parts related to wav duration 
-# because we shouldn't use kaldi's command here
-# and we don't need the files actually.
-
-# Copyright 2013  Johns Hopkins University (author: Daniel Povey)
-#           2014  Tom Ko
-#           2018  Emotech LTD (author: Pawel Swietojanski)
-# Apache 2.0
-
-# This script operates on a directory, such as in data/train/,
-# that contains some subset of the following files:
-#  wav.scp
-#  spk2utt
-#  utt2spk
-#  text
-#
-# It generates the files which are used for perturbing the speed of the original data.
-
-export LC_ALL=C
-set -euo pipefail
-
-if [[ $# != 3 ]]; then
-    echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>"
-    echo "e.g.:"
-    echo " $0 0.9 data/train_si284 data/train_si284p"
-    exit 1
-fi
-
-factor=$1
-srcdir=$2
-destdir=$3
-label="sp"
-spk_prefix="${label}${factor}-"
-utt_prefix="${label}${factor}-"
-
-#check is sox on the path
-
-! command -v sox &>/dev/null && echo "sox: command not found" && exit 1;
-
-if [[ ! -f ${srcdir}/utt2spk ]]; then
-  echo "$0: no such file ${srcdir}/utt2spk"
-  exit 1;
-fi
-
-if [[ ${destdir} == "${srcdir}" ]]; then
-  echo "$0: this script requires <srcdir> and <destdir> to be different."
-  exit 1
-fi
-
-mkdir -p "${destdir}"
-
-<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map"
-<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map"
-<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map"
-if [[ ! -f ${srcdir}/utt2uniq ]]; then
-    <"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq"
-else
-    <"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq"
-fi
-
-
-<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
-  utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
-
-utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
-
-if [[ -f ${srcdir}/segments ]]; then
-
-  utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
-      utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
-          awk -v factor="${factor}" \
-            '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
-            >"${destdir}"/segments
-
-  utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
-      # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
-      awk -v factor="${factor}" \
-          '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
-            else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
-            else  {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
-             > "${destdir}"/wav.scp
-  if [[ -f ${srcdir}/reco2file_and_channel ]]; then
-      utils/apply_map.pl -f 1 "${destdir}"/reco_map \
-       <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
-  fi
-
-else # no segments->wav indexed by utterance.
-    if [[ -f ${srcdir}/wav.scp ]]; then
-        utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
-         # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
-         awk -v factor="${factor}" \
-           '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
-             else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
-             else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
-                 > "${destdir}"/wav.scp
-    fi
-fi
-
-if [[ -f ${srcdir}/text ]]; then
-    utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
-fi
-if [[ -f ${srcdir}/spk2gender ]]; then
-    utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
-fi
-if [[ -f ${srcdir}/utt2lang ]]; then
-    utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
-fi
-
-rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
-echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
-
-utils/validate_data_dir.sh --no-feats --no-text "${destdir}"
diff --git a/funasr/losses/nll_loss.py b/funasr/losses/nll_loss.py
deleted file mode 100644
index 7e4e294..0000000
--- a/funasr/losses/nll_loss.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import torch
-from torch import nn
-
-class NllLoss(nn.Module):
-    """Nll loss.
-
-    :param int size: the number of class
-    :param int padding_idx: ignored class id
-    :param bool normalize_length: normalize loss by sequence length if True
-    :param torch.nn.Module criterion: loss function
-    """
-
-    def __init__(
-        self,
-        size,
-        padding_idx,
-        normalize_length=False,
-        criterion=nn.NLLLoss(reduction='none'),
-    ):
-        """Construct an LabelSmoothingLoss object."""
-        super(NllLoss, self).__init__()
-        self.criterion = criterion
-        self.padding_idx = padding_idx
-        self.size = size
-        self.true_dist = None
-        self.normalize_length = normalize_length
-
-    def forward(self, x, target):
-        """Compute loss between x and target.
-
-        :param torch.Tensor x: prediction (batch, seqlen, class)
-        :param torch.Tensor target:
-            target signal masked with self.padding_id (batch, seqlen)
-        :return: scalar float value
-        :rtype torch.Tensor
-        """
-        assert x.size(2) == self.size
-        batch_size = x.size(0)
-        x = x.view(-1, self.size)
-        target = target.view(-1)
-        with torch.no_grad():
-            ignore = target == self.padding_idx  # (B,)
-            total = len(target) - ignore.sum().item()
-            target = target.masked_fill(ignore, 0)  # avoid -1 index
-        kl = self.criterion(x , target)
-        denom = total if self.normalize_length else batch_size
-        return kl.masked_fill(ignore, 0).sum() / denom
diff --git a/funasr/models/decoder/decoder_layer_sa_asr.py b/funasr/models/decoder/decoder_layer_sa_asr.py
deleted file mode 100644
index 80afc51..0000000
--- a/funasr/models/decoder/decoder_layer_sa_asr.py
+++ /dev/null
@@ -1,169 +0,0 @@
-import torch
-from torch import nn
-
-from funasr.modules.layer_norm  import LayerNorm
-
-
-class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
-
-    def __init__(
-        self,
-        size,
-        self_attn,
-        src_attn,
-        feed_forward,
-        dropout_rate,
-        normalize_before=True,
-        concat_after=False,
-    ):
-        """Construct an DecoderLayer object."""
-        super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
-        self.size = size
-        self.self_attn = self_attn
-        self.src_attn = src_attn
-        self.feed_forward = feed_forward
-        self.norm1 = LayerNorm(size)
-        self.norm2 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.normalize_before = normalize_before
-        self.concat_after = concat_after
-        if self.concat_after:
-            self.concat_linear1 = nn.Linear(size + size, size)
-            self.concat_linear2 = nn.Linear(size + size, size)
-
-    def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
-        
-        residual = tgt
-        if self.normalize_before:
-            tgt = self.norm1(tgt)
-
-        if cache is None:
-            tgt_q = tgt
-            tgt_q_mask = tgt_mask
-        else:
-            # compute only the last frame query keeping dim: max_time_out -> 1
-            assert cache.shape == (
-                tgt.shape[0],
-                tgt.shape[1] - 1,
-                self.size,
-            ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
-            tgt_q = tgt[:, -1:, :]
-            residual = residual[:, -1:, :]
-            tgt_q_mask = None
-            if tgt_mask is not None:
-                tgt_q_mask = tgt_mask[:, -1:, :]
-
-        if self.concat_after:
-            tgt_concat = torch.cat(
-                (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
-            )
-            x = residual + self.concat_linear1(tgt_concat)
-        else:
-            x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
-        if not self.normalize_before:
-            x = self.norm1(x)
-        z = x
-        
-        residual = x
-        if self.normalize_before:
-            x = self.norm1(x)
-
-        skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
-
-        if self.concat_after:
-            x_concat = torch.cat(
-                (x, skip), dim=-1
-            )
-            x = residual + self.concat_linear2(x_concat)
-        else:
-            x = residual + self.dropout(skip)
-        if not self.normalize_before:
-            x = self.norm1(x)
-        
-        residual = x
-        if self.normalize_before:
-            x = self.norm2(x)
-        x = residual + self.dropout(self.feed_forward(x))
-        if not self.normalize_before:
-            x = self.norm2(x)
-
-        if cache is not None:
-            x = torch.cat([cache, x], dim=1)
-            
-        return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
-
-class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
-    
-    def __init__(
-        self,
-        size,
-        d_size,
-        src_attn,
-        feed_forward,
-        dropout_rate,
-        normalize_before=True,
-        concat_after=False,
-    ):
-        """Construct an DecoderLayer object."""
-        super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
-        self.size = size
-        self.src_attn = src_attn
-        self.feed_forward = feed_forward
-        self.norm1 = LayerNorm(size)
-        self.norm2 = LayerNorm(size)
-        self.norm3 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.normalize_before = normalize_before
-        self.concat_after = concat_after
-        self.spk_linear = nn.Linear(d_size, size, bias=False)
-        if self.concat_after:
-            self.concat_linear1 = nn.Linear(size + size, size)
-            self.concat_linear2 = nn.Linear(size + size, size)
-
-    def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
-        
-        residual = tgt
-        if self.normalize_before:
-            tgt = self.norm1(tgt)
-
-        if cache is None:
-            tgt_q = tgt
-            tgt_q_mask = tgt_mask
-        else:
-            
-            tgt_q = tgt[:, -1:, :]
-            residual = residual[:, -1:, :]
-            tgt_q_mask = None
-            if tgt_mask is not None:
-                tgt_q_mask = tgt_mask[:, -1:, :]
-
-        x = tgt_q
-        if self.normalize_before:
-            x = self.norm2(x)
-        if self.concat_after:
-            x_concat = torch.cat(
-                (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
-            )
-            x = residual + self.concat_linear2(x_concat)
-        else:
-            x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
-        if not self.normalize_before:
-            x = self.norm2(x)
-        residual = x
-
-        if dn!=None:
-            x = x + self.spk_linear(dn)
-        if self.normalize_before:
-            x = self.norm3(x)
-        
-        x = residual + self.dropout(self.feed_forward(x))
-        if not self.normalize_before:
-            x = self.norm3(x)
-
-        if cache is not None:
-            x = torch.cat([cache, x], dim=1)
-
-        return x, tgt_mask, memory, memory_mask
-
-
-
diff --git a/funasr/models/decoder/transformer_decoder_sa_asr.py b/funasr/models/decoder/transformer_decoder_sa_asr.py
deleted file mode 100644
index 949f9c8..0000000
--- a/funasr/models/decoder/transformer_decoder_sa_asr.py
+++ /dev/null
@@ -1,291 +0,0 @@
-from typing import Any
-from typing import List
-from typing import Sequence
-from typing import Tuple
-
-import torch
-from typeguard import check_argument_types
-
-from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.attention import MultiHeadedAttention
-from funasr.modules.attention import CosineDistanceAttention
-from funasr.models.decoder.transformer_decoder import DecoderLayer
-from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeAsrDecoderFirstLayer
-from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeSpkDecoderFirstLayer
-from funasr.modules.dynamic_conv import DynamicConvolution
-from funasr.modules.dynamic_conv2d import DynamicConvolution2D
-from funasr.modules.embedding import PositionalEncoding
-from funasr.modules.layer_norm import LayerNorm
-from funasr.modules.lightconv import LightweightConvolution
-from funasr.modules.lightconv2d import LightweightConvolution2D
-from funasr.modules.mask import subsequent_mask
-from funasr.modules.positionwise_feed_forward import (
-    PositionwiseFeedForward,  # noqa: H301
-)
-from funasr.modules.repeat import repeat
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-from funasr.models.decoder.abs_decoder import AbsDecoder
-
-class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
-    
-    def __init__(
-        self,
-        vocab_size: int,
-        encoder_output_size: int,
-        spker_embedding_dim: int = 256,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        input_layer: str = "embed",
-        use_asr_output_layer: bool = True,
-        use_spk_output_layer: bool = True,
-        pos_enc_class=PositionalEncoding,
-        normalize_before: bool = True,
-    ):
-        assert check_argument_types()
-        super().__init__()
-        attention_dim = encoder_output_size
-
-        if input_layer == "embed":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Embedding(vocab_size, attention_dim),
-                pos_enc_class(attention_dim, positional_dropout_rate),
-            )
-        elif input_layer == "linear":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(vocab_size, attention_dim),
-                torch.nn.LayerNorm(attention_dim),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                pos_enc_class(attention_dim, positional_dropout_rate),
-            )
-        else:
-            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
-
-        self.normalize_before = normalize_before
-        if self.normalize_before:
-            self.after_norm = LayerNorm(attention_dim)
-        if use_asr_output_layer:
-            self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
-        else:
-            self.asr_output_layer = None
-
-        if use_spk_output_layer:
-            self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
-        else:
-            self.spk_output_layer = None
-
-        self.cos_distance_att = CosineDistanceAttention()
-
-        self.decoder1 = None
-        self.decoder2 = None
-        self.decoder3 = None
-        self.decoder4 = None
-
-    def forward(
-        self,
-        asr_hs_pad: torch.Tensor,
-        spk_hs_pad: torch.Tensor,
-        hlens: torch.Tensor,
-        ys_in_pad: torch.Tensor,
-        ys_in_lens: torch.Tensor,
-        profile: torch.Tensor,
-        profile_lens: torch.Tensor,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        
-        tgt = ys_in_pad
-        # tgt_mask: (B, 1, L)
-        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
-        # m: (1, L, L)
-        m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
-        # tgt_mask: (B, L, L)
-        tgt_mask = tgt_mask & m
-
-        asr_memory = asr_hs_pad
-        spk_memory = spk_hs_pad
-        memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
-        # Spk decoder
-        x = self.embed(tgt)
-
-        x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
-            x, tgt_mask, asr_memory, spk_memory, memory_mask
-        )
-        x, tgt_mask, spk_memory, memory_mask = self.decoder2(
-            x, tgt_mask, spk_memory, memory_mask
-        )
-        if self.normalize_before:
-            x = self.after_norm(x)
-        if self.spk_output_layer is not None:
-            x = self.spk_output_layer(x)
-        dn, weights = self.cos_distance_att(x, profile, profile_lens)
-        # Asr decoder
-        x, tgt_mask, asr_memory, memory_mask = self.decoder3(
-            z, tgt_mask, asr_memory, memory_mask, dn
-        )
-        x, tgt_mask, asr_memory, memory_mask = self.decoder4(
-            x, tgt_mask, asr_memory, memory_mask
-        )
-
-        if self.normalize_before:
-            x = self.after_norm(x)
-        if self.asr_output_layer is not None:
-            x = self.asr_output_layer(x)
-
-        olens = tgt_mask.sum(1)
-        return x, weights, olens
-
-
-    def forward_one_step(
-        self,
-        tgt: torch.Tensor,
-        tgt_mask: torch.Tensor,
-        asr_memory: torch.Tensor,
-        spk_memory: torch.Tensor,
-        profile: torch.Tensor,
-        cache: List[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
-        
-        x = self.embed(tgt)
-
-        if cache is None:
-            cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
-        new_cache = []
-        x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
-                x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
-        )
-        new_cache.append(x)
-        for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
-            x, tgt_mask, spk_memory, _ = decoder(
-                x, tgt_mask, spk_memory, None, cache=c
-            )
-            new_cache.append(x)
-        if self.normalize_before:
-            x = self.after_norm(x)
-        else:
-            x = x
-        if self.spk_output_layer is not None:
-            x = self.spk_output_layer(x)
-        dn, weights = self.cos_distance_att(x, profile, None)
-
-        x, tgt_mask, asr_memory, _ = self.decoder3(
-            z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
-        )
-        new_cache.append(x)
-
-        for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
-            x, tgt_mask, asr_memory, _ = decoder(
-                x, tgt_mask, asr_memory, None, cache=c
-            )
-            new_cache.append(x)
-
-        if self.normalize_before:
-            y = self.after_norm(x[:, -1])
-        else:
-            y = x[:, -1]
-        if self.asr_output_layer is not None:
-            y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
-
-        return y, weights, new_cache
-
-    def score(self, ys, state, asr_enc, spk_enc, profile):
-        """Score."""
-        ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
-        logp, weights, state = self.forward_one_step(
-            ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
-        )
-        return logp.squeeze(0), weights.squeeze(), state
-
-class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
-    def __init__(
-        self,
-        vocab_size: int,
-        encoder_output_size: int,
-        spker_embedding_dim: int = 256,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        asr_num_blocks: int = 6,
-        spk_num_blocks: int = 3,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        self_attention_dropout_rate: float = 0.0,
-        src_attention_dropout_rate: float = 0.0,
-        input_layer: str = "embed",
-        use_asr_output_layer: bool = True,
-        use_spk_output_layer: bool = True,
-        pos_enc_class=PositionalEncoding,
-        normalize_before: bool = True,
-        concat_after: bool = False,
-    ):
-        assert check_argument_types()
-        super().__init__(
-            vocab_size=vocab_size,
-            encoder_output_size=encoder_output_size,
-            spker_embedding_dim=spker_embedding_dim,
-            dropout_rate=dropout_rate,
-            positional_dropout_rate=positional_dropout_rate,
-            input_layer=input_layer,
-            use_asr_output_layer=use_asr_output_layer,
-            use_spk_output_layer=use_spk_output_layer,
-            pos_enc_class=pos_enc_class,
-            normalize_before=normalize_before,
-        )
-
-        attention_dim = encoder_output_size
-
-        self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
-            attention_dim,
-            MultiHeadedAttention(
-                attention_heads, attention_dim, self_attention_dropout_rate
-            ),
-            MultiHeadedAttention(
-                attention_heads, attention_dim, src_attention_dropout_rate
-            ),
-            PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
-            dropout_rate,
-            normalize_before,
-            concat_after,
-        )
-        self.decoder2 = repeat(
-            spk_num_blocks - 1,
-            lambda lnum: DecoderLayer(
-                attention_dim,
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, self_attention_dropout_rate
-                ),
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, src_attention_dropout_rate
-                ),
-                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        
-        
-        self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
-            attention_dim,
-            spker_embedding_dim,
-            MultiHeadedAttention(
-                attention_heads, attention_dim, src_attention_dropout_rate
-            ),
-            PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
-            dropout_rate,
-            normalize_before,
-            concat_after,
-        )
-        self.decoder4 = repeat(
-            asr_num_blocks - 1,
-            lambda lnum: DecoderLayer(
-                attention_dim,
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, self_attention_dropout_rate
-                ),
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, src_attention_dropout_rate
-                ),
-                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )

--
Gitblit v1.9.1