From f022d1b84455ffedd846882b98c215d9d6cf4369 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 二月 2023 15:22:11 +0800
Subject: [PATCH] Merge pull request #123 from alibaba-damo-academy/dev_zly
---
funasr/models/encoder/fsmn_encoder.py | 168 +++++++++++++++++++++++++++-----------------------------
1 files changed, 81 insertions(+), 87 deletions(-)
diff --git a/funasr/models/encoder/fsmn_encoder.py b/funasr/models/encoder/fsmn_encoder.py
index 643cefc..54a113d 100755
--- a/funasr/models/encoder/fsmn_encoder.py
+++ b/funasr/models/encoder/fsmn_encoder.py
@@ -1,55 +1,50 @@
+from typing import Tuple, Dict
+import copy
+
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-from typing import Tuple
-
-
class LinearTransform(nn.Module):
- def __init__(self, input_dim, output_dim, quantize=0):
+ def __init__(self, input_dim, output_dim):
super(LinearTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim, bias=False)
- self.quantize = quantize
- self.quant = torch.quantization.QuantStub()
- self.dequant = torch.quantization.DeQuantStub()
def forward(self, input):
- if self.quantize:
- output = self.quant(input)
- else:
- output = input
- output = self.linear(output)
- if self.quantize:
- output = self.dequant(output)
+ output = self.linear(input)
return output
class AffineTransform(nn.Module):
- def __init__(self, input_dim, output_dim, quantize=0):
+ def __init__(self, input_dim, output_dim):
super(AffineTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
- self.quantize = quantize
self.linear = nn.Linear(input_dim, output_dim)
- self.quant = torch.quantization.QuantStub()
- self.dequant = torch.quantization.DeQuantStub()
def forward(self, input):
- if self.quantize:
- output = self.quant(input)
- else:
- output = input
- output = self.linear(output)
- if self.quantize:
- output = self.dequant(output)
+ output = self.linear(input)
return output
+
+
+class RectifiedLinear(nn.Module):
+
+ def __init__(self, input_dim, output_dim):
+ super(RectifiedLinear, self).__init__()
+ self.dim = input_dim
+ self.relu = nn.ReLU()
+ self.dropout = nn.Dropout(0.1)
+
+ def forward(self, input):
+ out = self.relu(input)
+ return out
class FSMNBlock(nn.Module):
@@ -62,7 +57,6 @@
rorder=None,
lstride=1,
rstride=1,
- quantize=0
):
super(FSMNBlock, self).__init__()
@@ -84,71 +78,75 @@
self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
else:
self.conv_right = None
- self.quantize = quantize
- self.quant = torch.quantization.QuantStub()
- self.dequant = torch.quantization.DeQuantStub()
- def forward(self, input):
+ def forward(self, input: torch.Tensor, in_cache=None):
x = torch.unsqueeze(input, 1)
- x_per = x.permute(0, 3, 2, 1)
-
- y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
- if self.quantize:
- y_left = self.quant(y_left)
+ x_per = x.permute(0, 3, 2, 1) # B D T C
+ if in_cache is None: # offline
+ y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
+ else:
+ y_left = torch.cat((in_cache, x_per), dim=2)
+ in_cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
y_left = self.conv_left(y_left)
- if self.quantize:
- y_left = self.dequant(y_left)
out = x_per + y_left
if self.conv_right is not None:
+ # maybe need to check
y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
- if self.quantize:
- y_right = self.quant(y_right)
y_right = self.conv_right(y_right)
- if self.quantize:
- y_right = self.dequant(y_right)
out += y_right
out_per = out.permute(0, 3, 2, 1)
output = out_per.squeeze(1)
- return output
+ return output, in_cache
-class RectifiedLinear(nn.Module):
+class BasicBlock(nn.Sequential):
+ def __init__(self,
+ linear_dim: int,
+ proj_dim: int,
+ lorder: int,
+ rorder: int,
+ lstride: int,
+ rstride: int,
+ stack_layer: int
+ ):
+ super(BasicBlock, self).__init__()
+ self.lorder = lorder
+ self.rorder = rorder
+ self.lstride = lstride
+ self.rstride = rstride
+ self.stack_layer = stack_layer
+ self.linear = LinearTransform(linear_dim, proj_dim)
+ self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride)
+ self.affine = AffineTransform(proj_dim, linear_dim)
+ self.relu = RectifiedLinear(linear_dim, linear_dim)
- def __init__(self, input_dim, output_dim):
- super(RectifiedLinear, self).__init__()
- self.dim = input_dim
- self.relu = nn.ReLU()
- self.dropout = nn.Dropout(0.1)
-
- def forward(self, input):
- out = self.relu(input)
- # out = self.dropout(out)
- return out
+ def forward(self, input: torch.Tensor, in_cache=None):
+ x1 = self.linear(input) # B T D
+ if in_cache is not None: # Dict[str, tensor.Tensor]
+ cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+ if cache_layer_name not in in_cache:
+ in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
+ x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
+ else:
+ x2, _ = self.fsmn_block(x1)
+ x3 = self.affine(x2)
+ x4 = self.relu(x3)
+ return x4, in_cache
-def _build_repeats(
- fsmn_layers: int,
- linear_dim: int,
- proj_dim: int,
- lorder: int,
- rorder: int,
- lstride=1,
- rstride=1,
-):
- repeats = [
- nn.Sequential(
- LinearTransform(linear_dim, proj_dim),
- FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1),
- AffineTransform(proj_dim, linear_dim),
- RectifiedLinear(linear_dim, linear_dim))
- for i in range(fsmn_layers)
- ]
+class FsmnStack(nn.Sequential):
+ def __init__(self, *args):
+ super(FsmnStack, self).__init__(*args)
- return nn.Sequential(*repeats)
+ def forward(self, input: torch.Tensor, in_cache=None):
+ x = input
+ for module in self._modules.values():
+ x, in_cache = module(x, in_cache)
+ return x
'''
@@ -177,6 +175,7 @@
rstride: int,
output_affine_dim: int,
output_dim: int,
+ streaming=False
):
super(FSMN, self).__init__()
@@ -185,23 +184,16 @@
self.fsmn_layers = fsmn_layers
self.linear_dim = linear_dim
self.proj_dim = proj_dim
- self.lorder = lorder
- self.rorder = rorder
- self.lstride = lstride
- self.rstride = rstride
self.output_affine_dim = output_affine_dim
self.output_dim = output_dim
+ self.in_cache_original = dict() if streaming else None
+ self.in_cache = copy.deepcopy(self.in_cache_original)
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim)
-
- self.fsmn = _build_repeats(fsmn_layers,
- linear_dim,
- proj_dim,
- lorder, rorder,
- lstride, rstride)
-
+ self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
+ range(fsmn_layers)])
self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
self.softmax = nn.Softmax(dim=-1)
@@ -209,27 +201,29 @@
def fuse_modules(self):
pass
+ def cache_reset(self):
+ self.in_cache = copy.deepcopy(self.in_cache_original)
+
def forward(
self,
input: torch.Tensor,
- in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
input (torch.Tensor): Input tensor (B, T, D)
- in_cache(torhc.Tensor): (B, D, C), C is the accumulated cache size
+ in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
+ {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
"""
x1 = self.in_linear1(input)
x2 = self.in_linear2(x1)
x3 = self.relu(x2)
- x4 = self.fsmn(x3)
+ x4 = self.fsmn(x3, self.in_cache) # if in_cache is not None, self.fsmn is streaming's format, it will update automatically in self.fsmn
x5 = self.out_linear1(x4)
x6 = self.out_linear2(x5)
x7 = self.softmax(x6)
return x7
- # return x6, in_cache
'''
--
Gitblit v1.9.1