-
-
Notifications
You must be signed in to change notification settings - Fork 868
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
如何使用state tuning rwkv6-7B? #246
Comments
请问你是直接使用的RWKV-LM项目,还是自己修改的?如果是自己修改的项目,在冻结梯度是deepspeed的checkpoint会报错,你需要使用torch.checkpoint 详细可以参考RWKV-PEFT |
@JL-er 谢谢,请问为啥是只tune time_state 64*64的矩阵呀,另外的两组state为啥不一起微调呢? state[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
state[i*3+1] = state_xueshan_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
state[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() |
这两个参数非常小影响不大,所以只取核心部分的state,为了简单方便 |
好的了解了,谢谢! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
我按照官网尝试,
应该是只训练time_state,然而报错
RuntimeError: element o of tensors does not require grad and does not have a grad_fn
The text was updated successfully, but these errors were encountered: