From fc68b5ffe453235294a561737d8e84bb6c1689a4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 四月 2024 21:43:47 +0800
Subject: [PATCH] Dev gzf exp (#1661)

---
 funasr/models/conformer_rwkv/decoder.py                         |   46 +
 funasr/models/sense_voice/decoder.py                            |   37 +
 wandb/run-20240425_211446-lkqptn01/logs/debug.log               |   29 +
 wandb/run-20240425_211446-lkqptn01/run-lkqptn01.wandb           |    0 
 funasr/datasets/audio_datasets/espnet_samplers.py               |    9 
 funasr/models/sense_voice/rwkv_v4.py                            |  412 +++++++++++++++
 funasr/models/sense_voice/rwkv_v5.py                            |  597 ++++++++++++++++++++++
 funasr/models/sense_voice/rwkv_v6.py                            |   19 
 funasr/train_utils/trainer.py                                   |   55 +
 wandb/debug.log                                                 |    1 
 examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py |    3 
 funasr/datasets/audio_datasets/index_ds.py                      |    4 
 funasr/datasets/dataloader_entry.py                             |   20 
 funasr/train_utils/model_summary.py                             |    2 
 funasr/models/sense_voice/cuda/wkv_cuda.cu                      |  125 ++++
 wandb/run-20240425_211446-lkqptn01/logs/debug-internal.log      |  146 +++++
 funasr/bin/train.py                                             |   24 
 wandb/run-20240425_211446-lkqptn01/files/config.yaml            |   26 
 wandb/latest-run                                                |    1 
 funasr/models/sense_voice/cuda/wkv_op.cpp                       |   21 
 wandb/debug-internal.log                                        |    1 
 21 files changed, 1,527 insertions(+), 51 deletions(-)

diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
index b21348f..0f30a37 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -9,6 +9,9 @@
 
 model = AutoModel(model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch")
 
+mm = model.model
+for p in mm.parameters():
+    print(f"{p.numel()}")
 res = model.generate(input=wav_file)
 print(res)
 # [[beg1, end1], [beg2, end2], .., [begN, endN]]
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index eb1611a..448e464 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -99,7 +99,7 @@
     if freeze_param is not None:
         if "," in freeze_param:
             freeze_param = eval(freeze_param)
-        if isinstance(freeze_param, Sequence):
+        if not isinstance(freeze_param, Sequence):
             freeze_param = (freeze_param,)
         logging.info("freeze_param is not None: %s", freeze_param)
         for t in freeze_param:
@@ -107,8 +107,9 @@
                 if k.startswith(t + ".") or k == t:
                     logging.info(f"Setting {k}.requires_grad = False")
                     p.requires_grad = False
+    if local_rank == 0:
+        logging.info(f"{model_summary(model)}")
 
-    logging.info(f"model info: {model_summary(model)}")
     if use_ddp:
         model = model.cuda(local_rank)
         model = DDP(
@@ -145,8 +146,6 @@
     else:
         model = model.to(device=kwargs.get("device", "cuda"))
 
-    if local_rank == 0:
-        logging.info(f"{model}")
     kwargs["device"] = next(model.parameters()).device
 
     # optim
@@ -182,7 +181,12 @@
     scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
     scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
 
-    trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+    trainer.resume_checkpoint(
+        model=model,
+        optim=optim,
+        scheduler=scheduler,
+        scaler=scaler,
+    )
 
     tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
     os.makedirs(tensorboard_dir, exist_ok=True)
@@ -197,8 +201,11 @@
     for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
         time1 = time.perf_counter()
 
-        for data_split_i in range(dataloader.data_split_num):
-            dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i)
+        for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
+            dataloader_tr, dataloader_val = dataloader.build_iter(
+                epoch, data_split_i=data_split_i, start_step=trainer.start_step
+            )
+            trainer.start_step = 0
             trainer.train_epoch(
                 model=model,
                 optim=optim,
@@ -211,9 +218,8 @@
                 data_split_i=data_split_i,
                 data_split_num=dataloader.data_split_num,
             )
-            
-            torch.cuda.empty_cache()
 
+            torch.cuda.empty_cache()
 
         trainer.validate_epoch(
             model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py
index 3f14d09..cb30a28 100644
--- a/funasr/datasets/audio_datasets/espnet_samplers.py
+++ b/funasr/datasets/audio_datasets/espnet_samplers.py
@@ -41,6 +41,7 @@
         drop_last=False,
         is_training: bool = True,
         sort_size: int = 1024,
+        start_step: int = 0,
         **kwargs,
     ):
 
@@ -70,7 +71,9 @@
         self.max_token_length = kwargs.get("max_token_length", 2048)
         self.min_token_length = kwargs.get("min_token_length", 0)
         self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-
+        self.start_step = 0
+        if self.start_step > 0:
+            logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}")
         # super().__init__(dataset, num_replicas=num_replicas, rank=rank,
         #                  shuffle=shuffle, drop_last=drop_last)
 
@@ -92,7 +95,7 @@
         max_len_in_batch = 0  # Tracks the max sample length within the current batch
 
         for idx in sorted_indices:
-          
+
             # original_sample_length = self.dataset.get_source_len(idx)
             # if (
             #     original_sample_length < self.min_token_length
@@ -142,7 +145,7 @@
         # Allocate the batches to the current rank
         start_idx = self.rank * batches_per_rank
         end_idx = start_idx + batches_per_rank
-        rank_batches = buffer_batches[start_idx:end_idx]
+        rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
 
         # Return an iterator over the batches for the current rank
         return iter(rank_batches)
diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
index da008b4..70581e8 100644
--- a/funasr/datasets/audio_datasets/index_ds.py
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -104,10 +104,10 @@
                             or target_len > self.max_target_length
                         ):
                             continue
-                            
+
                         if (source_len + target_len) > self.max_token_length:
                             continue
-                            
+
                         contents_i = {
                             "source": source,
                             "prompt": prompt,
diff --git a/funasr/datasets/dataloader_entry.py b/funasr/datasets/dataloader_entry.py
index 9595805..925b1d3 100644
--- a/funasr/datasets/dataloader_entry.py
+++ b/funasr/datasets/dataloader_entry.py
@@ -14,14 +14,14 @@
         frontend=frontend,
         tokenizer=tokenizer,
         is_training=True,
-        **kwargs.get("dataset_conf")
+        **kwargs.get("dataset_conf"),
     )
     dataset_val = dataset_class(
         kwargs.get("valid_data_set_list"),
         frontend=frontend,
         tokenizer=tokenizer,
         is_training=False,
-        **kwargs.get("dataset_conf")
+        **kwargs.get("dataset_conf"),
     )
 
     # dataloader
@@ -55,14 +55,14 @@
             frontend=frontend,
             tokenizer=tokenizer,
             is_training=True,
-            **kwargs.get("dataset_conf")
+            **kwargs.get("dataset_conf"),
         )
         dataset_val = dataset_class(
             kwargs.get("valid_data_set_list"),
             frontend=frontend,
             tokenizer=tokenizer,
             is_training=False,
-            **kwargs.get("dataset_conf")
+            **kwargs.get("dataset_conf"),
         )
 
         self.dataset_tr = dataset_tr
@@ -76,7 +76,7 @@
         self.tokenizer = tokenizer
         self.kwargs = kwargs
 
-    def build_iter(self, epoch=0, data_split_i=0, **kwargs):
+    def build_iter(self, epoch=0, data_split_i=0, start_step=0, **kwargs):
 
         # reload dataset slice
         if self.data_split_num > 1:
@@ -87,7 +87,7 @@
                 tokenizer=self.tokenizer,
                 is_training=True,
                 **self.kwargs.get("dataset_conf"),
-                data_split_i=data_split_i
+                data_split_i=data_split_i,
             )
 
         # dataloader
@@ -95,7 +95,9 @@
         batch_sampler_val = None
         if batch_sampler is not None:
             batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
-            batch_sampler = batch_sampler_class(self.dataset_tr, **self.kwargs.get("dataset_conf"))
+            batch_sampler = batch_sampler_class(
+                self.dataset_tr, start_step=start_step, **self.kwargs.get("dataset_conf")
+            )
             batch_sampler_val = batch_sampler_class(
                 self.dataset_val, is_training=False, **self.kwargs.get("dataset_conf")
             )
@@ -121,14 +123,14 @@
         frontend=frontend,
         tokenizer=tokenizer,
         is_training=True,
-        **kwargs.get("dataset_conf")
+        **kwargs.get("dataset_conf"),
     )
     dataset_val = dataset_class(
         kwargs.get("valid_data_set_list"),
         frontend=frontend,
         tokenizer=tokenizer,
         is_training=False,
-        **kwargs.get("dataset_conf")
+        **kwargs.get("dataset_conf"),
     )
 
     return dataset_tr, dataset_val
