From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/llm_asr/adaptor.py | 72 ++++++++++++++++++++++++-----------
1 files changed, 49 insertions(+), 23 deletions(-)
diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py
index 9b79ed2..4348213 100644
--- a/funasr/models/llm_asr/adaptor.py
+++ b/funasr/models/llm_asr/adaptor.py
@@ -51,18 +51,40 @@
self.linear = nn.Linear(configuration.hidden_size, self.llm_dim)
self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
+
+ self.second_per_frame = 0.333333
+ self.second_stride = 0.333333
- def forward(self, x, atts):
- query = self.query.expand(x.shape[0], -1, -1)
+ def split_frames(self, speech_embeds):
+ B, T, C = speech_embeds.shape
+ kernel = round(T * self.second_per_frame / 30.0)
+ stride = round(T * self.second_stride / 30.0)
+ kernel = (1, kernel)
+ stride = (1, stride)
+ speech_embeds_tr = speech_embeds.transpose(1, 2).unsqueeze(2)
+ speech_embeds_overlap = torch.nn.functional.unfold(speech_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride)
+ _, _, L = speech_embeds_overlap.shape
+ speech_embeds_overlap = speech_embeds_overlap.view(B, -1, kernel[1], L)
+ speech_embeds_overlap = torch.permute(speech_embeds_overlap, [0, 3, 2, 1])
+ speech_embeds = speech_embeds_overlap.reshape(-1, kernel[1], C)
+ speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long, device=speech_embeds.device)
+ return speech_embeds, speech_atts
+ def forward(self, x):
+ B, T, C = x.size()
+ encoder_out_feat, attention_mask = self.split_frames(x)
+ query = self.query.expand(encoder_out_feat.shape[0], -1, -1)
+
+
query_output = self.qformer(
query_embeds=query,
- encoder_hidden_states=x,
- encoder_attention_mask=atts,
+ encoder_hidden_states=encoder_out_feat,
+ encoder_attention_mask=attention_mask,
return_dict=True,
)
query_proj = self.norm(self.linear(query_output.last_hidden_state))
+ query_proj = query_proj.view(B, -1, query_proj.size(2)).contiguous()
return query_proj
@@ -83,25 +105,27 @@
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
- self.blocks = nn.ModuleList(
- [
- EncoderLayer(
- llm_dim,
- MultiHeadedAttention(
- kwargs.get("attention_heads", 8),
+ self.blocks = None
+ if kwargs.get("n_layer", 2) > 0:
+ self.blocks = nn.ModuleList(
+ [
+ EncoderLayer(
llm_dim,
- kwargs.get("attention_dropout_rate", 0.0),
- ),
- PositionwiseFeedForward(
- llm_dim,
- llm_dim // 4,
+ MultiHeadedAttention(
+ kwargs.get("attention_heads", 8),
+ llm_dim,
+ kwargs.get("attention_dropout_rate", 0.0),
+ ),
+ PositionwiseFeedForward(
+ llm_dim,
+ llm_dim // 4,
+ kwargs.get("dropout_rate", 0.0),
+ ),
kwargs.get("dropout_rate", 0.0),
- ),
- kwargs.get("dropout_rate", 0.0),
- )
- for i in range(kwargs.get("n_layer", 2))
- ]
- )
+ )
+ for i in range(kwargs.get("n_layer", 2))
+ ]
+ )
def forward(self, x, ilens=None):
@@ -123,6 +147,8 @@
olens = None
olens = (ilens - 1) // self.k + 1
masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
- for layer, block in enumerate(self.blocks):
- x, masks = block(x, masks)
+
+ if self.blocks is not None:
+ for layer, block in enumerate(self.blocks):
+ x, masks = block(x, masks)
return x, olens
--
Gitblit v1.9.1