-
Notifications
You must be signed in to change notification settings - Fork 1.1k
graph: backend: dnnl, tests: benchdnn: support sdpa / gqa training with gradients for mask #4404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
w.r.t gradients for mask
w.r.t gradients for mask
for sdpa/gqa training backward w.r.t gradients for mask
| data_type::s8, data_type::u8, data_type::s32, | ||
| data_type::undef})) | ||
| data_type::undef}) | ||
| .set_shape_inference_function(infer_dummy_output_shape)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this? By definition, there is no output for an End op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The graph's infer shape function will iterate every op in the graph, we can certainly check if one op has output or not to avoid adding a dummy infer shape function for End op here, but I think it's not a better choice because in that way, every op will have to check that condition.
|
make test |
TaoLv
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any changes required for document? What's the main difference in the fusion graph?
| auto end_input = cur_op->get_input_value(0); | ||
| outputs_.push_back(end_input->get_logical_tensor()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| auto end_input = cur_op->get_input_value(0); | |
| outputs_.push_back(end_input->get_logical_tensor()); | |
| outputs_.push_back(end_input->get_input_logical_tensor(0)); |
| const dnnl::engine &p_engine, pd_cache_t &pd_cache, | ||
| const fpmath_t &fpmath, bool use_block_layout, | ||
| subgraph_rewriter_t &rewriter) { | ||
| logical_tensor_t dst_lt = op->get_input_value(0)->get_logical_tensor(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| logical_tensor_t dst_lt = op->get_input_value(0)->get_logical_tensor(); | |
| logical_tensor_t dst_lt = op->get_input_logical_tensor(0); |
| auto src_md = make_dnnl_memory_desc( | ||
| op->get_input_value(0)->get_logical_tensor()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| auto src_md = make_dnnl_memory_desc( | |
| op->get_input_value(0)->get_logical_tensor()); | |
| auto src_md = make_dnnl_memory_desc(op->get_input_logical_tensor(0)); |
| pgraph->create_input_port(2, matmul_dv, 1); | ||
| pgraph->create_input_port(2, matmul_v_do, 0); | ||
| }) | ||
| .set_attr<FCreatePattern>("FCreatePattern", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious why a new pattern and not incorporating into existing one?
| if (cur_op_refs.size() == 2 && cur_op_refs[0].kind_ == "End") { | ||
| matmul_idx = 1; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better if:
while (cur_op_refs[matmul_idx].kind_ != "MatMul") {
matmul_idx++;
}
Description
MFDNN-14010.