Skip to content

Conversation

@ElaineBao
Copy link
Contributor

Description

MFDNN-14010.

  • Use an End op to represent an intermediate output (which has consumers within the partition as well as outside the partition)
  • Add patterns and related implementation for sdpa/gqa training backward w.r.t gradients for mask
  • Add test cases, modify benchdnn to support such cases' validation.

@ElaineBao ElaineBao self-assigned this Dec 2, 2025
@ElaineBao ElaineBao requested a review from a team as a code owner December 2, 2025 06:33
@ElaineBao ElaineBao added the component:graph-api Codeowner: @oneapi-src/onednn-graph label Dec 2, 2025
@ElaineBao ElaineBao requested a review from a team as a code owner December 2, 2025 06:33
@github-actions github-actions bot added the component:tests Codeowner: @oneapi-src/onednn-arch label Dec 2, 2025
data_type::s8, data_type::u8, data_type::s32,
data_type::undef}))
data_type::undef})
.set_shape_inference_function(infer_dummy_output_shape))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? By definition, there is no output for an End op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The graph's infer shape function will iterate every op in the graph, we can certainly check if one op has output or not to avoid adding a dummy infer shape function for End op here, but I think it's not a better choice because in that way, every op will have to check that condition.

@ElaineBao
Copy link
Contributor Author

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph

Copy link
Contributor

@TaoLv TaoLv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any changes required for document? What's the main difference in the fusion graph?

Comment on lines +68 to +69
auto end_input = cur_op->get_input_value(0);
outputs_.push_back(end_input->get_logical_tensor());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto end_input = cur_op->get_input_value(0);
outputs_.push_back(end_input->get_logical_tensor());
outputs_.push_back(end_input->get_input_logical_tensor(0));

const dnnl::engine &p_engine, pd_cache_t &pd_cache,
const fpmath_t &fpmath, bool use_block_layout,
subgraph_rewriter_t &rewriter) {
logical_tensor_t dst_lt = op->get_input_value(0)->get_logical_tensor();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logical_tensor_t dst_lt = op->get_input_value(0)->get_logical_tensor();
logical_tensor_t dst_lt = op->get_input_logical_tensor(0);

Comment on lines +1679 to +1680
auto src_md = make_dnnl_memory_desc(
op->get_input_value(0)->get_logical_tensor());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto src_md = make_dnnl_memory_desc(
op->get_input_value(0)->get_logical_tensor());
auto src_md = make_dnnl_memory_desc(op->get_input_logical_tensor(0));

pgraph->create_input_port(2, matmul_dv, 1);
pgraph->create_input_port(2, matmul_v_do, 0);
})
.set_attr<FCreatePattern>("FCreatePattern",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious why a new pattern and not incorporating into existing one?

Comment on lines +829 to +831
if (cur_op_refs.size() == 2 && cur_op_refs[0].kind_ == "End") {
matmul_idx = 1;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better if:

while (cur_op_refs[matmul_idx].kind_ != "MatMul") {
    matmul_idx++;
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants