zhifu gao
2024-04-25 fc68b5ffe453235294a561737d8e84bb6c1689a4
Dev gzf exp (#1661)

* rwkv 5

* rwkv v4

* rwkv v4

* rwkv

* rwkv

* update

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step

* resume from step
10个文件已修改
11个文件已添加
1560 ■■■■■ 已修改文件
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 22 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/espnet_samplers.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/index_ds.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/dataloader_entry.py 20 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/conformer_rwkv/decoder.py 46 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/cuda/wkv_cuda.cu 125 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/cuda/wkv_op.cpp 21 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/decoder.py 37 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/rwkv_v4.py 412 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/rwkv_v5.py 597 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/rwkv_v6.py 13 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/model_summary.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 51 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
wandb/debug-internal.log 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
wandb/debug.log 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
wandb/latest-run 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
wandb/run-20240425_211446-lkqptn01/files/config.yaml 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
wandb/run-20240425_211446-lkqptn01/logs/debug-internal.log 146 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
wandb/run-20240425_211446-lkqptn01/logs/debug.log 29 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
wandb/run-20240425_211446-lkqptn01/run-lkqptn01.wandb 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -9,6 +9,9 @@
model = AutoModel(model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch")
mm = model.model
for p in mm.parameters():
    print(f"{p.numel()}")
res = model.generate(input=wav_file)
print(res)
# [[beg1, end1], [beg2, end2], .., [begN, endN]]
funasr/bin/train.py
@@ -99,7 +99,7 @@
    if freeze_param is not None:
        if "," in freeze_param:
            freeze_param = eval(freeze_param)
        if isinstance(freeze_param, Sequence):
        if not isinstance(freeze_param, Sequence):
            freeze_param = (freeze_param,)
        logging.info("freeze_param is not None: %s", freeze_param)
        for t in freeze_param:
@@ -107,8 +107,9 @@
                if k.startswith(t + ".") or k == t:
                    logging.info(f"Setting {k}.requires_grad = False")
                    p.requires_grad = False
    if local_rank == 0:
        logging.info(f"{model_summary(model)}")
    logging.info(f"model info: {model_summary(model)}")
    if use_ddp:
        model = model.cuda(local_rank)
        model = DDP(
@@ -145,8 +146,6 @@
    else:
        model = model.to(device=kwargs.get("device", "cuda"))
    if local_rank == 0:
        logging.info(f"{model}")
    kwargs["device"] = next(model.parameters()).device
    # optim
@@ -182,7 +181,12 @@
    scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
    scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
    trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
    trainer.resume_checkpoint(
        model=model,
        optim=optim,
        scheduler=scheduler,
        scaler=scaler,
    )
    tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
    os.makedirs(tensorboard_dir, exist_ok=True)
@@ -197,8 +201,11 @@
    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
        time1 = time.perf_counter()
        for data_split_i in range(dataloader.data_split_num):
            dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i)
        for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
            dataloader_tr, dataloader_val = dataloader.build_iter(
                epoch, data_split_i=data_split_i, start_step=trainer.start_step
            )
            trainer.start_step = 0
            trainer.train_epoch(
                model=model,
                optim=optim,
@@ -213,7 +220,6 @@
            )
            
            torch.cuda.empty_cache()
        trainer.validate_epoch(
            model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
funasr/datasets/audio_datasets/espnet_samplers.py
@@ -41,6 +41,7 @@
        drop_last=False,
        is_training: bool = True,
        sort_size: int = 1024,
        start_step: int = 0,
        **kwargs,
    ):
@@ -70,7 +71,9 @@
        self.max_token_length = kwargs.get("max_token_length", 2048)
        self.min_token_length = kwargs.get("min_token_length", 0)
        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
        self.start_step = 0
        if self.start_step > 0:
            logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}")
        # super().__init__(dataset, num_replicas=num_replicas, rank=rank,
        #                  shuffle=shuffle, drop_last=drop_last)
@@ -142,7 +145,7 @@
        # Allocate the batches to the current rank
        start_idx = self.rank * batches_per_rank
        end_idx = start_idx + batches_per_rank
        rank_batches = buffer_batches[start_idx:end_idx]
        rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
        # Return an iterator over the batches for the current rank
        return iter(rank_batches)
funasr/datasets/audio_datasets/index_ds.py
funasr/datasets/dataloader_entry.py
@@ -14,14 +14,14 @@
        frontend=frontend,
        tokenizer=tokenizer,
        is_training=True,
        **kwargs.get("dataset_conf")
        **kwargs.get("dataset_conf"),
    )
    dataset_val = dataset_class(
        kwargs.get("valid_data_set_list"),
        frontend=frontend,
        tokenizer=tokenizer,
        is_training=False,
        **kwargs.get("dataset_conf")
        **kwargs.get("dataset_conf"),
    )
    # dataloader
@@ -55,14 +55,14 @@
            frontend=frontend,
            tokenizer=tokenizer,
            is_training=True,
            **kwargs.get("dataset_conf")
            **kwargs.get("dataset_conf"),
        )
        dataset_val = dataset_class(
            kwargs.get("valid_data_set_list"),
            frontend=frontend,
            tokenizer=tokenizer,
            is_training=False,
            **kwargs.get("dataset_conf")
            **kwargs.get("dataset_conf"),
        )
        self.dataset_tr = dataset_tr
@@ -76,7 +76,7 @@
        self.tokenizer = tokenizer
        self.kwargs = kwargs
    def build_iter(self, epoch=0, data_split_i=0, **kwargs):
    def build_iter(self, epoch=0, data_split_i=0, start_step=0, **kwargs):
        # reload dataset slice
        if self.data_split_num > 1:
@@ -87,7 +87,7 @@
                tokenizer=self.tokenizer,
                is_training=True,
                **self.kwargs.get("dataset_conf"),
                data_split_i=data_split_i
                data_split_i=data_split_i,
            )
        # dataloader
@@ -95,7 +95,9 @@
        batch_sampler_val = None
        if batch_sampler is not None:
            batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
            batch_sampler = batch_sampler_class(self.dataset_tr, **self.kwargs.get("dataset_conf"))
            batch_sampler = batch_sampler_class(
                self.dataset_tr, start_step=start_step, **self.kwargs.get("dataset_conf")
            )
            batch_sampler_val = batch_sampler_class(
                self.dataset_val, is_training=False, **self.kwargs.get("dataset_conf")
            )
@@ -121,14 +123,14 @@
        frontend=frontend,
        tokenizer=tokenizer,
        is_training=True,
        **kwargs.get("dataset_conf")
        **kwargs.get("dataset_conf"),
    )
    dataset_val = dataset_class(
        kwargs.get("valid_data_set_list"),
        frontend=frontend,
        tokenizer=tokenizer,
        is_training=False,
        **kwargs.get("dataset_conf")
        **kwargs.get("dataset_conf"),
    )
    return dataset_tr, dataset_val
funasr/models/conformer_rwkv/decoder.py
@@ -29,6 +29,11 @@
from funasr.register import tables
class LayerNorm(nn.LayerNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)
class DecoderLayer(nn.Module):
    """Single decoder layer module.
@@ -54,7 +59,7 @@
    def __init__(
        self,
        size,
        self_attn,
        # self_attn,
        src_attn,
        feed_forward,
        dropout_rate,
@@ -62,11 +67,12 @@
        concat_after=False,
        layer_id=None,
        args={},
        **kwargs,
    ):
        """Construct an DecoderLayer object."""
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn.to(torch.bfloat16)
        # self.self_attn = self_attn.to(torch.bfloat16)
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.norm1 = LayerNorm(size)
@@ -79,6 +85,22 @@
            self.concat_linear1 = nn.Linear(size + size, size)
            self.concat_linear2 = nn.Linear(size + size, size)
        self.layer_id = layer_id
        if args.get("version", "v4") == "v4":
            from funasr.models.sense_voice.rwkv_v4 import RWKVLayer
            from funasr.models.sense_voice.rwkv_v4 import RWKV_TimeMix as RWKV_Tmix
        elif args.get("version", "v5") == "v5":
            from funasr.models.sense_voice.rwkv_v5 import RWKVLayer
            from funasr.models.sense_voice.rwkv_v5 import RWKV_Tmix_x052 as RWKV_Tmix
        else:
            from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
            from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix
        # self.attn = RWKVLayer(args=args, layer_id=layer_id)
        self.self_attn = RWKV_Tmix(args, layer_id=layer_id)
        if args.get("datatype", "bf16") == "bf16":
            self.self_attn.to(torch.bfloat16)
            # self.norm1.to(torch.bfloat16)
        self.args = args
        self.ln0 = None
        if self.layer_id == 0 and not args.get("ln0", True):
            self.ln0 = LayerNorm(args.n_embd)
@@ -93,7 +115,15 @@
            print("init_rwkv")
            scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
            nn.init.constant_(self.norm1.weight, scale)
            nn.init.constant_(self.self_attn.ln2.weight, scale)
            # nn.init.constant_(self.self_attn.ln2.weight, scale)
        if args.get("init_rwkv", True):
            print("init_rwkv")
            nn.init.orthogonal_(self.self_attn.receptance.weight, gain=1)
            nn.init.orthogonal_(self.self_attn.key.weight, gain=0.1)
            nn.init.orthogonal_(self.self_attn.value.weight, gain=1)
            nn.init.orthogonal_(self.self_attn.gate.weight, gain=0.1)
            nn.init.zeros_(self.self_attn.output.weight)
    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
        """Compute decoded features.
@@ -117,6 +147,8 @@
        if self.layer_id == 0 and self.ln0 is not None:
            tgt = self.ln0(tgt)
        if self.args.get("datatype", "bf16") == "bf16":
            tgt = tgt.bfloat16()
        residual = tgt
        tgt = self.norm1(tgt)
@@ -132,7 +164,8 @@
            x = residual + self.dropout(self.self_attn(tgt, mask=tgt_q_mask))
            x = x[:, -1, :]
        if self.args.get("datatype", "bf16") == "bf16":
            x = x.to(torch.float32)
        # x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
        residual = x
@@ -370,17 +403,16 @@
            pos_enc_class=pos_enc_class,
            normalize_before=normalize_before,
        )
        from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
        # from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
        rwkv_cfg = kwargs.get("rwkv_cfg", {})
        args = OmegaConf.create(rwkv_cfg)
        # self.attn = RWKVLayer(args=args, layer_id=layer_id)
        attention_dim = encoder_output_size
        self.decoders = repeat(
            num_blocks,
            lambda lnum: DecoderLayer(
                attention_dim,
                RWKVLayer(args=args, layer_id=lnum),
                MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                dropout_rate,
funasr/models/sense_voice/cuda/wkv_cuda.cu
New file
@@ -0,0 +1,125 @@
#include <stdio.h>
#include <assert.h>
#define MIN_VALUE (-1e38)
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C,
                               const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
                               F *__restrict__ const _y) {
    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
    const int _b = idx / C;
    const int _c = idx % C;
    const int _offset = _b * T * C + _c;
    F u = _u[_c];
    F w = _w[_c];
    const F *__restrict__ const k = _k + _offset;
    const F *__restrict__ const v = _v + _offset;
    F *__restrict__ const y = _y + _offset;
    F p = 0, q = 0, o = MIN_VALUE;
    // p and q are running sums divided by exp(o) (to avoid overflows)
    for (int i = 0; i < T; i++) {
        const int ii = i * C;
        F no = max(o, u + k[ii]);
        F A = exp(o - no);
        F B = exp(u + k[ii] - no);
        y[ii] = (A * p + B * v[ii]) / (A * q + B);
        no = max(w + o, k[ii]);
        A = exp(w + o - no);
        B = exp(k[ii] - no);
        p = A * p + B * v[ii];
        q = A * q + B;
        o = no;
    }
}
template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C,
                                const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
                                F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
    const int _b = idx / C;
    const int _c = idx % C;
    const int _offset = _b * T * C + _c;
    F u = _u[_c];
    F w = _w[_c];
    const F *__restrict__ const k = _k + _offset;
    const F *__restrict__ const v = _v + _offset;
    const F *__restrict__ const gy = _gy + _offset;
    F *__restrict__ const gk = _gk + _offset;
    F *__restrict__ const gv = _gv + _offset;
    F y[Tmax], z[Tmax], zexp[Tmax];
    F gw = 0, gu = 0;
    F p = 0, q = 0;
    F dpdw = 0, dqdw = 0;
    F o = MIN_VALUE;
    for (int i = 0; i < T; i++) {
        const int ii = i * C;
        F no = max(o, k[ii] + u);
        F A = exp(o - no);
        F B = exp(k[ii] + u - no);
        F num = A * p + B * v[ii];
        F iden = 1 / (A * q + B);
        y[i] = num * iden;
        z[i] = iden;
        zexp[i] = k[ii] + u - no;
        gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
        gu += gy[ii] * (v[ii] - y[i]) * B * iden;
        no = max(w + o, k[ii]);
        A = exp(w + o - no);
        B = exp(k[ii] - no);
        dpdw = A * (p + dpdw);
        dqdw = A * (q + dqdw);
        p = A * p + B * v[ii];
        q = A * q + B;
        o = no;
    }
    F gp = 0, gq = 0;
    o = MIN_VALUE;
    for (int i = T - 1; i >= 0; i--) {
        const int ii = i * C;
        F A = gy[ii] * z[i] * exp(zexp[i]);
        F B = exp(k[ii] + o);
        gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
        gv[ii] = A + B * gp;
        F no = max(w + o, zexp[i] - k[ii] - u);
        A = exp(w + o - no);
        B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
        gp = A * gp + B;
        gq = A * gq - B * y[i];
        o = no;
    }
    // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
    const int _offsetBC = _b * C + _c;
    _gw[_offsetBC] += gw * _w[_c];
    _gu[_offsetBC] += gu;
}
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
    dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
    assert(B * C % threadsPerBlock.x == 0);
    dim3 numBlocks(B * C / threadsPerBlock.x);
    kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
}
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
    dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
    assert(B * C % threadsPerBlock.x == 0);
    dim3 numBlocks(B * C / threadsPerBlock.x);
    kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
}
funasr/models/sense_voice/cuda/wkv_op.cpp
New file
@@ -0,0 +1,21 @@
#include <torch/extension.h>
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
    cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
    cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "wkv forward");
    m.def("backward", &backward, "wkv backward");
}
TORCH_LIBRARY(wkv, m) {
    m.def("forward", forward);
    m.def("backward", backward);
}
funasr/models/sense_voice/decoder.py
@@ -156,7 +156,6 @@
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
from omegaconf import OmegaConf
@@ -168,9 +167,25 @@
        rwkv_cfg = kwargs.get("rwkv_cfg", {})
        args = OmegaConf.create(rwkv_cfg)
        self.attn = RWKVLayer(args=args, layer_id=layer_id)
        if args.get("datatype", "bf16") == "bf16":
            self.attn.to(torch.bfloat16)
        if args.get("version", "v4") == "v4":
            from funasr.models.sense_voice.rwkv_v4 import RWKVLayer
            from funasr.models.sense_voice.rwkv_v4 import RWKV_TimeMix as RWKV_Tmix
        elif args.get("version", "v5") == "v5":
            from funasr.models.sense_voice.rwkv_v5 import RWKVLayer
            from funasr.models.sense_voice.rwkv_v5 import RWKV_Tmix_x052 as RWKV_Tmix
        else:
            from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
            from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix
        # self.att = RWKVLayer(args=args, layer_id=layer_id)
        self.att = RWKV_Tmix(args, layer_id=layer_id)
        if args.get("init_rwkv", True):
            print("init_rwkv")
            nn.init.orthogonal_(self.att.receptance.weight, gain=1)
            nn.init.orthogonal_(self.att.key.weight, gain=0.1)
            nn.init.orthogonal_(self.att.value.weight, gain=1)
            nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
            nn.init.zeros_(self.att.output.weight)
        self.ln0 = None
        if layer_id == 0 and not args.get("ln0", True):
@@ -180,6 +195,7 @@
                layer_id = 0
                scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
                nn.init.constant_(self.ln0.weight, scale)
        self.layer_id = layer_id
        self.args = args
@@ -191,6 +207,11 @@
                print("init_rwkv")
                scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
                nn.init.constant_(self.ln1.weight, scale)
        if args.get("datatype", "bf16") == "bf16":
            self.att.to(torch.bfloat16)
            # if self.ln1 is not None:
            #     self.ln1.to(torch.bfloat16)
        self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
@@ -213,10 +234,14 @@
        if self.layer_id == 0 and self.ln0 is not None:
            x = self.ln0(x)
        if self.args.get("datatype", "bf16") == "bf16":
            x = x.bfloat16()
        if self.ln1 is None:
            x = x + self.attn(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
            x = x + self.att(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
        else:
            x = x + self.attn(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
            x = x + self.att(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
        if self.args.get("datatype", "bf16") == "bf16":
            x = x.to(torch.float32)
        if self.cross_attn:
            x = (
funasr/models/sense_voice/rwkv_v4.py
New file
@@ -0,0 +1,412 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os, math, gc, importlib
import torch
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
import torch.nn as nn
from torch.nn import functional as F
def __nop(ob):
    return ob
MyModule = nn.Module
MyFunction = __nop
if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1":
    MyModule = torch.jit.ScriptModule
    MyFunction = torch.jit.script_method
########################################################################################################
# CUDA Kernel
########################################################################################################
wkv_cuda = None
def load_rwkv_kernel(
    HEAD_SIZE: int = 64,
    RWKV_CTXLEN: int = 512,
    T_MAX: int = 512,
):
    from torch.utils.cpp_extension import load
    global wkv_cuda
    if wkv_cuda is not None:
        return
    absolute_file_path = os.path.abspath(__file__)
    cur_dir = os.path.dirname(absolute_file_path)
    wkv_cuda = load(
        name="wkv",
        sources=[f"{cur_dir}/cuda/wkv_op.cpp", f"{cur_dir}/cuda/wkv_cuda.cu"],
        verbose=True,
        extra_cuda_cflags=[
            "-res-usage",
            "--maxrregcount 60",
            "--use_fast_math",
            "-O3",
            "-Xptxas -O3",
            f"-DTmax={T_MAX}",
        ],
    )
class WKV(torch.autograd.Function):
    @staticmethod
    def forward(ctx, B, T, C, w, u, k, v):
        ctx.B = B
        ctx.T = T
        ctx.C = C
        # assert T <= T_MAX
        assert B * C % min(C, 1024) == 0
        if "32" in os.environ["RWKV_FLOAT_MODE"]:
            w = -torch.exp(w.contiguous())
            u = u.contiguous()
            k = k.contiguous()
            v = v.contiguous()
        else:
            w = -torch.exp(w.float().contiguous())
            u = u.float().contiguous()
            k = k.float().contiguous()
            v = v.float().contiguous()
        ctx.save_for_backward(w, u, k, v)
        y = torch.empty((B, T, C), device="cuda", memory_format=torch.contiguous_format)
        wkv_cuda.forward(B, T, C, w, u, k, v, y)
        if "32" in os.environ["RWKV_FLOAT_MODE"]:
            return y
        elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
            return y.half()
        elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
            return y.bfloat16()
    @staticmethod
    def backward(ctx, gy):
        B = ctx.B
        T = ctx.T
        C = ctx.C
        assert T <= T_MAX
        assert B * C % min(C, 1024) == 0
        w, u, k, v = ctx.saved_tensors
        gw = torch.zeros((B, C), device="cuda").contiguous()
        gu = torch.zeros((B, C), device="cuda").contiguous()
        gk = torch.zeros((B, T, C), device="cuda").contiguous()
        gv = torch.zeros((B, T, C), device="cuda").contiguous()
        if "32" in os.environ["RWKV_FLOAT_MODE"]:
            wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
        else:
            wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
        gw = torch.sum(gw, dim=0)
        gu = torch.sum(gu, dim=0)
        if "32" in os.environ["RWKV_FLOAT_MODE"]:
            return (None, None, None, gw, gu, gk, gv)
        elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
            return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
        elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
            return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
def RUN_CUDA(B, T, C, w, u, k, v):
    return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
class RWKV_TimeMix(torch.jit.ScriptModule):
    def __init__(self, config, layer_id):
        super().__init__()
        load_rwkv_kernel()
        self.layer_id = layer_id
        self.ctx_len = config.ctx_len
        self.n_embd = config.n_embd
        attn_sz = config.n_embd
        with torch.no_grad():  # fancy init
            ratio_0_to_1 = layer_id / (config.n_layer - 1)  # 0 to 1
            ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer)  # 1 to ~0
            # fancy time_decay
            decay_speed = torch.ones(attn_sz)
            for h in range(attn_sz):
                decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
            self.time_decay = nn.Parameter(decay_speed)
            # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
            # fancy time_first
            zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5
            self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
            # fancy time_mix
            x = torch.ones(1, 1, config.n_embd)
            for i in range(config.n_embd):
                x[0, 0, i] = i / config.n_embd
            self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
            self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
            self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
        self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
        self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
        self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
        self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
        self.key.scale_init = 0
        self.receptance.scale_init = 0
        self.output.scale_init = 0
    @torch.jit.script_method
    def jit_func(self, x):
        # Mix x with the previous timestep to produce xk, xv, xr
        xx = self.time_shift(x)
        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
        xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
        # Use xk, xv, xr to produce k, v, r
        k = self.key(xk)
        v = self.value(xv)
        r = self.receptance(xr)
        sr = torch.sigmoid(r)
        return sr, k, v
    def forward(self, x):
        B, T, C = x.size()  # x = (Batch,Time,Channel)
        sr, k, v = self.jit_func(x)
        rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
        rwkv = self.output(rwkv)
        return rwkv
class RWKV_ChannelMix(torch.jit.ScriptModule):
    def __init__(self, config, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
        with torch.no_grad():  # fancy init of time_mix
            ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer)  # 1 to ~0
            x = torch.ones(1, 1, config.n_embd)
            for i in range(config.n_embd):
                x[0, 0, i] = i / config.n_embd
            self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
            self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
        hidden_sz = 4 * config.n_embd
        self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
        self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
        self.value.scale_init = 0
        self.receptance.scale_init = 0
    @torch.jit.script_method
    def forward(self, x):
        xx = self.time_shift(x)
        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
        k = self.key(xk)
        k = torch.square(torch.relu(k))
        kv = self.value(k)
        rkv = torch.sigmoid(self.receptance(xr)) * kv
        return rkv
# class Block(nn.Module):
#     def __init__(self, args, layer_id):
#         super().__init__()
#         self.args = args
#         self.layer_id = layer_id
#
#         self.ln1 = nn.LayerNorm(args.n_embd)
#         self.ln2 = nn.LayerNorm(args.n_embd)
#
#         if self.layer_id == 0:
#             self.ln0 = nn.LayerNorm(args.n_embd)
#
#         self.att = RWKV_Tmix_x060(args, layer_id)
#
#         self.ffn = RWKV_CMix_x060(args, layer_id)
#
#         if args.dropout > 0:
#             self.drop0 = nn.Dropout(p=args.dropout)
#             self.drop1 = nn.Dropout(p=args.dropout)
#
#     def forward(self, x, x_emb=None):
#         args = self.args
#         B, T, C = x.size()
#         if self.layer_id == 0:
#             x = self.ln0(x)
#
#         if self.args.dropout == 0:
#             if self.layer_id == 0 and args.pre_ffn > 0:
#                 x = x + self.ffnPre(self.ln1(x))
#             else:
#                 x = x + self.att(self.ln1(x))
#             x = x + self.ffn(self.ln2(x))
#         else:
#             if self.layer_id == 0 and args.pre_ffn > 0:
#                 x = self.drop0(x + self.ffnPre(self.ln1(x)))
#             else:
#                 x = self.drop0(x + self.att(self.ln1(x)))
#             x = self.drop1(x + self.ffn(self.ln2(x)))
#
#         return x
class RWKVLayer(nn.Module):
    def __init__(self, args, layer_id):
        super().__init__()
        self.args = args
        self.layer_id = layer_id
        if args.dim_ffn is None:
            args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
        self.ln0 = None
        if self.layer_id == 0 and args.get("ln0", True):
            self.ln0 = nn.LayerNorm(args.n_embd)
        self.ln1 = None
        if args.get("ln1", True):
            self.ln1 = nn.LayerNorm(args.n_embd)
        self.ln2 = nn.LayerNorm(args.n_embd)
        self.att = RWKV_TimeMix(args, layer_id)
        self.ffn = RWKV_ChannelMix(args, layer_id)
        if args.dropout > 0:
            self.drop0 = nn.Dropout(p=args.dropout)
            self.drop1 = nn.Dropout(p=args.dropout)
        # init
        if args.get("init_rwkv", True):
            print("init_rwkv")
            nn.init.orthogonal_(self.att.receptance.weight, gain=1)
            nn.init.orthogonal_(self.att.key.weight, gain=0.1)
            nn.init.orthogonal_(self.att.value.weight, gain=1)
            nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
            nn.init.zeros_(self.att.output.weight)
            nn.init.orthogonal_(self.ffn.key.weight, gain=1)
            nn.init.zeros_(self.ffn.value.weight)
            nn.init.zeros_(self.ffn.receptance.weight)
            scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
            nn.init.constant_(self.ln2.weight, scale)
            if self.ln0 is not None:
                nn.init.constant_(self.ln0.weight, scale)
            if self.ln1 is not None:
                nn.init.constant_(self.ln1.weight, scale)
    def forward(self, x, x_emb=None, mask=None, **kwargs):
        args = self.args
        if args.get("datatype", "bf16") == "bf16":
            x = x.bfloat16()
        B, T, C = x.size()
        if self.layer_id == 0 and self.ln0 is not None:
            x = self.ln0(x)
        if self.args.dropout == 0:
            if self.ln1 is None:
                x = x + self.att(x)
            else:
                x = x + self.att(self.ln1(x))
            x = x + self.ffn(self.ln2(x))
        else:
            if self.ln1 is None:
                x = self.drop0(x + self.att(x))
            else:
                x = self.drop0(x + self.att(self.ln1(x)))
            x = self.drop1(x + self.ffn(self.ln2(x)))
        if args.get("datatype", "bf16") == "bf16":
            x = x.to(torch.float32)
        return x
class RWKV(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        if not hasattr(args, "dim_att"):
            args.dim_att = args.n_embd
        if not hasattr(args, "dim_ffn"):
            if "-f4" in os.environ["RWKV_MY_TESTING"]:
                args.dim_ffn = int((args.n_embd * 4) // 32 * 32)
            else:
                args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)  # default = 3.5x emb size
        if not hasattr(args, "tiny_att_layer"):
            args.tiny_att_layer = -1
        if not hasattr(args, "tiny_att_dim"):
            args.tiny_att_dim = -1
        assert args.n_embd % 32 == 0
        assert args.dim_att % 32 == 0
        assert args.dim_ffn % 32 == 0
        self.emb = nn.Embedding(args.vocab_size, args.n_embd)
        self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
        self.ln_out = nn.LayerNorm(args.n_embd)
        self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
        if args.dropout > 0:
            self.drop0 = nn.Dropout(p=args.dropout)
    def forward(self, idx):
        args = self.args
        B, T = idx.size()
        assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
        x = self.emb(idx)
        x_emb = x
        if args.dropout > 0:
            x = self.drop0(x)
        if args.tiny_att_dim > 0:
            for block in self.blocks:
                if args.grad_cp == 1:
                    x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
                else:
                    x = block(x, x_emb)
        else:
            for block in self.blocks:
                if args.grad_cp == 1:
                    x = deepspeed.checkpointing.checkpoint(block, x)
                else:
                    x = block(x)
        x = self.ln_out(x)
        if args.head_qk > 0:
            q = self.head_q(x)[:, :T, :]
            k = self.head_k(x)[:, :T, :]
            c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
            c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
            if "32" in os.environ["RWKV_FLOAT_MODE"]:
                c = c @ F.one_hot(idx, num_classes=args.vocab_size)
            elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
                c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
            elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
                c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
            x = self.head(x) + c
        else:
            x = self.head(x)
        return x
funasr/models/sense_voice/rwkv_v5.py
New file
@@ -0,0 +1,597 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os, math, gc, importlib
import torch
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
import torch.nn as nn
from torch.nn import functional as F
def __nop(ob):
    return ob
MyModule = nn.Module
MyFunction = __nop
if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1":
    MyModule = torch.jit.ScriptModule
    MyFunction = torch.jit.script_method
########################################################################################################
# CUDA Kernel
########################################################################################################
wkv5_cuda = None
def load_rwkv_kernel(
    HEAD_SIZE: int = 64,
    RWKV_CTXLEN: int = 512,
):
    from torch.utils.cpp_extension import load
    global wkv5_cuda
    if wkv5_cuda is not None:
        return
    absolute_file_path = os.path.abspath(__file__)
    cur_dir = os.path.dirname(absolute_file_path)
    wkv5_cuda = load(
        name="wkv5",
        sources=[f"{cur_dir}/cuda/wkv5_op.cpp", f"{cur_dir}/cuda/wkv5_cuda.cu"],
        verbose=True,
        extra_cuda_cflags=[
            "-res-usage",
            "--use_fast_math",
            "-O3",
            "-Xptxas -O3",
            "--extra-device-vectorization",
            f"-D_N_={HEAD_SIZE}",
        ],
    )
# dtype = torch.float
dtype = torch.bfloat16
class WKV_5(torch.autograd.Function):
    @staticmethod
    def forward(ctx, B, T, C, H, r, k, v, w, u):
        with torch.no_grad():
            # assert r.dtype == torch.bfloat16
            # assert k.dtype == torch.bfloat16
            # assert v.dtype == torch.bfloat16
            # assert w.dtype == torch.bfloat16
            # assert u.dtype == torch.bfloat16
            # assert HEAD_SIZE == C // H
            ctx.B = B
            ctx.T = T
            ctx.C = C
            ctx.H = H
            assert r.is_contiguous()
            assert k.is_contiguous()
            assert v.is_contiguous()
            assert w.is_contiguous()
            assert u.is_contiguous()
            ew = (-torch.exp(w.float())).contiguous()
            eew = (torch.exp(ew)).contiguous()
            ctx.save_for_backward(r, k, v, eew, ew, u)
            y = torch.empty(
                (B, T, C),
                device=r.device,
                dtype=torch.bfloat16,
                memory_format=torch.contiguous_format,
            )  # .uniform_(-1, 1)
            wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y)
            return y
    @staticmethod
    def backward(ctx, gy):
        with torch.no_grad():
            assert gy.dtype == torch.bfloat16
            B = ctx.B
            T = ctx.T
            C = ctx.C
            H = ctx.H
            assert gy.is_contiguous()
            r, k, v, eew, ew, u = ctx.saved_tensors
            gr = torch.empty(
                (B, T, C),
                device=gy.device,
                requires_grad=False,
                dtype=torch.bfloat16,
                memory_format=torch.contiguous_format,
            )  # .uniform_(-1, 1)
            gk = torch.empty(
                (B, T, C),
                device=gy.device,
                requires_grad=False,
                dtype=torch.bfloat16,
                memory_format=torch.contiguous_format,
            )  # .uniform_(-1, 1)
            gv = torch.empty(
                (B, T, C),
                device=gy.device,
                requires_grad=False,
                dtype=torch.bfloat16,
                memory_format=torch.contiguous_format,
            )  # .uniform_(-1, 1)
            gw = torch.empty(
                (B, C),
                device=gy.device,
                requires_grad=False,
                dtype=torch.bfloat16,
                memory_format=torch.contiguous_format,
            )  # .uniform_(-1, 1)
            gu = torch.empty(
                (B, C),
                device=gy.device,
                requires_grad=False,
                dtype=torch.bfloat16,
                memory_format=torch.contiguous_format,
            )  # .uniform_(-1, 1)
            wkv5_cuda.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu)
            gw = torch.sum(gw, 0).view(H, C // H)
            gu = torch.sum(gu, 0).view(H, C // H)
            return (None, None, None, None, gr, gk, gv, gw, gu)
def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u):
    return WKV_5.apply(B, T, C, H, r, k, v, w, u)
class RWKV_Tmix_x052(MyModule):
    def __init__(self, args, layer_id):
        super().__init__()
        self.args = args
        load_rwkv_kernel(args.head_size_a, args.ctx_len)
        self.layer_id = layer_id
        self.head_size = args.head_size_a
        # assert HEAD_SIZE == self.head_size  # change HEAD_SIZE to match args.head_size_a
        self.n_head = args.dim_att // self.head_size
        assert args.dim_att % self.n_head == 0
        self.head_size_divisor = args.head_size_divisor
        with torch.no_grad():
            ratio_0_to_1 = layer_id / (args.n_layer - 1)  # 0 to 1
            ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)  # 1 to ~0
            ddd = torch.ones(1, 1, args.n_embd)
            for i in range(args.n_embd):
                ddd[0, 0, i] = i / args.n_embd
            # fancy time_mix
            self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
            self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
            self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
            self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
            # fancy time_decay
            decay_speed = torch.ones(args.dim_att)
            for n in range(args.dim_att):
                decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
            self.time_decay = nn.Parameter(decay_speed.reshape(self.n_head, self.head_size))
            # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
            tmp = torch.zeros(args.dim_att)
            for n in range(args.dim_att):
                zigzag = ((n + 1) % 3 - 1) * 0.1
                tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
            self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
        self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
        self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
        self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
        self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
        self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
        self.ln_x = nn.GroupNorm(self.n_head, args.dim_att)
    @MyFunction
    def jit_func(self, x):
        B, T, C = x.size()
        xx = self.time_shift(x)  # Mix x with the previous timestep to produce xk, xv, xr
        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
        xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
        xg = x * self.time_mix_g + xx * (1 - self.time_mix_g)
        r = self.receptance(xr)
        k = self.key(xk)
        v = self.value(xv)
        g = F.silu(self.gate(xg))
        return r, k, v, g
    @MyFunction
    def jit_func_2(self, x, g):
        B, T, C = x.size()
        x = x.view(B * T, C)
        x = self.ln_x(x / self.head_size_divisor).view(B, T, C)
        x = self.output(x * g)
        return x
    def forward(self, x, **kwargs):
        B, T, C = x.size()
        H = self.n_head
        r, k, v, g = self.jit_func(x)
        x = RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa)
        return self.jit_func_2(x, g)
#
# class RWKV_Tmix_x060(MyModule):
#     def __init__(self, args, layer_id):
#         super().__init__()
#         self.args = args
#
#         load_rwkv_kernel(args.head_size_a, args.ctx_len)
#
#         self.layer_id = layer_id
#
#         self.head_size = args.head_size_a
#         self.n_head = args.dim_att // self.head_size
#         assert args.dim_att % self.n_head == 0
#
#         with torch.no_grad():
#             ratio_0_to_1 = layer_id / (args.n_layer - 1)  # 0 to 1
#             ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)  # 1 to ~0
#             ddd = torch.ones(1, 1, args.n_embd)
#             for i in range(args.n_embd):
#                 ddd[0, 0, i] = i / args.n_embd
#
#             # fancy time_mix
#             self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
#             self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
#             self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
#             self.time_maa_v = nn.Parameter(
#                 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
#             )
#             self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
#             self.time_maa_g = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
#
#             D_MIX_LORA = 32  # generate TIME_MIX for w,k,v,r,g
#             self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5))
#             self.time_maa_w2 = nn.Parameter(
#                 torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01)
#             )
#
#             # fancy time_decay
#             decay_speed = torch.ones(args.dim_att)
#             for n in range(args.dim_att):
#                 decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
#             self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
#
#             D_DECAY_LORA = 64
#             self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA))
#             self.time_decay_w2 = nn.Parameter(
#                 torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01)
#             )
#
#             tmp = torch.zeros(args.dim_att)
#             for n in range(args.dim_att):
#                 zigzag = ((n + 1) % 3 - 1) * 0.1
#                 tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
#
#             self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
#
#         self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
#         self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
#         self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
#
#         self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
#         self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
#         self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
#         self.ln_x = nn.GroupNorm(
#             self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2)
#         )
#
#     @MyFunction
#     def jit_func(self, x):
#         B, T, C = x.size()
#
#         xx = self.time_shift(x) - x
#
#         xxx = x + xx * self.time_maa_x
#         xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
#         xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
#         mw, mk, mv, mr, mg = xxx.unbind(dim=0)
#
#         xw = x + xx * (self.time_maa_w + mw)
#         xk = x + xx * (self.time_maa_k + mk)
#         xv = x + xx * (self.time_maa_v + mv)
#         xr = x + xx * (self.time_maa_r + mr)
#         xg = x + xx * (self.time_maa_g + mg)
#
#         r = self.receptance(xr)
#         k = self.key(xk)
#         v = self.value(xv)
#         g = F.silu(self.gate(xg))
#
#         ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
#         w = self.time_decay + ww
#
#         return r, k, v, g, w
#
#     @MyFunction
#     def jit_func_2(self, x, g):
#         B, T, C = x.size()
#         x = x.view(B * T, C)
#
#         x = self.ln_x(x).view(B, T, C)
#         x = self.output(x * g)
#         return x
#
#     def forward(self, x):
#         B, T, C = x.size()
#         H = self.n_head
#
#         r, k, v, g, w = self.jit_func(x)
#         x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa)
#
#         return self.jit_func_2(x, g)
class RWKV_CMix_x052(MyModule):
    def __init__(self, args, layer_id):
        super().__init__()
        self.args = args
        self.layer_id = layer_id
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
        with torch.no_grad():  # fancy init of time_mix
            ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)  # 1 to ~0
            ddd = torch.ones(1, 1, args.n_embd)
            for i in range(args.n_embd):
                ddd[0, 0, i] = i / args.n_embd
            self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
            self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
        self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
        self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
        self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
    @MyFunction
    def forward(self, x):
        xx = self.time_shift(x)
        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
        k = self.key(xk)
        k = torch.relu(k) ** 2
        kv = self.value(k)
        return torch.sigmoid(self.receptance(xr)) * kv
# class RWKV_CMix_x060(MyModule):
#     def __init__(self, args, layer_id):
#         super().__init__()
#         self.args = args
#         self.layer_id = layer_id
#         self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
#
#         with torch.no_grad():  # fancy init of time_mix
#             ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)  # 1 to ~0
#             ddd = torch.ones(1, 1, args.n_embd)
#             for i in range(args.n_embd):
#                 ddd[0, 0, i] = i / args.n_embd
#             self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
#             self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
#
#         self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
#         self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
#         self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
#
#     @MyFunction
#     def forward(self, x):
#         xx = self.time_shift(x) - x
#         xk = x + xx * self.time_maa_k
#         xr = x + xx * self.time_maa_r
#
#         k = self.key(xk)
#         k = torch.relu(k) ** 2
#         kv = self.value(k)
#         return torch.sigmoid(self.receptance(xr)) * kv
# class Block(nn.Module):
#     def __init__(self, args, layer_id):
#         super().__init__()
#         self.args = args
#         self.layer_id = layer_id
#
#         self.ln1 = nn.LayerNorm(args.n_embd)
#         self.ln2 = nn.LayerNorm(args.n_embd)
#
#         if self.layer_id == 0:
#             self.ln0 = nn.LayerNorm(args.n_embd)
#
#         self.att = RWKV_Tmix_x060(args, layer_id)
#
#         self.ffn = RWKV_CMix_x060(args, layer_id)
#
#         if args.dropout > 0:
#             self.drop0 = nn.Dropout(p=args.dropout)
#             self.drop1 = nn.Dropout(p=args.dropout)
#
#     def forward(self, x, x_emb=None):
#         args = self.args
#         B, T, C = x.size()
#         if self.layer_id == 0:
#             x = self.ln0(x)
#
#         if self.args.dropout == 0:
#             if self.layer_id == 0 and args.pre_ffn > 0:
#                 x = x + self.ffnPre(self.ln1(x))
#             else:
#                 x = x + self.att(self.ln1(x))
#             x = x + self.ffn(self.ln2(x))
#         else:
#             if self.layer_id == 0 and args.pre_ffn > 0:
#                 x = self.drop0(x + self.ffnPre(self.ln1(x)))
#             else:
#                 x = self.drop0(x + self.att(self.ln1(x)))
#             x = self.drop1(x + self.ffn(self.ln2(x)))
#
#         return x
class RWKVLayer(nn.Module):
    def __init__(self, args, layer_id):
        super().__init__()
        self.args = args
        self.layer_id = layer_id
        if args.dim_ffn is None:
            args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
        self.ln0 = None
        if self.layer_id == 0 and args.get("ln0", True):
            self.ln0 = nn.LayerNorm(args.n_embd)
        self.ln1 = None
        if args.get("ln1", True):
            self.ln1 = nn.LayerNorm(args.n_embd)
        self.att = RWKV_Tmix_x052(args, layer_id)
        self.ln2 = None
        self.ffn = None
        if args.get("use_rwkv_ffn", True):
            self.ln2 = nn.LayerNorm(args.n_embd)
            self.ffn = RWKV_CMix_x052(args, layer_id)
        if args.dropout > 0:
            self.drop0 = nn.Dropout(p=args.dropout)
            self.drop1 = nn.Dropout(p=args.dropout)
        # init
        if args.get("init_rwkv", True):
            print("init_rwkv")
            nn.init.orthogonal_(self.att.receptance.weight, gain=1)
            nn.init.orthogonal_(self.att.key.weight, gain=0.1)
            nn.init.orthogonal_(self.att.value.weight, gain=1)
            nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
            nn.init.zeros_(self.att.output.weight)
            nn.init.orthogonal_(self.ffn.key.weight, gain=1)
            nn.init.zeros_(self.ffn.value.weight)
            nn.init.zeros_(self.ffn.receptance.weight)
            scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
            nn.init.constant_(self.ln2.weight, scale)
            if self.ln0 is not None:
                nn.init.constant_(self.ln0.weight, scale)
            if self.ln1 is not None:
                nn.init.constant_(self.ln1.weight, scale)
    def forward(self, x, x_emb=None, mask=None, **kwargs):
        args = self.args
        if args.get("datatype", "bf16") == "bf16":
            x = x.bfloat16()
        B, T, C = x.size()
        if self.layer_id == 0 and self.ln0 is not None:
            x = self.ln0(x)
        if self.args.dropout == 0:
            if self.ln1 is None:
                x = x + self.att(x)
            else:
                x = x + self.att(self.ln1(x))
            if self.ffn is not None:
                x = x + self.ffn(self.ln2(x))
        else:
            if self.ln1 is None:
                x = self.drop0(x + self.att(x))
            else:
                x = self.drop0(x + self.att(self.ln1(x)))
            if self.ffn is not None:
                x = self.drop1(x + self.ffn(self.ln2(x)))
        if args.get("datatype", "bf16") == "bf16":
            x = x.to(torch.float32)
        return x
# class RWKV(nn.Module):
#     def __init__(self, args):
#         super().__init__()
#         self.args = args
#         if not hasattr(args, "dim_att"):
#             args.dim_att = args.n_embd
#         if not hasattr(args, "dim_ffn"):
#             if "-f4" in os.environ["RWKV_MY_TESTING"]:
#                 args.dim_ffn = int((args.n_embd * 4) // 32 * 32)
#             else:
#                 args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)  # default = 3.5x emb size
#         if not hasattr(args, "tiny_att_layer"):
#             args.tiny_att_layer = -1
#         if not hasattr(args, "tiny_att_dim"):
#             args.tiny_att_dim = -1
#         assert args.n_embd % 32 == 0
#         assert args.dim_att % 32 == 0
#         assert args.dim_ffn % 32 == 0
#
#         self.emb = nn.Embedding(args.vocab_size, args.n_embd)
#
#         self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
#
#         self.ln_out = nn.LayerNorm(args.n_embd)
#         self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
#
#         if args.dropout > 0:
#             self.drop0 = nn.Dropout(p=args.dropout)
#
#     def forward(self, idx):
#         args = self.args
#         B, T = idx.size()
#         assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
#
#         x = self.emb(idx)
#         x_emb = x
#
#         if args.dropout > 0:
#             x = self.drop0(x)
#         if args.tiny_att_dim > 0:
#             for block in self.blocks:
#                 if args.grad_cp == 1:
#                     x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
#                 else:
#                     x = block(x, x_emb)
#         else:
#             for block in self.blocks:
#                 if args.grad_cp == 1:
#                     x = deepspeed.checkpointing.checkpoint(block, x)
#                 else:
#                     x = block(x)
#
#         x = self.ln_out(x)
#
#         if args.head_qk > 0:
#             q = self.head_q(x)[:, :T, :]
#             k = self.head_k(x)[:, :T, :]
#             c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
#             c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
#
#             if "32" in os.environ["RWKV_FLOAT_MODE"]:
#                 c = c @ F.one_hot(idx, num_classes=args.vocab_size)
#             elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
#                 c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
#             elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
#                 c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
#
#             x = self.head(x) + c
#         else:
#             x = self.head(x)
#
#         return x
funasr/models/sense_voice/rwkv_v6.py
@@ -244,7 +244,7 @@
        x = self.output(x * g)
        return x
    def forward(self, x):
    def forward(self, x, **kwargs):
        B, T, C = x.size()
        H = self.n_head
@@ -341,10 +341,13 @@
        self.ln1 = None
        if args.get("ln1", True):
            self.ln1 = nn.LayerNorm(args.n_embd)
        self.ln2 = nn.LayerNorm(args.n_embd)
        self.att = RWKV_Tmix_x060(args, layer_id)
        self.ln2 = None
        self.ffn = None
        if args.get("use_rwkv_ffn", True):
            self.ln2 = nn.LayerNorm(args.n_embd)
        self.ffn = RWKV_CMix_x060(args, layer_id)
        if args.dropout > 0:
@@ -364,11 +367,13 @@
            nn.init.zeros_(self.ffn.value.weight)
            nn.init.zeros_(self.ffn.receptance.weight)
            scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
            nn.init.constant_(self.ln2.weight, scale)
            if self.ln0 is not None:
                nn.init.constant_(self.ln0.weight, scale)
            if self.ln1 is not None:
                nn.init.constant_(self.ln1.weight, scale)
            if self.ln2 is not None:
                nn.init.constant_(self.ln2.weight, scale)
    def forward(self, x, x_emb=None, mask=None, **kwargs):
@@ -384,12 +389,14 @@
                x = x + self.att(x)
            else:
                x = x + self.att(self.ln1(x))
            if self.ffn is not None:
            x = x + self.ffn(self.ln2(x))
        else:
            if self.ln1 is None:
                x = self.drop0(x + self.att(x))
            else:
                x = self.drop0(x + self.att(self.ln1(x)))
            if self.ffn is not None:
            x = self.drop1(x + self.ffn(self.ln2(x)))
        if args.get("datatype", "bf16") == "bf16":
funasr/train_utils/model_summary.py
@@ -47,6 +47,8 @@
def model_summary(model: torch.nn.Module) -> str:
    message = "Model structure:\n"
    message += str(model)
    # for p in model.parameters():
    #     print(f"{p.numel()}")
    tot_params = sum(p.numel() for p in model.parameters())
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
funasr/train_utils/trainer.py
@@ -15,6 +15,11 @@
from funasr.train_utils.average_nbest_models import average_checkpoints
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
try:
    import wandb
except:
    wandb = None
@contextmanager
def maybe_autocast(enabled):
@@ -109,7 +114,20 @@
        self.val_loss_step_or_eoch = {}
        
        self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
        self.start_data_split_i = 0
        self.start_step = 0
        self.use_wandb = kwargs.get("use_wandb", False)
        if self.use_wandb:
            wandb.login(key=kwargs.get("wandb_token"))
            wandb.init(
                config=kwargs,
                project=kwargs.get("wandb_project", "my_project"),
                entity=kwargs.get("wandb_team", "my_team"),
                name=kwargs.get("wandb_exp_name", "my_exp"),
                dir=output_dir,
                job_type="training",
                reinit=True,
            )
    def save_checkpoint(
        self,
@@ -142,6 +160,7 @@
                "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
                "best_step_or_epoch": self.best_step_or_epoch,
                "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
                "step": step,
            }
            if hasattr(model, "module"):
                state["state_dict"] = model.module.state_dict()
@@ -268,6 +287,13 @@
                self.best_step_or_epoch = (
                    checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
                )
                self.start_data_split_i = (
                    checkpoint["start_data_split_i"] if "start_data_split_i" in checkpoint else 0
                )
                self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
                self.start_step = checkpoint["step"] if "step" in checkpoint else 0
                self.start_step = 0 if self.start_step is None else self.start_step
                model.to(self.device)
                print(f"Checkpoint loaded successfully from '{ckpt}'")
            else:
@@ -598,7 +624,7 @@
            acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
            description = (
                f"{tag}, "
                f"rank: {self.local_rank}, "
                f"rank: {self.rank}, "
                f"epoch: {epoch}/{self.max_epoch}, "
                f"data_slice: {data_split_i}/{data_split_num}, "
                f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
@@ -613,17 +639,28 @@
            )
            logging.info(description)
            description_dict = {
                f"rank{self.rank}_loss/{tag}": loss,
                f"rank{self.rank}_lr/{tag}": lr,
            }
            if writer is not None:
                writer.add_scalar(f"rank{self.local_rank}_loss/{tag}", loss, self.batch_total)
                writer.add_scalar(f"rank{self.local_rank}_lr/{tag}", lr, self.batch_total)
                writer.add_scalar(f"rank{self.local_rank}_lr/{tag}", lr, self.batch_total)
                writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, self.batch_total)
                writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, self.batch_total)
                for key, var in stats.items():
                    writer.add_scalar(
                        f"stats_rank{self.local_rank}_{key}/{tag}", var.item(), self.batch_total
                        f"stats_rank{self.rank}_{key}/{tag}", var.item(), self.batch_total
                    )
                    description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = var.item()
                for key, var in speed_stats.items():
                    writer.add_scalar(
                        f"stats_rank{self.local_rank}_{key}/{tag}", eval(var), self.batch_total
                        f"stats_rank{self.rank}_{key}/{tag}", eval(var), self.batch_total
                    )
                    description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = eval(var)
            if self.use_wandb and wandb is not None:
                wandb.log(
                    description_dict,
                    setp=self.batch_total,
                    )
    def close(self, writer=None):
wandb/debug-internal.log
New file
@@ -0,0 +1 @@
run-20240425_211446-lkqptn01/logs/debug-internal.log
wandb/debug.log
New file
@@ -0,0 +1 @@
run-20240425_211446-lkqptn01/logs/debug.log
wandb/latest-run
New file
@@ -0,0 +1 @@
run-20240425_211446-lkqptn01
wandb/run-20240425_211446-lkqptn01/files/config.yaml
New file
@@ -0,0 +1,26 @@
wandb_version: 1
a:
  desc: null
  value: 1
_wandb:
  desc: null
  value:
    python_version: 3.8.15
    cli_version: 0.16.6
    is_jupyter_run: false
    is_kaggle_kernel: false
    start_time: 1714050886.0
    t:
      1:
      - 55
      - 105
      3:
      - 13
      - 16
      - 23
      4: 3.8.15
      5: 0.16.6
      8:
      - 5
      13: darwin-arm64
wandb/run-20240425_211446-lkqptn01/logs/debug-internal.log
New file
@@ -0,0 +1,146 @@
2024-04-25 21:14:46,163 INFO    StreamThr :90024 [internal.py:wandb_internal():86] W&B internal server running at pid: 90024, started at: 2024-04-25 21:14:46.156788
2024-04-25 21:14:46,164 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status
2024-04-25 21:14:46,172 INFO    WriterThread:90024 [datastore.py:open_for_write():87] open: ./wandb/run-20240425_211446-lkqptn01/run-lkqptn01.wandb
2024-04-25 21:14:46,173 DEBUG   SenderThread:90024 [sender.py:send():379] send: header
2024-04-25 21:14:46,284 DEBUG   SenderThread:90024 [sender.py:send():379] send: run
2024-04-25 21:14:46,828 ERROR   SenderThread:90024 [internal_api.py:execute():373] 403 response executing GraphQL.
2024-04-25 21:14:46,828 ERROR   SenderThread:90024 [internal_api.py:execute():374] {"errors":[{"message":"permission denied","path":["upsertBucket"],"extensions":{"code":"PERMISSION_ERROR"}}],"data":{"upsertBucket":null}}
2024-04-25 21:14:46,828 ERROR   SenderThread:90024 [sender.py:send_run():971] It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
Traceback (most recent call last):
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/lib/retry.py", line 131, in __call__
    result = self._call_fn(*args, **kwargs)
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/internal_api.py", line 369, in execute
    return self.client.execute(*args, **kwargs)  # type: ignore
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py", line 52, in execute
    result = self._get_result(document, *args, **kwargs)
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py", line 60, in _get_result
    return self.transport.execute(document, *args, **kwargs)
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/lib/gql_request.py", line 59, in execute
    request.raise_for_status()
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/requests/models.py", line 1021, in raise_for_status
    raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 403 Client Error: Forbidden for url: https://api.wandb.ai/graphql
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/sender.py", line 969, in send_run
    server_run = self._init_run(run, config_value_dict)
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/sender.py", line 1014, in _init_run
    server_run, inserted, server_messages = self._api.upsert_run(
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/apis/normalize.py", line 73, in wrapper
    raise err
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/apis/normalize.py", line 41, in wrapper
    return func(*args, **kwargs)
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/internal_api.py", line 2217, in upsert_run
    response = self.gql(
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/internal/internal_api.py", line 341, in gql
    ret = self._retry_gql(
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/lib/retry.py", line 147, in __call__
    retry_timedelta_triggered = check_retry_fn(e)
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/util.py", line 968, in check_retry_fn
    return fallback_retry_fn(e)
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/util.py", line 910, in no_retry_auth
    raise CommError(
wandb.errors.CommError: It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
2024-04-25 21:14:51,848 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:14:56,865 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:15:01,882 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:15:06,910 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:15:11,930 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:15:16,954 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:15:21,968 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:15:26,992 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:15:32,007 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:15:37,026 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:15:42,047 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:15:47,060 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:15:52,088 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:15:57,107 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:16:02,125 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:16:07,154 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:16:12,184 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:16:17,193 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:16:22,218 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:16:27,236 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:16:32,248 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:16:37,267 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:16:42,285 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:16:47,302 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:16:52,323 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:16:57,345 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:17:02,359 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:17:07,370 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:17:12,391 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:17:17,410 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:17:22,427 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:17:27,449 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:17:32,472 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:17:37,492 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:17:42,514 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:17:47,529 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:17:52,547 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:17:57,565 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:18:02,584 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:18:07,605 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:18:12,630 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:18:17,649 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:18:22,670 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:18:27,689 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:18:32,706 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:18:37,721 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:18:42,738 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:18:47,760 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:18:52,776 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:18:57,797 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:19:02,816 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:19:07,831 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:19:12,849 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:19:17,886 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:19:22,903 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:19:27,914 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:19:32,935 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:19:37,956 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:19:42,977 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:19:47,997 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:19:53,013 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:19:58,037 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:20:03,054 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:20:08,074 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:20:13,095 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:20:18,115 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:20:23,136 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:20:28,158 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:20:33,175 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:20:38,195 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:20:43,216 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:20:48,237 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:20:53,254 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:20:58,276 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:21:03,293 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:21:08,306 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:21:13,328 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:21:18,349 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:21:23,372 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:21:28,393 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:21:33,418 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:21:38,428 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:21:43,452 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:21:48,475 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:21:53,493 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:21:58,512 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:22:03,530 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:22:08,546 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:22:13,563 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:22:18,586 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:22:23,604 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:22:28,621 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:22:33,641 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:22:38,688 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:22:43,704 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:22:48,722 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:22:53,742 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:22:58,756 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:23:03,775 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:23:08,798 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
2024-04-25 21:23:13,815 DEBUG   HandlerThread:90024 [handler.py:handle_request():146] handle_request: status_report
wandb/run-20240425_211446-lkqptn01/logs/debug.log
New file
@@ -0,0 +1,29 @@
2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Current SDK version is 0.16.6
2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Configure stats pid to 89997
2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Loading settings from /Users/zhifu/.config/wandb/settings
2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Loading settings from /Users/zhifu/funasr1.0/wandb/settings
2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Loading settings from environment variables: {}
2024-04-25 21:14:46,147 WARNING MainThread:89997 [wandb_setup.py:_flush():76] Could not find program at <input>
2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Inferring run settings from compute environment: {'program_relpath': None, 'program': '<input>'}
2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Applying login settings: {'api_key': '***REDACTED***'}
2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Applying login settings: {'api_key': '***REDACTED***'}
2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_setup.py:_flush():76] Applying login settings: {}
2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_init.py:_log_setup():521] Logging user logs to ./wandb/run-20240425_211446-lkqptn01/logs/debug.log
2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_init.py:_log_setup():522] Logging internal logs to ./wandb/run-20240425_211446-lkqptn01/logs/debug-internal.log
2024-04-25 21:14:46,147 INFO    MainThread:89997 [wandb_init.py:init():561] calling init triggers
2024-04-25 21:14:46,148 INFO    MainThread:89997 [wandb_init.py:init():568] wandb.init called with sweep_config: {}
config: {'a': 1}
2024-04-25 21:14:46,148 INFO    MainThread:89997 [wandb_init.py:init():611] starting backend
2024-04-25 21:14:46,148 INFO    MainThread:89997 [wandb_init.py:init():615] setting up manager
2024-04-25 21:14:46,153 INFO    MainThread:89997 [backend.py:_multiprocessing_setup():105] multiprocessing start_methods=spawn,fork,forkserver, using: spawn
2024-04-25 21:14:46,156 INFO    MainThread:89997 [wandb_init.py:init():623] backend started and connected
2024-04-25 21:14:46,171 INFO    MainThread:89997 [wandb_init.py:init():715] updated telemetry
2024-04-25 21:14:46,282 INFO    MainThread:89997 [wandb_init.py:init():748] communicating run to backend with 90.0 second timeout
2024-04-25 21:14:46,836 ERROR   MainThread:89997 [wandb_init.py:init():774] encountered error: It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
2024-04-25 21:14:48,147 ERROR   MainThread:89997 [wandb_init.py:init():1199] It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
Traceback (most recent call last):
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 1181, in init
    run = wi.init()
  File "/Users/zhifu/miniconda3/envs/funasr/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 780, in init
    raise error
wandb.errors.CommError: It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 403: Forbidden)
wandb/run-20240425_211446-lkqptn01/run-lkqptn01.wandb