From 30c40c643c19f6e2ac8679fa76d09d0f9ceccc65 Mon Sep 17 00:00:00 2001
From: chenmengzheAAA <123789350+chenmengzheAAA@users.noreply.github.com>
Date: 星期四, 14 九月 2023 18:00:43 +0800
Subject: [PATCH] Update modelscope_models.md
---
funasr/models/encoder/sanm_encoder.py | 85 +++++++++++++++++++++++++++++++++++-------
1 files changed, 70 insertions(+), 15 deletions(-)
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 2a68011..9e27d4a 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -8,7 +8,6 @@
import torch.nn as nn
import torch.nn.functional as F
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
-from typeguard import check_argument_types
import numpy as np
from funasr.torch_utils.device_funcs import to_device
from funasr.modules.nets_utils import make_pad_mask
@@ -27,9 +26,10 @@
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
+from funasr.modules.mask import subsequent_mask, vad_mask
+
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
def __init__(
@@ -146,11 +146,14 @@
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
+ lora_list: List[str] = None,
+ lora_rank: int = 8,
+ lora_alpha: int = 16,
+ lora_dropout: float = 0.1,
selfattention_layer_type: str = "sanm",
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
@@ -230,6 +233,10 @@
attention_dropout_rate,
kernel_size,
sanm_shfit,
+ lora_list,
+ lora_rank,
+ lora_alpha,
+ lora_dropout,
)
encoder_selfattn_layer_args = (
@@ -239,6 +246,10 @@
attention_dropout_rate,
kernel_size,
sanm_shfit,
+ lora_list,
+ lora_rank,
+ lora_alpha,
+ lora_dropout,
)
self.encoders0 = repeat(
1,
@@ -354,18 +365,9 @@
def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
if len(cache) == 0:
return feats
- # process last chunk
cache["feats"] = to_device(cache["feats"], device=feats.device)
overlap_feats = torch.cat((cache["feats"], feats), dim=1)
- if cache["is_final"]:
- cache["feats"] = overlap_feats[:, -cache["chunk_size"][0]:, :]
- if not cache["last_chunk"]:
- padding_length = sum(cache["chunk_size"]) - overlap_feats.shape[1]
- overlap_feats = overlap_feats.transpose(1, 2)
- overlap_feats = F.pad(overlap_feats, (0, padding_length))
- overlap_feats = overlap_feats.transpose(1, 2)
- else:
- cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+ cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
return overlap_feats
def forward_chunk(self,
@@ -609,7 +611,6 @@
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
@@ -641,6 +642,8 @@
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
+ elif input_layer == "pe_online":
+ self.embed = StreamSinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
@@ -825,6 +828,59 @@
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
+
+ def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
+ if len(cache) == 0:
+ return feats
+ cache["feats"] = to_device(cache["feats"], device=feats.device)
+ overlap_feats = torch.cat((cache["feats"], feats), dim=1)
+ cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+ return overlap_feats
+
+ def forward_chunk(self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ cache: dict = None,
+ ctc: CTC = None,
+ ):
+ xs_pad *= self.output_size() ** 0.5
+ if self.embed is None:
+ xs_pad = xs_pad
+ else:
+ xs_pad = self.embed(xs_pad, cache)
+ if cache["tail_chunk"]:
+ xs_pad = to_device(cache["feats"], device=xs_pad.device)
+ else:
+ xs_pad = self._add_overlap_chunk(xs_pad, cache)
+ encoder_outs = self.encoders0(xs_pad, None, None, None, None)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ intermediate_outs = []
+ if len(self.interctc_layer_idx) == 0:
+ encoder_outs = self.encoders(xs_pad, None, None, None, None)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ else:
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ encoder_outs = encoder_layer(xs_pad, None, None, None, None)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ if layer_idx + 1 in self.interctc_layer_idx:
+ encoder_out = xs_pad
+
+ # intermediate outputs are also normalized
+ if self.normalize_before:
+ encoder_out = self.after_norm(encoder_out)
+
+ intermediate_outs.append((layer_idx + 1, encoder_out))
+
+ if self.interctc_use_conditioning:
+ ctc_out = ctc.softmax(encoder_out)
+ xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+
+ if len(intermediate_outs) > 0:
+ return (xs_pad, intermediate_outs), None, None
+ return xs_pad, ilens, None
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
@@ -1013,7 +1069,6 @@
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
--
Gitblit v1.9.1