diff --git a/funasr/models/conformer_rwkv/decoder.py b/funasr/models/conformer_rwkv/decoder.py
index 90e56e5..5e2ac12 100644
--- a/funasr/models/conformer_rwkv/decoder.py
+++ b/funasr/models/conformer_rwkv/decoder.py
@@ -29,6 +29,11 @@
 from funasr.register import tables
 
 
+class LayerNorm(nn.LayerNorm):
+    def forward(self, x):
+        return super().forward(x.float()).type(x.dtype)
+
+
 class DecoderLayer(nn.Module):
     """Single decoder layer module.
 
@@ -54,7 +59,7 @@
     def __init__(
         self,
         size,
-        self_attn,
+        # self_attn,
         src_attn,
         feed_forward,
         dropout_rate,
@@ -62,11 +67,12 @@
         concat_after=False,
         layer_id=None,
         args={},
+        **kwargs,
     ):
         """Construct an DecoderLayer object."""
         super(DecoderLayer, self).__init__()
         self.size = size
-        self.self_attn = self_attn.to(torch.bfloat16)
+        # self.self_attn = self_attn.to(torch.bfloat16)
         self.src_attn = src_attn
         self.feed_forward = feed_forward
         self.norm1 = LayerNorm(size)
@@ -79,6 +85,22 @@
             self.concat_linear1 = nn.Linear(size + size, size)
             self.concat_linear2 = nn.Linear(size + size, size)
         self.layer_id = layer_id
+
+        if args.get("version", "v4") == "v4":
+            from funasr.models.sense_voice.rwkv_v4 import RWKVLayer
+            from funasr.models.sense_voice.rwkv_v4 import RWKV_TimeMix as RWKV_Tmix
+        elif args.get("version", "v5") == "v5":
+            from funasr.models.sense_voice.rwkv_v5 import RWKVLayer
+            from funasr.models.sense_voice.rwkv_v5 import RWKV_Tmix_x052 as RWKV_Tmix
+        else:
+            from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+            from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix
+        # self.attn = RWKVLayer(args=args, layer_id=layer_id)
+        self.self_attn = RWKV_Tmix(args, layer_id=layer_id)
+        if args.get("datatype", "bf16") == "bf16":
+            self.self_attn.to(torch.bfloat16)
+            # self.norm1.to(torch.bfloat16)
+        self.args = args
         self.ln0 = None
         if self.layer_id == 0 and not args.get("ln0", True):
             self.ln0 = LayerNorm(args.n_embd)
@@ -93,7 +115,15 @@
             print("init_rwkv")
             scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
             nn.init.constant_(self.norm1.weight, scale)
-            nn.init.constant_(self.self_attn.ln2.weight, scale)
+            # nn.init.constant_(self.self_attn.ln2.weight, scale)
+
+        if args.get("init_rwkv", True):
+            print("init_rwkv")
+            nn.init.orthogonal_(self.self_attn.receptance.weight, gain=1)
+            nn.init.orthogonal_(self.self_attn.key.weight, gain=0.1)
+            nn.init.orthogonal_(self.self_attn.value.weight, gain=1)
+            nn.init.orthogonal_(self.self_attn.gate.weight, gain=0.1)
+            nn.init.zeros_(self.self_attn.output.weight)
 
     def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
         """Compute decoded features.
@@ -117,6 +147,8 @@
         if self.layer_id == 0 and self.ln0 is not None:
             tgt = self.ln0(tgt)
 
+        if self.args.get("datatype", "bf16") == "bf16":
+            tgt = tgt.bfloat16()
         residual = tgt
 
         tgt = self.norm1(tgt)
@@ -132,7 +164,8 @@
 
             x = residual + self.dropout(self.self_attn(tgt, mask=tgt_q_mask))
             x = x[:, -1, :]
-
+        if self.args.get("datatype", "bf16") == "bf16":
+            x = x.to(torch.float32)
         # x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
 
         residual = x
@@ -370,17 +403,16 @@
             pos_enc_class=pos_enc_class,
             normalize_before=normalize_before,
         )
-        from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+        # from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
 
         rwkv_cfg = kwargs.get("rwkv_cfg", {})
         args = OmegaConf.create(rwkv_cfg)
-        # self.attn = RWKVLayer(args=args, layer_id=layer_id)
+
         attention_dim = encoder_output_size
         self.decoders = repeat(
             num_blocks,
             lambda lnum: DecoderLayer(
                 attention_dim,
-                RWKVLayer(args=args, layer_id=lnum),
                 MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
                 PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                 dropout_rate,
diff --git a/funasr/models/sense_voice/cuda/wkv_cuda.cu b/funasr/models/sense_voice/cuda/wkv_cuda.cu
new file mode 100644
index 0000000..6acd0f3
--- /dev/null
+++ b/funasr/models/sense_voice/cuda/wkv_cuda.cu
@@ -0,0 +1,125 @@
+#include <stdio.h>
+#include <assert.h>
+
+#define MIN_VALUE (-1e38)
+
+template <typename F>
+__global__ void kernel_forward(const int B, const int T, const int C,
+                               const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
+                               F *__restrict__ const _y) {
+    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    const int _b = idx / C;
+    const int _c = idx % C;
+    const int _offset = _b * T * C + _c;
+
+    F u = _u[_c];
+    F w = _w[_c];
+    const F *__restrict__ const k = _k + _offset;
+    const F *__restrict__ const v = _v + _offset;
+    F *__restrict__ const y = _y + _offset;
+
+    F p = 0, q = 0, o = MIN_VALUE;
+    // p and q are running sums divided by exp(o) (to avoid overflows)
+    for (int i = 0; i < T; i++) {
+        const int ii = i * C;
+
+        F no = max(o, u + k[ii]);
+        F A = exp(o - no);
+        F B = exp(u + k[ii] - no);
+        y[ii] = (A * p + B * v[ii]) / (A * q + B);
+
+        no = max(w + o, k[ii]);
+        A = exp(w + o - no);
+        B = exp(k[ii] - no);
+        p = A * p + B * v[ii];
+        q = A * q + B;
+        o = no;
+    }
+}
+
+template <typename F>
+__global__ void kernel_backward(const int B, const int T, const int C,
+                                const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
+                                F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
+    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    const int _b = idx / C;
+    const int _c = idx % C;
+    const int _offset = _b * T * C + _c;
+
+    F u = _u[_c];
+    F w = _w[_c];
+    const F *__restrict__ const k = _k + _offset;
+    const F *__restrict__ const v = _v + _offset;
+    const F *__restrict__ const gy = _gy + _offset;
+
+    F *__restrict__ const gk = _gk + _offset;
+    F *__restrict__ const gv = _gv + _offset;
+
+    F y[Tmax], z[Tmax], zexp[Tmax];
+
+    F gw = 0, gu = 0;
+    F p = 0, q = 0;
+    F dpdw = 0, dqdw = 0;
+    F o = MIN_VALUE;
+    for (int i = 0; i < T; i++) {
+        const int ii = i * C;
+        F no = max(o, k[ii] + u);
+        F A = exp(o - no);
+        F B = exp(k[ii] + u - no);
+
+        F num = A * p + B * v[ii];
+        F iden = 1 / (A * q + B);
+
+        y[i] = num * iden;
+        z[i] = iden;
+        zexp[i] = k[ii] + u - no;
+
+        gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
+        gu += gy[ii] * (v[ii] - y[i]) * B * iden;
+
+        no = max(w + o, k[ii]);
+        A = exp(w + o - no);
+        B = exp(k[ii] - no);
+        dpdw = A * (p + dpdw);
+        dqdw = A * (q + dqdw);
+        p = A * p + B * v[ii];
+        q = A * q + B;
+        o = no;
+    }
+
+    F gp = 0, gq = 0;
+    o = MIN_VALUE;
+    for (int i = T - 1; i >= 0; i--) {
+        const int ii = i * C;
+        F A = gy[ii] * z[i] * exp(zexp[i]);
+        F B = exp(k[ii] + o);
+        gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
+        gv[ii] = A + B * gp;
+
+        F no = max(w + o, zexp[i] - k[ii] - u);
+        A = exp(w + o - no);
+        B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
+        gp = A * gp + B;
+        gq = A * gq - B * y[i];
+        o = no;
+    }
+
+    // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
+    const int _offsetBC = _b * C + _c;
+    _gw[_offsetBC] += gw * _w[_c];
+    _gu[_offsetBC] += gu;
+}
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
+    dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+    assert(B * C % threadsPerBlock.x == 0);
+    dim3 numBlocks(B * C / threadsPerBlock.x);
+    kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
+}
+
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
+    dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+    assert(B * C % threadsPerBlock.x == 0);
+    dim3 numBlocks(B * C / threadsPerBlock.x);
+    kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
+}
diff --git a/funasr/models/sense_voice/cuda/wkv_op.cpp b/funasr/models/sense_voice/cuda/wkv_op.cpp
new file mode 100644
index 0000000..efe56d8
--- /dev/null
+++ b/funasr/models/sense_voice/cuda/wkv_op.cpp
@@ -0,0 +1,21 @@
+#include <torch/extension.h>
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
+
+void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
+    cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
+}
+void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
+    cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("forward", &forward, "wkv forward");
+    m.def("backward", &backward, "wkv backward");
+}
+
+TORCH_LIBRARY(wkv, m) {
+    m.def("forward", forward);
+    m.def("backward", backward);
+}
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
index 5d34ff3..133508f 100644
--- a/funasr/models/sense_voice/decoder.py
+++ b/funasr/models/sense_voice/decoder.py
@@ -156,7 +156,6 @@
         return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
 
 
-from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
 from omegaconf import OmegaConf
 
 
@@ -168,9 +167,25 @@
 
         rwkv_cfg = kwargs.get("rwkv_cfg", {})
         args = OmegaConf.create(rwkv_cfg)
-        self.attn = RWKVLayer(args=args, layer_id=layer_id)
-        if args.get("datatype", "bf16") == "bf16":
-            self.attn.to(torch.bfloat16)
+        if args.get("version", "v4") == "v4":
+            from funasr.models.sense_voice.rwkv_v4 import RWKVLayer
+            from funasr.models.sense_voice.rwkv_v4 import RWKV_TimeMix as RWKV_Tmix
+        elif args.get("version", "v5") == "v5":
+            from funasr.models.sense_voice.rwkv_v5 import RWKVLayer
+            from funasr.models.sense_voice.rwkv_v5 import RWKV_Tmix_x052 as RWKV_Tmix
+        else:
+            from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+            from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix
+        # self.att = RWKVLayer(args=args, layer_id=layer_id)
+        self.att = RWKV_Tmix(args, layer_id=layer_id)
+
+        if args.get("init_rwkv", True):
+            print("init_rwkv")
+            nn.init.orthogonal_(self.att.receptance.weight, gain=1)
+            nn.init.orthogonal_(self.att.key.weight, gain=0.1)
+            nn.init.orthogonal_(self.att.value.weight, gain=1)
+            nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
+            nn.init.zeros_(self.att.output.weight)
 
         self.ln0 = None
         if layer_id == 0 and not args.get("ln0", True):
@@ -180,6 +195,7 @@
                 layer_id = 0
                 scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
                 nn.init.constant_(self.ln0.weight, scale)
+
         self.layer_id = layer_id
         self.args = args
 
@@ -191,6 +207,11 @@
                 print("init_rwkv")
                 scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
                 nn.init.constant_(self.ln1.weight, scale)
+
+        if args.get("datatype", "bf16") == "bf16":
+            self.att.to(torch.bfloat16)
+            # if self.ln1 is not None:
+            #     self.ln1.to(torch.bfloat16)
 
         self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
         self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
@@ -213,10 +234,14 @@
         if self.layer_id == 0 and self.ln0 is not None:
             x = self.ln0(x)
 
+        if self.args.get("datatype", "bf16") == "bf16":
+            x = x.bfloat16()
         if self.ln1 is None:
-            x = x + self.attn(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+            x = x + self.att(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
         else:
-            x = x + self.attn(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+            x = x + self.att(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+        if self.args.get("datatype", "bf16") == "bf16":
+            x = x.to(torch.float32)
 
         if self.cross_attn:
             x = (
diff --git a/funasr/models/sense_voice/rwkv_v4.py b/funasr/models/sense_voice/rwkv_v4.py
new file mode 100644
index 0000000..c154ac0
--- /dev/null
+++ b/funasr/models/sense_voice/rwkv_v4.py
@@ -0,0 +1,412 @@
+########################################################################################################
+# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
+########################################################################################################
+
+import os, math, gc, importlib
+import torch
+
+# torch._C._jit_set_profiling_executor(True)
+# torch._C._jit_set_profiling_mode(True)
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+def __nop(ob):
+    return ob
+
+
+MyModule = nn.Module
+MyFunction = __nop
+if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1":
+    MyModule = torch.jit.ScriptModule
+    MyFunction = torch.jit.script_method
+
+########################################################################################################
+# CUDA Kernel
+########################################################################################################
+
+wkv_cuda = None
+
+
+def load_rwkv_kernel(
+    HEAD_SIZE: int = 64,
+    RWKV_CTXLEN: int = 512,
+    T_MAX: int = 512,
+):
+    from torch.utils.cpp_extension import load
+
+    global wkv_cuda
+
+    if wkv_cuda is not None:
+        return
+
+    absolute_file_path = os.path.abspath(__file__)
+    cur_dir = os.path.dirname(absolute_file_path)
+    wkv_cuda = load(
+        name="wkv",
+        sources=[f"{cur_dir}/cuda/wkv_op.cpp", f"{cur_dir}/cuda/wkv_cuda.cu"],
+        verbose=True,
+        extra_cuda_cflags=[
+            "-res-usage",
+            "--maxrregcount 60",
+            "--use_fast_math",
+            "-O3",
+            "-Xptxas -O3",
+            f"-DTmax={T_MAX}",
+        ],
+    )
+
+
+class WKV(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, B, T, C, w, u, k, v):
+        ctx.B = B
+        ctx.T = T
+        ctx.C = C
+        # assert T <= T_MAX
+        assert B * C % min(C, 1024) == 0
+        if "32" in os.environ["RWKV_FLOAT_MODE"]:
+            w = -torch.exp(w.contiguous())
+            u = u.contiguous()
+            k = k.contiguous()
+            v = v.contiguous()
+        else:
+            w = -torch.exp(w.float().contiguous())
+            u = u.float().contiguous()
+            k = k.float().contiguous()
+            v = v.float().contiguous()
+        ctx.save_for_backward(w, u, k, v)
+        y = torch.empty((B, T, C), device="cuda", memory_format=torch.contiguous_format)
+        wkv_cuda.forward(B, T, C, w, u, k, v, y)
+        if "32" in os.environ["RWKV_FLOAT_MODE"]:
+            return y
+        elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
+            return y.half()
+        elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
+            return y.bfloat16()
+
+    @staticmethod
+    def backward(ctx, gy):
+        B = ctx.B
+        T = ctx.T
+        C = ctx.C
+        assert T <= T_MAX
+        assert B * C % min(C, 1024) == 0
+        w, u, k, v = ctx.saved_tensors
+        gw = torch.zeros((B, C), device="cuda").contiguous()
+        gu = torch.zeros((B, C), device="cuda").contiguous()
+        gk = torch.zeros((B, T, C), device="cuda").contiguous()
+        gv = torch.zeros((B, T, C), device="cuda").contiguous()
+        if "32" in os.environ["RWKV_FLOAT_MODE"]:
+            wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
+        else:
+            wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
+        gw = torch.sum(gw, dim=0)
+        gu = torch.sum(gu, dim=0)
+        if "32" in os.environ["RWKV_FLOAT_MODE"]:
+            return (None, None, None, gw, gu, gk, gv)
+        elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
+            return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
+        elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
+            return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
+
+
+def RUN_CUDA(B, T, C, w, u, k, v):
+    return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
+
+
+class RWKV_TimeMix(torch.jit.ScriptModule):
+    def __init__(self, config, layer_id):
+        super().__init__()
+        load_rwkv_kernel()
+        self.layer_id = layer_id
+        self.ctx_len = config.ctx_len
+        self.n_embd = config.n_embd
+
+        attn_sz = config.n_embd
+
+        with torch.no_grad():  # fancy init
+            ratio_0_to_1 = layer_id / (config.n_layer - 1)  # 0 to 1
+            ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer)  # 1 to ~0
+
+            # fancy time_decay
+            decay_speed = torch.ones(attn_sz)
+            for h in range(attn_sz):
+                decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
+            self.time_decay = nn.Parameter(decay_speed)
+            # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
+
+            # fancy time_first
+            zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5
+            self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
+
+            # fancy time_mix
+            x = torch.ones(1, 1, config.n_embd)
+            for i in range(config.n_embd):
+                x[0, 0, i] = i / config.n_embd
+            self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
+            self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
+            self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
+
+        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+
+        self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
+        self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
+        self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
+
+        self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
+
+        self.key.scale_init = 0
+        self.receptance.scale_init = 0
+        self.output.scale_init = 0
+
+    @torch.jit.script_method
+    def jit_func(self, x):
+
+        # Mix x with the previous timestep to produce xk, xv, xr
+        xx = self.time_shift(x)
+        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
+        xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
+        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
+
+        # Use xk, xv, xr to produce k, v, r
+        k = self.key(xk)
+        v = self.value(xv)
+        r = self.receptance(xr)
+        sr = torch.sigmoid(r)
+
+        return sr, k, v
+
+    def forward(self, x):
+        B, T, C = x.size()  # x = (Batch,Time,Channel)
+
+        sr, k, v = self.jit_func(x)
+
+        rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
+        rwkv = self.output(rwkv)
+        return rwkv
+
+
+class RWKV_ChannelMix(torch.jit.ScriptModule):
+    def __init__(self, config, layer_id):
+        super().__init__()
+        self.layer_id = layer_id
+
+        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+
+        with torch.no_grad():  # fancy init of time_mix
+            ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer)  # 1 to ~0
+
+            x = torch.ones(1, 1, config.n_embd)
+            for i in range(config.n_embd):
+                x[0, 0, i] = i / config.n_embd
+
+            self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
+            self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
+
+        hidden_sz = 4 * config.n_embd
+        self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
+        self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
+        self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
+
+        self.value.scale_init = 0
+        self.receptance.scale_init = 0
+
+    @torch.jit.script_method
+    def forward(self, x):
+        xx = self.time_shift(x)
+        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
+        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
+
+        k = self.key(xk)
+        k = torch.square(torch.relu(k))
+        kv = self.value(k)
+
+        rkv = torch.sigmoid(self.receptance(xr)) * kv
+        return rkv
+
+
+# class Block(nn.Module):
+#     def __init__(self, args, layer_id):
+#         super().__init__()
+#         self.args = args
+#         self.layer_id = layer_id
+#
+#         self.ln1 = nn.LayerNorm(args.n_embd)
+#         self.ln2 = nn.LayerNorm(args.n_embd)
+#
+#         if self.layer_id == 0:
+#             self.ln0 = nn.LayerNorm(args.n_embd)
+#
+#         self.att = RWKV_Tmix_x060(args, layer_id)
+#
+#         self.ffn = RWKV_CMix_x060(args, layer_id)
+#
+#         if args.dropout > 0:
+#             self.drop0 = nn.Dropout(p=args.dropout)
+#             self.drop1 = nn.Dropout(p=args.dropout)
+#
+#     def forward(self, x, x_emb=None):
+#         args = self.args
+#         B, T, C = x.size()
+#         if self.layer_id == 0:
+#             x = self.ln0(x)
+#
+#         if self.args.dropout == 0:
+#             if self.layer_id == 0 and args.pre_ffn > 0:
+#                 x = x + self.ffnPre(self.ln1(x))
+#             else:
+#                 x = x + self.att(self.ln1(x))
+#             x = x + self.ffn(self.ln2(x))
+#         else:
+#             if self.layer_id == 0 and args.pre_ffn > 0:
+#                 x = self.drop0(x + self.ffnPre(self.ln1(x)))
+#             else:
+#                 x = self.drop0(x + self.att(self.ln1(x)))
+#             x = self.drop1(x + self.ffn(self.ln2(x)))
+#
+#         return x
+
+
+class RWKVLayer(nn.Module):
+    def __init__(self, args, layer_id):
+        super().__init__()
+        self.args = args
+        self.layer_id = layer_id
+        if args.dim_ffn is None:
+            args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
+        self.ln0 = None
+        if self.layer_id == 0 and args.get("ln0", True):
+            self.ln0 = nn.LayerNorm(args.n_embd)
+
+        self.ln1 = None
+        if args.get("ln1", True):
+            self.ln1 = nn.LayerNorm(args.n_embd)
+        self.ln2 = nn.LayerNorm(args.n_embd)
+
+        self.att = RWKV_TimeMix(args, layer_id)
+
+        self.ffn = RWKV_ChannelMix(args, layer_id)
+
+        if args.dropout > 0:
+            self.drop0 = nn.Dropout(p=args.dropout)
+            self.drop1 = nn.Dropout(p=args.dropout)
+
+        # init
+        if args.get("init_rwkv", True):
+            print("init_rwkv")
+            nn.init.orthogonal_(self.att.receptance.weight, gain=1)
+            nn.init.orthogonal_(self.att.key.weight, gain=0.1)
+            nn.init.orthogonal_(self.att.value.weight, gain=1)
+            nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
+            nn.init.zeros_(self.att.output.weight)
+
+            nn.init.orthogonal_(self.ffn.key.weight, gain=1)
+            nn.init.zeros_(self.ffn.value.weight)
+            nn.init.zeros_(self.ffn.receptance.weight)
+            scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+            nn.init.constant_(self.ln2.weight, scale)
+            if self.ln0 is not None:
+                nn.init.constant_(self.ln0.weight, scale)
+            if self.ln1 is not None:
+                nn.init.constant_(self.ln1.weight, scale)
+
+    def forward(self, x, x_emb=None, mask=None, **kwargs):
+
+        args = self.args
+        if args.get("datatype", "bf16") == "bf16":
+            x = x.bfloat16()
+        B, T, C = x.size()
+        if self.layer_id == 0 and self.ln0 is not None:
+            x = self.ln0(x)
+
+        if self.args.dropout == 0:
+            if self.ln1 is None:
+                x = x + self.att(x)
+            else:
+                x = x + self.att(self.ln1(x))
+            x = x + self.ffn(self.ln2(x))
+        else:
+            if self.ln1 is None:
+                x = self.drop0(x + self.att(x))
+            else:
+                x = self.drop0(x + self.att(self.ln1(x)))
+            x = self.drop1(x + self.ffn(self.ln2(x)))
+
+        if args.get("datatype", "bf16") == "bf16":
+            x = x.to(torch.float32)
+        return x
+
+
+class RWKV(nn.Module):
+    def __init__(self, args):
+        super().__init__()
+        self.args = args
+        if not hasattr(args, "dim_att"):
+            args.dim_att = args.n_embd
+        if not hasattr(args, "dim_ffn"):
+            if "-f4" in os.environ["RWKV_MY_TESTING"]:
+                args.dim_ffn = int((args.n_embd * 4) // 32 * 32)
+            else:
+                args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)  # default = 3.5x emb size
+        if not hasattr(args, "tiny_att_layer"):
+            args.tiny_att_layer = -1
+        if not hasattr(args, "tiny_att_dim"):
+            args.tiny_att_dim = -1
+        assert args.n_embd % 32 == 0
+        assert args.dim_att % 32 == 0
+        assert args.dim_ffn % 32 == 0
+
+        self.emb = nn.Embedding(args.vocab_size, args.n_embd)
+
+        self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
+
+        self.ln_out = nn.LayerNorm(args.n_embd)
+        self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
+
+        if args.dropout > 0:
+            self.drop0 = nn.Dropout(p=args.dropout)
+
+    def forward(self, idx):
+        args = self.args
+        B, T = idx.size()
+        assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
+
+        x = self.emb(idx)
+        x_emb = x
+
+        if args.dropout > 0:
+            x = self.drop0(x)
+        if args.tiny_att_dim > 0:
+            for block in self.blocks:
+                if args.grad_cp == 1:
+                    x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
+                else:
+                    x = block(x, x_emb)
+        else:
+            for block in self.blocks:
+                if args.grad_cp == 1:
+                    x = deepspeed.checkpointing.checkpoint(block, x)
+                else:
+                    x = block(x)
+
+        x = self.ln_out(x)
+
+        if args.head_qk > 0:
+            q = self.head_q(x)[:, :T, :]
+            k = self.head_k(x)[:, :T, :]
+            c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
+            c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
+
+            if "32" in os.environ["RWKV_FLOAT_MODE"]:
+                c = c @ F.one_hot(idx, num_classes=args.vocab_size)
+            elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
+                c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
+            elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
+                c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
+
+            x = self.head(x) + c
+        else:
+            x = self.head(x)
+
+        return x
diff --git a/funasr/models/sense_voice/rwkv_v5.py b/funasr/models/sense_voice/rwkv_v5.py
new file mode 100644
index 0000000..f19ca79
--- /dev/null
+++ b/funasr/models/sense_voice/rwkv_v5.py
@@ -0,0 +1,597 @@
+########################################################################################################
+# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
+########################################################################################################
+
+import os, math, gc, importlib
+import torch
+
+# torch._C._jit_set_profiling_executor(True)
+# torch._C._jit_set_profiling_mode(True)
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+def __nop(ob):
+    return ob
+
+
+MyModule = nn.Module
+MyFunction = __nop
+if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1":
+    MyModule = torch.jit.ScriptModule
+    MyFunction = torch.jit.script_method
+
+########################################################################################################
+# CUDA Kernel
+########################################################################################################
+
+wkv5_cuda = None
+
+
+def load_rwkv_kernel(
+    HEAD_SIZE: int = 64,
+    RWKV_CTXLEN: int = 512,
+):
+    from torch.utils.cpp_extension import load
+
+    global wkv5_cuda
+
+    if wkv5_cuda is not None:
+        return
+
+    absolute_file_path = os.path.abspath(__file__)
+    cur_dir = os.path.dirname(absolute_file_path)
+
+    wkv5_cuda = load(
+        name="wkv5",
+        sources=[f"{cur_dir}/cuda/wkv5_op.cpp", f"{cur_dir}/cuda/wkv5_cuda.cu"],
+        verbose=True,
+        extra_cuda_cflags=[
+            "-res-usage",
+            "--use_fast_math",
+            "-O3",
+            "-Xptxas -O3",
+            "--extra-device-vectorization",
+            f"-D_N_={HEAD_SIZE}",
+        ],
+    )
+
+
+# dtype = torch.float
+dtype = torch.bfloat16
+
+
+class WKV_5(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, B, T, C, H, r, k, v, w, u):
+        with torch.no_grad():
+            # assert r.dtype == torch.bfloat16
+            # assert k.dtype == torch.bfloat16
+            # assert v.dtype == torch.bfloat16
+            # assert w.dtype == torch.bfloat16
+            # assert u.dtype == torch.bfloat16
+            # assert HEAD_SIZE == C // H
+            ctx.B = B
+            ctx.T = T
+            ctx.C = C
+            ctx.H = H
+            assert r.is_contiguous()
+            assert k.is_contiguous()
+            assert v.is_contiguous()
+            assert w.is_contiguous()
+            assert u.is_contiguous()
+            ew = (-torch.exp(w.float())).contiguous()
+            eew = (torch.exp(ew)).contiguous()
+            ctx.save_for_backward(r, k, v, eew, ew, u)
+            y = torch.empty(
+                (B, T, C),
+                device=r.device,
+                dtype=torch.bfloat16,
+                memory_format=torch.contiguous_format,
+            )  # .uniform_(-1, 1)
+            wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y)
+            return y
+
+    @staticmethod
+    def backward(ctx, gy):
+        with torch.no_grad():
+            assert gy.dtype == torch.bfloat16
+            B = ctx.B
+            T = ctx.T
+            C = ctx.C
+            H = ctx.H
+            assert gy.is_contiguous()
+            r, k, v, eew, ew, u = ctx.saved_tensors
+            gr = torch.empty(
+                (B, T, C),
+                device=gy.device,
+                requires_grad=False,
+                dtype=torch.bfloat16,
+                memory_format=torch.contiguous_format,
+            )  # .uniform_(-1, 1)
+            gk = torch.empty(
+                (B, T, C),
+                device=gy.device,
+                requires_grad=False,
+                dtype=torch.bfloat16,
+                memory_format=torch.contiguous_format,
+            )  # .uniform_(-1, 1)
+            gv = torch.empty(
+                (B, T, C),
+                device=gy.device,
+                requires_grad=False,
+                dtype=torch.bfloat16,
+                memory_format=torch.contiguous_format,
+            )  # .uniform_(-1, 1)
+            gw = torch.empty(
+                (B, C),
+                device=gy.device,
+                requires_grad=False,
+                dtype=torch.bfloat16,
+                memory_format=torch.contiguous_format,
+            )  # .uniform_(-1, 1)
+            gu = torch.empty(
+                (B, C),
+                device=gy.device,
+                requires_grad=False,
+                dtype=torch.bfloat16,
+                memory_format=torch.contiguous_format,
+            )  # .uniform_(-1, 1)
+            wkv5_cuda.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu)
+            gw = torch.sum(gw, 0).view(H, C // H)
+            gu = torch.sum(gu, 0).view(H, C // H)
+            return (None, None, None, None, gr, gk, gv, gw, gu)
+
+
+def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u):
+    return WKV_5.apply(B, T, C, H, r, k, v, w, u)
+
+
+class RWKV_Tmix_x052(MyModule):
+    def __init__(self, args, layer_id):
+        super().__init__()
+        self.args = args
+        load_rwkv_kernel(args.head_size_a, args.ctx_len)
+        self.layer_id = layer_id
+
+        self.head_size = args.head_size_a
+        # assert HEAD_SIZE == self.head_size  # change HEAD_SIZE to match args.head_size_a
+        self.n_head = args.dim_att // self.head_size
+        assert args.dim_att % self.n_head == 0
+        self.head_size_divisor = args.head_size_divisor
+
+        with torch.no_grad():
+            ratio_0_to_1 = layer_id / (args.n_layer - 1)  # 0 to 1
+            ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)  # 1 to ~0
+            ddd = torch.ones(1, 1, args.n_embd)
+            for i in range(args.n_embd):
+                ddd[0, 0, i] = i / args.n_embd
+
+            # fancy time_mix
+            self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
+            self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
+            self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+            self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+
+            # fancy time_decay
+            decay_speed = torch.ones(args.dim_att)
+            for n in range(args.dim_att):
+                decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
+            self.time_decay = nn.Parameter(decay_speed.reshape(self.n_head, self.head_size))
+            # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
+
+            tmp = torch.zeros(args.dim_att)
+            for n in range(args.dim_att):
+                zigzag = ((n + 1) % 3 - 1) * 0.1
+                tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
+
+            self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
+
+        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+        self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
+        self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
+
+        self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
+        self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
+        self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
+        self.ln_x = nn.GroupNorm(self.n_head, args.dim_att)
+
+    @MyFunction
+    def jit_func(self, x):
+        B, T, C = x.size()
+
+        xx = self.time_shift(x)  # Mix x with the previous timestep to produce xk, xv, xr
+        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
+        xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
+        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
+        xg = x * self.time_mix_g + xx * (1 - self.time_mix_g)
+
+        r = self.receptance(xr)
+        k = self.key(xk)
+        v = self.value(xv)
+        g = F.silu(self.gate(xg))
+
+        return r, k, v, g
+
+    @MyFunction
+    def jit_func_2(self, x, g):
+        B, T, C = x.size()
+        x = x.view(B * T, C)
+
+        x = self.ln_x(x / self.head_size_divisor).view(B, T, C)
+        x = self.output(x * g)
+        return x
+
+    def forward(self, x, **kwargs):
+        B, T, C = x.size()
+        H = self.n_head
+
+        r, k, v, g = self.jit_func(x)
+
+        x = RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa)
+
+        return self.jit_func_2(x, g)
+
+
+#
+# class RWKV_Tmix_x060(MyModule):
+#     def __init__(self, args, layer_id):
+#         super().__init__()
+#         self.args = args
+#
+#         load_rwkv_kernel(args.head_size_a, args.ctx_len)
+#
+#         self.layer_id = layer_id
+#
+#         self.head_size = args.head_size_a
+#         self.n_head = args.dim_att // self.head_size
+#         assert args.dim_att % self.n_head == 0
+#
+#         with torch.no_grad():
+#             ratio_0_to_1 = layer_id / (args.n_layer - 1)  # 0 to 1
+#             ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)  # 1 to ~0
+#             ddd = torch.ones(1, 1, args.n_embd)
+#             for i in range(args.n_embd):
+#                 ddd[0, 0, i] = i / args.n_embd
+#
+#             # fancy time_mix
+#             self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+#             self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+#             self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+#             self.time_maa_v = nn.Parameter(
+#                 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
+#             )
+#             self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+#             self.time_maa_g = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+#
+#             D_MIX_LORA = 32  # generate TIME_MIX for w,k,v,r,g
+#             self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5))
+#             self.time_maa_w2 = nn.Parameter(
+#                 torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01)
+#             )
+#
+#             # fancy time_decay
+#             decay_speed = torch.ones(args.dim_att)
+#             for n in range(args.dim_att):
+#                 decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
+#             self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
+#
+#             D_DECAY_LORA = 64
+#             self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA))
+#             self.time_decay_w2 = nn.Parameter(
+#                 torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01)
+#             )
+#
+#             tmp = torch.zeros(args.dim_att)
+#             for n in range(args.dim_att):
+#                 zigzag = ((n + 1) % 3 - 1) * 0.1
+#                 tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
+#
+#             self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
+#
+#         self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+#         self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
+#         self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
+#
+#         self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
+#         self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
+#         self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
+#         self.ln_x = nn.GroupNorm(
+#             self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2)
+#         )
+#
+#     @MyFunction
+#     def jit_func(self, x):
+#         B, T, C = x.size()
+#
+#         xx = self.time_shift(x) - x
+#
+#         xxx = x + xx * self.time_maa_x
+#         xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
+#         xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
+#         mw, mk, mv, mr, mg = xxx.unbind(dim=0)
+#
+#         xw = x + xx * (self.time_maa_w + mw)
+#         xk = x + xx * (self.time_maa_k + mk)
+#         xv = x + xx * (self.time_maa_v + mv)
+#         xr = x + xx * (self.time_maa_r + mr)
+#         xg = x + xx * (self.time_maa_g + mg)
+#
+#         r = self.receptance(xr)
+#         k = self.key(xk)
+#         v = self.value(xv)
+#         g = F.silu(self.gate(xg))
+#
+#         ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
+#         w = self.time_decay + ww
+#
+#         return r, k, v, g, w
+#
+#     @MyFunction
+#     def jit_func_2(self, x, g):
+#         B, T, C = x.size()
+#         x = x.view(B * T, C)
+#
+#         x = self.ln_x(x).view(B, T, C)
+#         x = self.output(x * g)
+#         return x
+#
+#     def forward(self, x):
+#         B, T, C = x.size()
+#         H = self.n_head
+#
+#         r, k, v, g, w = self.jit_func(x)
+#         x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa)
+#
+#         return self.jit_func_2(x, g)
+
+
+class RWKV_CMix_x052(MyModule):
+    def __init__(self, args, layer_id):
+        super().__init__()
+        self.args = args
+        self.layer_id = layer_id
+        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+
+        with torch.no_grad():  # fancy init of time_mix
+            ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)  # 1 to ~0
+            ddd = torch.ones(1, 1, args.n_embd)
+            for i in range(args.n_embd):
+                ddd[0, 0, i] = i / args.n_embd
+            self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
+            self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
+
+        self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
+        self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
+        self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
+
+    @MyFunction
+    def forward(self, x):
+        xx = self.time_shift(x)
+        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
+        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
+        k = self.key(xk)
+        k = torch.relu(k) ** 2
+        kv = self.value(k)
+        return torch.sigmoid(self.receptance(xr)) * kv
+
+
+# class RWKV_CMix_x060(MyModule):
+#     def __init__(self, args, layer_id):
+#         super().__init__()
+#         self.args = args
+#         self.layer_id = layer_id
+#         self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+#
+#         with torch.no_grad():  # fancy init of time_mix
+#             ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)  # 1 to ~0
+#             ddd = torch.ones(1, 1, args.n_embd)
+#             for i in range(args.n_embd):
+#                 ddd[0, 0, i] = i / args.n_embd
+#             self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+#             self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+#
+#         self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
+#         self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
+#         self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
+#
+#     @MyFunction
+#     def forward(self, x):
+#         xx = self.time_shift(x) - x
+#         xk = x + xx * self.time_maa_k
+#         xr = x + xx * self.time_maa_r
+#
+#         k = self.key(xk)
+#         k = torch.relu(k) ** 2
+#         kv = self.value(k)
+#         return torch.sigmoid(self.receptance(xr)) * kv
+
+
+# class Block(nn.Module):
+#     def __init__(self, args, layer_id):
+#         super().__init__()
+#         self.args = args
+#         self.layer_id = layer_id
+#
+#         self.ln1 = nn.LayerNorm(args.n_embd)
+#         self.ln2 = nn.LayerNorm(args.n_embd)
+#
+#         if self.layer_id == 0:
+#             self.ln0 = nn.LayerNorm(args.n_embd)
+#
+#         self.att = RWKV_Tmix_x060(args, layer_id)
+#
+#         self.ffn = RWKV_CMix_x060(args, layer_id)
+#
+#         if args.dropout > 0:
+#             self.drop0 = nn.Dropout(p=args.dropout)
+#             self.drop1 = nn.Dropout(p=args.dropout)
+#
+#     def forward(self, x, x_emb=None):
+#         args = self.args
+#         B, T, C = x.size()
+#         if self.layer_id == 0:
+#             x = self.ln0(x)
+#
+#         if self.args.dropout == 0:
+#             if self.layer_id == 0 and args.pre_ffn > 0:
+#                 x = x + self.ffnPre(self.ln1(x))
+#             else:
+#                 x = x + self.att(self.ln1(x))
+#             x = x + self.ffn(self.ln2(x))
+#         else:
+#             if self.layer_id == 0 and args.pre_ffn > 0:
+#                 x = self.drop0(x + self.ffnPre(self.ln1(x)))
+#             else:
+#                 x = self.drop0(x + self.att(self.ln1(x)))
+#             x = self.drop1(x + self.ffn(self.ln2(x)))
+#
+#         return x
+
+
+class RWKVLayer(nn.Module):
+    def __init__(self, args, layer_id):
+        super().__init__()
+        self.args = args
+        self.layer_id = layer_id
+        if args.dim_ffn is None:
+            args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
+        self.ln0 = None
+        if self.layer_id == 0 and args.get("ln0", True):
+            self.ln0 = nn.LayerNorm(args.n_embd)
+
+        self.ln1 = None
+        if args.get("ln1", True):
+            self.ln1 = nn.LayerNorm(args.n_embd)
+
+        self.att = RWKV_Tmix_x052(args, layer_id)
+        self.ln2 = None
+        self.ffn = None
+        if args.get("use_rwkv_ffn", True):
+            self.ln2 = nn.LayerNorm(args.n_embd)
+            self.ffn = RWKV_CMix_x052(args, layer_id)
+
+        if args.dropout > 0:
+            self.drop0 = nn.Dropout(p=args.dropout)
+            self.drop1 = nn.Dropout(p=args.dropout)
+
+        # init
+        if args.get("init_rwkv", True):
+            print("init_rwkv")
+            nn.init.orthogonal_(self.att.receptance.weight, gain=1)
+            nn.init.orthogonal_(self.att.key.weight, gain=0.1)
+            nn.init.orthogonal_(self.att.value.weight, gain=1)
+            nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
+            nn.init.zeros_(self.att.output.weight)
+
+            nn.init.orthogonal_(self.ffn.key.weight, gain=1)
+            nn.init.zeros_(self.ffn.value.weight)
+            nn.init.zeros_(self.ffn.receptance.weight)
+            scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+            nn.init.constant_(self.ln2.weight, scale)
+            if self.ln0 is not None:
+                nn.init.constant_(self.ln0.weight, scale)
+            if self.ln1 is not None:
+                nn.init.constant_(self.ln1.weight, scale)
+
+    def forward(self, x, x_emb=None, mask=None, **kwargs):
+
+        args = self.args
+        if args.get("datatype", "bf16") == "bf16":
+            x = x.bfloat16()
+        B, T, C = x.size()
+        if self.layer_id == 0 and self.ln0 is not None:
+            x = self.ln0(x)
+
+        if self.args.dropout == 0:
+            if self.ln1 is None:
+                x = x + self.att(x)
+            else:
+                x = x + self.att(self.ln1(x))
+            if self.ffn is not None:
+                x = x + self.ffn(self.ln2(x))
+        else:
+            if self.ln1 is None:
+                x = self.drop0(x + self.att(x))
+            else:
+                x = self.drop0(x + self.att(self.ln1(x)))
+            if self.ffn is not None:
+                x = self.drop1(x + self.ffn(self.ln2(x)))
+
+        if args.get("datatype", "bf16") == "bf16":
+            x = x.to(torch.float32)
+        return x
+
+
+# class RWKV(nn.Module):
+#     def __init__(self, args):
+#         super().__init__()
+#         self.args = args
+#         if not hasattr(args, "dim_att"):
+#             args.dim_att = args.n_embd
+#         if not hasattr(args, "dim_ffn"):
+#             if "-f4" in os.environ["RWKV_MY_TESTING"]:
+#                 args.dim_ffn = int((args.n_embd * 4) // 32 * 32)
+#             else:
+#                 args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)  # default = 3.5x emb size
+#         if not hasattr(args, "tiny_att_layer"):
+#             args.tiny_att_layer = -1
+#         if not hasattr(args, "tiny_att_dim"):
+#             args.tiny_att_dim = -1
+#         assert args.n_embd % 32 == 0
+#         assert args.dim_att % 32 == 0
+#         assert args.dim_ffn % 32 == 0
+#
+#         self.emb = nn.Embedding(args.vocab_size, args.n_embd)
+#
+#         self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
+#
+#         self.ln_out = nn.LayerNorm(args.n_embd)
+#         self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
+#
+#         if args.dropout > 0:
+#             self.drop0 = nn.Dropout(p=args.dropout)
+#
+#     def forward(self, idx):
+#         args = self.args
+#         B, T = idx.size()
+#         assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
+#
+#         x = self.emb(idx)
+#         x_emb = x
+#
+#         if args.dropout > 0:
+#             x = self.drop0(x)
+#         if args.tiny_att_dim > 0:
+#             for block in self.blocks:
+#                 if args.grad_cp == 1:
+#                     x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
+#                 else:
+#                     x = block(x, x_emb)
+#         else:
+#             for block in self.blocks:
+#                 if args.grad_cp == 1:
+#                     x = deepspeed.checkpointing.checkpoint(block, x)
+#                 else:
+#                     x = block(x)
+#
+#         x = self.ln_out(x)
+#
+#         if args.head_qk > 0:
+#             q = self.head_q(x)[:, :T, :]
+#             k = self.head_k(x)[:, :T, :]
+#             c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
+#             c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
+#
+#             if "32" in os.environ["RWKV_FLOAT_MODE"]:
+#                 c = c @ F.one_hot(idx, num_classes=args.vocab_size)
+#             elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
+#                 c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
+#             elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
+#                 c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
+#
+#             x = self.head(x) + c
+#         else:
+#             x = self.head(x)
+#
+#         return x
diff --git a/funasr/models/sense_voice/rwkv_v6.py b/funasr/models/sense_voice/rwkv_v6.py
index 36269a1..b91d47a 100644
--- a/funasr/models/sense_voice/rwkv_v6.py
+++ b/funasr/models/sense_voice/rwkv_v6.py
@@ -244,7 +244,7 @@
         x = self.output(x * g)
         return x
 
