-
Notifications
You must be signed in to change notification settings - Fork 1.1k
tests: benchdnn: graph: support validation for SDPA with non-contiguous strides in mask add #4398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| , ctx_exe(ctx_exe) | ||
| , impl_filter(impl_filter) { | ||
| repro = set_repro_line(); // must be last in ctor to collect right info | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please extend the existing one instead, it will require just one modification with it.
It would be also nice if binary driver gets this capability, too, through its interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion, updated! Could you pls take another look?
with non-contiguous strides in mask add
0bb3060 to
9902832
Compare
|
|
||
| logical_tensor::dims src_strides = base_op_ref.in_lts_[0].stride_; | ||
| logical_tensor::dims wei_strides = base_op_ref.in_lts_[1].stride_; | ||
| const logical_tensor::dims &dst_strides = base_op_ref.out_lts_[0].stride_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it on purpose to only have dst_strides in const ref, while copy for src_strides and wei_strides?
| --reset --op-attrs=0:transpose_b:1 --case=complex_fusion/mha/gemma2-bf16-f32.json | ||
|
|
||
| # sdpa with non-contiguous strides in mask add | ||
| --reset --in-shapes=5:1x1x384x384*147840x147840x385x1 --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my curiosity, could you please paste here the verbose logs of this case? Thanks.
Description
MFDNN-14395 reported a correctness issue for SDPA when the mask add is having a non-contiguous strides (shape: 1x1x1024x1024, strides:1049600x1049600x1025x1). Currently benchdnn graph doesn't support such case's validation (segfault). This PR aims to fix this.