From 49e8e9d8fc1209c347aa2c2c65c6eb067b9f79d4 Mon Sep 17 00:00:00 2001
From: zhu-gu-an <76513567+zhu-gu-an@users.noreply.github.com>
Date: 星期六, 13 一月 2024 13:54:00 +0800
Subject: [PATCH] add triton paraformer large online (#1242)

---
 funasr/models/encoder/rwkv_encoder.py |   12 ++++++------
 1 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/funasr/models/encoder/rwkv_encoder.py b/funasr/models/encoder/rwkv_encoder.py
index 40151bf..dc1f207 100644
--- a/funasr/models/encoder/rwkv_encoder.py
+++ b/funasr/models/encoder/rwkv_encoder.py
@@ -113,12 +113,12 @@
         x = self.embed_norm(x)
         olens = mask.eq(0).sum(1)
 
-        # for training
-        # for block in self.rwkv_blocks:
-        #     x, _ = block(x)
-
-        # for streaming inference
-        x = self.rwkv_infer(x)
+        if self.training:
+            for block in self.rwkv_blocks:
+                x, _ = block(x)
+        else:
+            x = self.rwkv_infer(x)
+            
         x = self.final_norm(x)
 
         if self.time_reduction_factor > 1:

--
Gitblit v1.9.1