-    def forward(self, x):
+    def forward(self, x, **kwargs):
         B, T, C = x.size()
         H = self.n_head
 
@@ -341,11 +341,14 @@
         self.ln1 = None
         if args.get("ln1", True):
             self.ln1 = nn.LayerNorm(args.n_embd)
-        self.ln2 = nn.LayerNorm(args.n_embd)
 
         self.att = RWKV_Tmix_x060(args, layer_id)
 
-        self.ffn = RWKV_CMix_x060(args, layer_id)
+        self.ln2 = None
+        self.ffn = None
+        if args.get("use_rwkv_ffn", True):
+            self.ln2 = nn.LayerNorm(args.n_embd)
+            self.ffn = RWKV_CMix_x060(args, layer_id)
 
         if args.dropout > 0:
             self.drop0 = nn.Dropout(p=args.dropout)
@@ -364,11 +367,13 @@
             nn.init.zeros_(self.ffn.value.weight)
             nn.init.zeros_(self.ffn.receptance.weight)
             scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
-            nn.init.constant_(self.ln2.weight, scale)
+
             if self.ln0 is not None:
                 nn.init.constant_(self.ln0.weight, scale)
             if self.ln1 is not None:
                 nn.init.constant_(self.ln1.weight, scale)
+            if self.ln2 is not None:
+                nn.init.constant_(self.ln2.weight, scale)
 
     def forward(self, x, x_emb=None, mask=None, **kwargs):
 
@@ -384,13 +389,15 @@
                 x = x + self.att(x)
             else:
                 x = x + self.att(self.ln1(x))
-            x = x + self.ffn(self.ln2(x))
+            if self.ffn is not None:
+                x = x + self.ffn(self.ln2(x))
         else:
             if self.ln1 is None:
                 x = self.drop0(x + self.att(x))
             else:
                 x = self.drop0(x + self.att(self.ln1(x)))
-            x = self.drop1(x + self.ffn(self.ln2(x)))
+            if self.ffn is not None:
+                x = self.drop1(x + self.ffn(self.ln2(x)))
 
         if args.get("datatype", "bf16") == "bf16":
             x = x.to(torch.float32)
diff --git a/funasr/train_utils/model_summary.py b/funasr/train_utils/model_summary.py
index 2aef88a..4e92a33 100644
--- a/funasr/train_utils/model_summary.py
+++ b/funasr/train_utils/model_summary.py
@@ -47,6 +47,8 @@
 def model_summary(model: torch.nn.Module) -> str:
     message = "Model structure:\n"
     message += str(model)
+    # for p in model.parameters():
+    #     print(f"{p.numel()}")
     tot_params = sum(p.numel() for p in model.parameters())
     num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
     percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 8f20ba4..66f8778 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -15,6 +15,11 @@
 from funasr.train_utils.average_nbest_models import average_checkpoints
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
 
