From 31eed1834f9ff17d6246008f64d3e061f58ef80a Mon Sep 17 00:00:00 2001
From: 凌匀 <ailsa.zly@alibaba-inc.com>
Date: 星期一, 27 二月 2023 13:33:55 +0800
Subject: [PATCH] in_cache & support soundfile read
---
funasr/models/encoder/fsmn_encoder.py | 44 +++++++++++++++++---------------------------
1 files changed, 17 insertions(+), 27 deletions(-)
diff --git a/funasr/models/encoder/fsmn_encoder.py b/funasr/models/encoder/fsmn_encoder.py
index 54a113d..c749dc4 100755
--- a/funasr/models/encoder/fsmn_encoder.py
+++ b/funasr/models/encoder/fsmn_encoder.py
@@ -79,14 +79,12 @@
else:
self.conv_right = None
- def forward(self, input: torch.Tensor, in_cache=None):
+ def forward(self, input: torch.Tensor, cache: torch.Tensor):
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1) # B D T C
- if in_cache is None: # offline
- y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
- else:
- y_left = torch.cat((in_cache, x_per), dim=2)
- in_cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
+
+ y_left = torch.cat((cache, x_per), dim=2)
+ cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
y_left = self.conv_left(y_left)
out = x_per + y_left
@@ -100,7 +98,7 @@
out_per = out.permute(0, 3, 2, 1)
output = out_per.squeeze(1)
- return output, in_cache
+ return output, cache
class BasicBlock(nn.Sequential):
@@ -124,28 +122,25 @@
self.affine = AffineTransform(proj_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim)
- def forward(self, input: torch.Tensor, in_cache=None):
+ def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
x1 = self.linear(input) # B T D
- if in_cache is not None: # Dict[str, tensor.Tensor]
- cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
- if cache_layer_name not in in_cache:
- in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
- x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
- else:
- x2, _ = self.fsmn_block(x1)
+ cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+ if cache_layer_name not in in_cache:
+ in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
+ x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
x3 = self.affine(x2)
x4 = self.relu(x3)
- return x4, in_cache
+ return x4
class FsmnStack(nn.Sequential):
def __init__(self, *args):
super(FsmnStack, self).__init__(*args)
- def forward(self, input: torch.Tensor, in_cache=None):
+ def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
x = input
for module in self._modules.values():
- x, in_cache = module(x, in_cache)
+ x = module(x, in_cache)
return x
@@ -174,8 +169,7 @@
lstride: int,
rstride: int,
output_affine_dim: int,
- output_dim: int,
- streaming=False
+ output_dim: int
):
super(FSMN, self).__init__()
@@ -186,8 +180,6 @@
self.proj_dim = proj_dim
self.output_affine_dim = output_affine_dim
self.output_dim = output_dim
- self.in_cache_original = dict() if streaming else None
- self.in_cache = copy.deepcopy(self.in_cache_original)
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
@@ -201,12 +193,10 @@
def fuse_modules(self):
pass
- def cache_reset(self):
- self.in_cache = copy.deepcopy(self.in_cache_original)
-
def forward(
self,
input: torch.Tensor,
+ in_cache: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
@@ -218,7 +208,7 @@
x1 = self.in_linear1(input)
x2 = self.in_linear2(x1)
x3 = self.relu(x2)
- x4 = self.fsmn(x3, self.in_cache) # if in_cache is not None, self.fsmn is streaming's format, it will update automatically in self.fsmn
+ x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
x5 = self.out_linear1(x4)
x6 = self.out_linear2(x5)
x7 = self.softmax(x6)
@@ -307,4 +297,4 @@
print('input shape: {}'.format(x.shape))
print('output shape: {}'.format(y.shape))
- print(fsmn.to_kaldi_net())
+ print(fsmn.to_kaldi_net())
\ No newline at end of file
--
Gitblit v1.9.1