-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
adds jais 2 support #30188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
adds jais 2 support #30188
Conversation
Signed-off-by: sarathc-cerebras <sarath.chandran@cerebras.net> Signed-off-by: sarathc-cerebras <sarath.chandran@cerebras.net> Signed-off-by: sarathc-cerebras <sarath.chandran@cerebras.net>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the Jais-2 model. The implementation looks mostly correct, following the patterns of similar models in vLLM. However, I found a critical issue in the weight remapping logic for the MLP layers. The provided mapping is for a gated MLP, while the Jais-2 model uses a standard MLP, which will cause weight loading to fail for certain model formats. I've provided a suggestion to fix this.
| "w1": "gate_proj", | ||
| "w2": "down_proj", | ||
| "w3": "up_proj", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mistral_mapping for the MLP layers appears to be incorrect for the Jais2MLP architecture. Jais2MLP is a standard 2-layer MLP with up_proj and down_proj, but the current mapping is for a 3-layer gated MLP (like SwiGLU), which includes gate_proj.
Specifically, mapping w1 to gate_proj will cause an AttributeError during weight loading because Jais2MLP does not have a gate_proj attribute. This seems to be a copy-paste from a Llama/Mistral implementation.
Assuming the 'mistral format' for a 2-layer MLP uses w1 for the up-projection and w2 for the down-projection, the mapping should be corrected.
| "w1": "gate_proj", | |
| "w2": "down_proj", | |
| "w3": "up_proj", | |
| "w1": "up_proj", | |
| "w2": "down_proj", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| # rotary embeds should be sliced | ||
| if "wk" in modules: | ||
| loaded_weight = permute(loaded_weight, self.config.num_key_value_heads) | ||
| elif "wq" in modules: | ||
| loaded_weight = permute(loaded_weight, self.config.num_attention_heads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard rotary permutation for quantized q/k tensors
The Mistral remapping helper permutes every tensor whose name contains wk or wq (lines 614‑618), regardless of whether it is an actual weight matrix. Quantized checkpoints (e.g., GPTQ/AWQ) include tensors like wk.qweight, wk.qzeros, or 1‑element wk.qscale_weight, and reshaping those to [num_heads, …, hidden_size] will either raise a view error or corrupt the loaded parameters. The Llama implementation gates permutation to .weight/qscale_weight tensors; a similar suffix check is needed here so quantized Jais2 weights in Mistral format load correctly.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for contirbution, some NIT comments
| "embed_tokens": "input_embeddings", | ||
| "lm_head": "output_embeddings", | ||
| } | ||
| embedding_padding_modules = ["lm_head"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have removed this attribute recently.
| embedding_padding_modules = ["lm_head"] | ||
|
|
||
| # BitandBytes specific attributes | ||
| bitsandbytes_stacked_params_mapping = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bitsandbytes_stacked_params_mapping has been removed for a long time
| "v_proj": ("qkv_proj", 2), | ||
| } | ||
|
|
||
| mistral_mapping = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wondering, do we really need this?
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.