+try:
+    import wandb
+except:
+    wandb = None
+
 
 @contextmanager
 def maybe_autocast(enabled):
@@ -107,9 +112,22 @@
         self.best_step_or_epoch = ""
         self.val_acc_step_or_eoch = {}
         self.val_loss_step_or_eoch = {}
-        
-        self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
 
+        self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
+        self.start_data_split_i = 0
+        self.start_step = 0
+        self.use_wandb = kwargs.get("use_wandb", False)
+        if self.use_wandb:
+            wandb.login(key=kwargs.get("wandb_token"))
+            wandb.init(
+                config=kwargs,
+                project=kwargs.get("wandb_project", "my_project"),
+                entity=kwargs.get("wandb_team", "my_team"),
+                name=kwargs.get("wandb_exp_name", "my_exp"),
+                dir=output_dir,
+                job_type="training",
+                reinit=True,
+            )
 
     def save_checkpoint(
         self,
@@ -142,6 +160,7 @@
                 "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
                 "best_step_or_epoch": self.best_step_or_epoch,
                 "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
+                "step": step,
             }
             if hasattr(model, "module"):
                 state["state_dict"] = model.module.state_dict()
@@ -268,6 +287,13 @@
                 self.best_step_or_epoch = (
                     checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
                 )
+                self.start_data_split_i = (
+                    checkpoint["start_data_split_i"] if "start_data_split_i" in checkpoint else 0
+                )
+                self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
+                self.start_step = checkpoint["step"] if "step" in checkpoint else 0
+                self.start_step = 0 if self.start_step is None else self.start_step
+
                 model.to(self.device)
                 print(f"Checkpoint loaded successfully from '{ckpt}'")
             else:
@@ -327,7 +353,7 @@
                 time2 = time.perf_counter()
                 with maybe_autocast(self.use_fp16):
                     retval = model(**batch)
-                    
+
                     if (
                         self.reset_gpu_cache
                         and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
@@ -598,7 +624,7 @@
             acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
             description = (
                 f"{tag}, "
-                f"rank: {self.local_rank}, "
+                f"rank: {self.rank}, "
                 f"epoch: {epoch}/{self.max_epoch}, "
                 f"data_slice: {data_split_i}/{data_split_num}, "
                 f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
@@ -613,18 +639,29 @@
             )
             logging.info(description)
 
+            description_dict = {
+                f"rank{self.rank}_loss/{tag}": loss,
+                f"rank{self.rank}_lr/{tag}": lr,
+            }
+
             if writer is not None:
-                writer.add_scalar(f"rank{self.local_rank}_loss/{tag}", loss, self.batch_total)
-                writer.add_scalar(f"rank{self.local_rank}_lr/{tag}", lr, self.batch_total)
-                writer.add_scalar(f"rank{self.local_rank}_lr/{tag}", lr, self.batch_total)
+                writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, self.batch_total)
+                writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, self.batch_total)
                 for key, var in stats.items():
                     writer.add_scalar(
-                        f"stats_rank{self.local_rank}_{key}/{tag}", var.item(), self.batch_total
+                        f"stats_rank{self.rank}_{key}/{tag}", var.item(), self.batch_total
                     )
+                    description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = var.item()
                 for key, var in speed_stats.items():
                     writer.add_scalar(
-                        f"stats_rank{self.local_rank}_{key}/{tag}", eval(var), self.batch_total
+                        f"stats_rank{self.rank}_{key}/{tag}", eval(var), self.batch_total
                     )
+                    description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = eval(var)
+            if self.use_wandb and wandb is not None:
+                wandb.log(
+                    description_dict,
+                    setp=self.batch_total,
+                )
 
     def close(self, writer=None):
 
diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log
new file mode 120000
index 0000000..e3e213d
--- /dev/null
+++ b/wandb/debug-internal.log
@@ -0,0 +1 @@
+run-20240425_211446-lkqptn01/logs/debug-internal.log
\ No newline at end of file
diff --git a/wandb/debug.log b/wandb/debug.log
new file mode 120000
index 0000000..826b402
--- /dev/null
+++ b/wandb/debug.log
@@ -0,0 +1 @@
+run-20240425_211446-lkqptn01/logs/debug.log
\ No newline at end of file
diff --git a/wandb/latest-run b/wandb/latest-run
new file mode 120000
index 0000000..7e88449
--- /dev/null
+++ b/wandb/latest-run
@@ -0,0 +1 @@
+run-20240425_211446-lkqptn01
\ No newline at end of file
diff --git a/wandb/run-20240425_211446-lkqptn01/files/config.yaml b/wandb/run-20240425_211446-lkqptn01/files/config.yaml
new file mode 100644
index 0000000..b131e2e
--- /dev/null
+++ b/wandb/run-20240425_211446-lkqptn01/files/config.yaml
@@ -0,0 +1,26 @@
+wandb_version: 1
+
+a:
+  desc: null
+  value: 1
+_wandb:
+  desc: null
+  value:
+    python_version: 3.8.15
+    cli_version: 0.16.6
+    is_jupyter_run: false
+    is_kaggle_kernel: false
+    start_time: 1714050886.0
+    t:
+      1:
+      - 55
+      - 105
+      3:
+      - 13
+      - 16
+      - 23
+      4: 3.8.15
+      5: 0.16.6
+      8:
+      - 5
+      13: darwin-arm64
diff --git a/wandb/run-20240425_211446-lkqptn01/logs/debug-internal.log b/wandb/run-20240425_211446-lkqptn01/logs/debug-internal.log
new file mode 100644
index 0000000..4d63ace
--- /dev/null
+++ b/wandb/run-20240425_211446-lkqptn01/logs/debug-internal.log
@@ -0,0 +1,146 @@
+2024-04-25 21:14:46,163 INFO    StreamThr :90024 [internal.py:wandb_internal():86] W&B internal server running at pid: 90024, started at: 2024-04-25 21:14:46.156788
+2024-04-25 21:14:46,164 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status
+2024-04-25 21:14:46,172 INFO    WriterThread:90024 [datastore.py:open_for_write():87] open: ./wandb/run-20240425_211446-lkqptn01/run-lkqptn01.wandb
+2024-04-25 21:14:46,173 DEBUG   SenderThread:90024 [sender.py:send():379] send: header
+2024-04-25 21:14:46,284 DEBUG   SenderThread:90024 [sender.py:send():379] send: run
+2024-04-25 21:14:46,828 ERROR   SenderThread:90024 [internal_api.py:execute():373] 403 response executing GraphQL.
+2024-04-25 21:14:46,828 ERROR   SenderThread:90024 [internal_api.py:execute():374] {"errors":[{"message":"permission denied","path":["upsertBucket"],"extensions":{"code":"PERMISSION_ERROR"}}],"data":{"upsertBucket":null}}
+2024-04-25 21:14:46,828 ERROR   SenderThread:90024 [sender.py:send_run():971] It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
+Traceback (most recent call last):
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/lib/retry.py", line 131, in __call__
+    result = self._call_fn(*args, **kwargs)
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/internal_api.py", line 369, in execute
+    return self.client.execute(*args, **kwargs)  # type: ignore
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py", line 52, in execute
+    result = self._get_result(document, *args, **kwargs)
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py", line 60, in _get_result
+    return self.transport.execute(document, *args, **kwargs)
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/lib/gql_request.py", line 59, in execute
+    request.raise_for_status()
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/requests/models.py", line 1021, in raise_for_status
+    raise HTTPError(http_error_msg, response=self)
+requests.exceptions.HTTPError: 403 Client Error: Forbidden for url: https://api.wandb.ai/graphql
+
+During handling of the above exception, another exception occurred:
+
+Traceback (most recent call last):
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/sender.py", line 969, in send_run
+    server_run = self._init_run(run, config_value_dict)
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/sender.py", line 1014, in _init_run
+    server_run, inserted, server_messages = self._api.upsert_run(
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/apis/normalize.py", line 73, in wrapper
+    raise err
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/apis/normalize.py", line 41, in wrapper
+    return func(*args, **kwargs)
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/internal_api.py", line 2217, in upsert_run
+    response = self.gql(
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/internal_api.py", line 341, in gql
+    ret = self._retry_gql(
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/lib/retry.py", line 147, in __call__
+    retry_timedelta_triggered = check_retry_fn(e)
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/util.py", line 968, in check_retry_fn
+    return fallback_retry_fn(e)
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/util.py", line 910, in no_retry_auth
+    raise CommError(
+wandb.errors.CommError: It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
+2024-04-25 21:14:51,848 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:14:56,865 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:01,882 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:06,910 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:11,930 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:16,954 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:21,968 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:26,992 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:32,007 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:37,026 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:42,047 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:47,060 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:52,088 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:15:57,107 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:02,125 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:07,154 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:12,184 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:17,193 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:22,218 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:27,236 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:32,248 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:37,267 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:42,285 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:47,302 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:52,323 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:16:57,345 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:02,359 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:07,370 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:12,391 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:17,410 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:22,427 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:27,449 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:32,472 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:37,492 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:42,514 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:47,529 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:52,547 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:17:57,565 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:02,584 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:07,605 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:12,630 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:17,649 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:22,670 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:27,689 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:32,706 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:37,721 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:42,738 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:47,760 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:52,776 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:18:57,797 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:02,816 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:07,831 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:12,849 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:17,886 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:22,903 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:27,914 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:32,935 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:37,956 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:42,977 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:47,997 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:53,013 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:19:58,037 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:03,054 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:08,074 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:13,095 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:18,115 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:23,136 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:28,158 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:33,175 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:38,195 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:43,216 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:48,237 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:53,254 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:20:58,276 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:03,293 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:08,306 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:13,328 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:18,349 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:23,372 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:28,393 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:33,418 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:38,428 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:43,452 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:48,475 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:53,493 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:21:58,512 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:03,530 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:08,546 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:13,563 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:18,586 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:23,604 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:28,621 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:33,641 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:38,688 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:43,704 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:48,722 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:53,742 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:22:58,756 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:23:03,775 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:23:08,798 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
+2024-04-25 21:23:13,815 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
diff --git a/wandb/run-20240425_211446-lkqptn01/logs/debug.log b/wandb/run-20240425_211446-lkqptn01/logs/debug.log
new file mode 100644
index 0000000..0f4231c
--- /dev/null
+++ b/wandb/run-20240425_211446-lkqptn01/logs/debug.log
@@ -0,0 +1,29 @@
+2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Current SDK version is 0.16.6
+2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Configure stats pid to 89997
+2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Loading settings from /Users/zhifu/.config/wandb/settings
+2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Loading settings from /Users/zhifu/funasr1.0/wandb/settings
+2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Loading settings from environment variables: {}
+2024-04-25 21:14:46,147 WARNING MainThread:89997 [wandb_setup.py:_flush():76] Could not find program at <input>
+2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Inferring run settings from compute environment: {'program_relpath': None, 'program': '<input>'}
+2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Applying login settings: {'api_key': '***REDACTED***'}
+2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Applying login settings: {'api_key': '***REDACTED***'}
+2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Applying login settings: {}
+2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_init.py:_log_setup():521] Logging user logs to ./wandb/run-20240425_211446-lkqptn01/logs/debug.log
+2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_init.py:_log_setup():522] Logging internal logs to ./wandb/run-20240425_211446-lkqptn01/logs/debug-internal.log
+2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_init.py:init():561] calling init triggers
+2024-04-25 21:14:46,148 INFO    MainThread:89997 [wandb_init.py:init():568] wandb.init called with sweep_config: {}
+config: {'a': 1}
+2024-04-25 21:14:46,148 INFO    MainThread:89997 [wandb_init.py:init():611] starting backend
+2024-04-25 21:14:46,148 INFO    MainThread:89997 [wandb_init.py:init():615] setting up manager
+2024-04-25 21:14:46,153 INFO    MainThread:89997 [backend.py:_multiprocessing_setup():105] multiprocessing start_methods=spawn,fork,forkserver, using: spawn
+2024-04-25 21:14:46,156 INFO    MainThread:89997 [wandb_init.py:init():623] backend started and connected
+2024-04-25 21:14:46,171 INFO    MainThread:89997 [wandb_init.py:init():715] updated telemetry
+2024-04-25 21:14:46,282 INFO    MainThread:89997 [wandb_init.py:init():748] communicating run to backend with 90.0 second timeout
+2024-04-25 21:14:46,836 ERROR   MainThread:89997 [wandb_init.py:init():774] encountered error: It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
+2024-04-25 21:14:48,147 ERROR   MainThread:89997 [wandb_init.py:init():1199] It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
+Traceback (most recent call last):
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 1181, in init
+    run = wi.init()
+  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 780, in init
+    raise error
+wandb.errors.CommError: It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
diff --git a/wandb/run-20240425_211446-lkqptn01/run-lkqptn01.wandb b/wandb/run-20240425_211446-lkqptn01/run-lkqptn01.wandb
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/wandb/run-20240425_211446-lkqptn01/run-lkqptn01.wandb

--
Gitblit v1.9.1