Skip to content

Commit 81a50ce

Browse files
cpu: aarch64: implement brgemm ukernel api
1 parent c8d998d commit 81a50ce

17 files changed

+1249
-13
lines changed

src/cpu/aarch64/brgemm/brgemm_types.hpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ struct brgemm_desc_t {
215215
dim_t stride_b = 0;
216216

217217
brgemm_layout_t layout = brgemm_layout_undef;
218-
brgemm_batch_kind_t type;
218+
brgemm_batch_kind_t type = brgemm_batch_kind_t::brgemm_addr;
219219
bool is_dgmm = false; // set to true in brdgmm_desc_init
220220
bool with_sum = false;
221221
bool req_cal_comp_pads = false;
@@ -292,7 +292,22 @@ struct brgemm_desc_t {
292292
return sz;
293293
}
294294

295-
bool is_b_data_layout_vnni() { return true; }
295+
// A class version of the `static` version of the function.
296+
// Note: used in benchdnn only, not used inside the library.
297+
bool is_b_data_layout_vnni() const { return is_b_data_layout_vnni(dt_b); }
298+
299+
static bool is_b_data_layout_vnni(data_type_t dt_b) {
300+
using namespace data_type;
301+
return utils::one_of(dt_b, s8, u8, bf16);
302+
}
303+
304+
bool are_post_ops_applicable() const {
305+
const bool has_zero_points = !utils::everyone_is(
306+
brgemm_broadcast_t::none, zp_type_a, zp_type_b, zp_type_c);
307+
return dt_c != dt_d || with_eltwise || with_binary || with_bias
308+
|| with_sum || req_s8s8_compensation || has_zero_points
309+
|| with_scales || with_dst_scales;
310+
}
296311

297312
bool operator==(const brgemm_desc_t &rhs) const;
298313
bool operator<(const brgemm_desc_t &rhs) const;

src/cpu/aarch64/brgemm/jit_brgemm_kernel.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,12 @@ struct jit_brgemm_kernel_t : public jit_generator_t {
149149
const XReg reg_a_offset = x2;
150150
const XReg reg_b_offset = x6;
151151

152-
const XReg reg_aux1_batch = x5;
153-
const XReg reg_aux1_A = x5;
154-
const XReg reg_aux1_B = x7; //from jit_generator.hpp in x64
152+
const XReg reg_aux1_A = x4;
153+
const XReg reg_aux1_batch = reg_aux1_A;
155154

156-
const XReg reg_offs_batch = reg_aux1_A;
155+
const XReg reg_aux1_B = x7; //from jit_generator_t.hpp in x64
156+
157+
const XReg reg_offs_batch = x5;
157158
const XReg reg_strd_batch = reg_rdb_loop;
158159

159160
const XReg reg_bias = reg_rdb_loop;
@@ -1232,8 +1233,9 @@ void jit_brgemm_kernel_t::set_A_B_matrices() {
12321233
add(reg_aux_A, reg_aux_A, X_TMP_0);
12331234
ldr(X_TMP_1, ptr(reg_offs_batch, GET_OFF_BATCH_ELEMENT(offset.B)));
12341235
add(reg_aux_B, reg_aux_B, X_TMP_1);
1235-
mov_imm(X_TMP_2, sizeof(brgemm_batch_element_t));
1236-
add(reg_offs_batch, reg_offs_batch, X_TMP_2);
1236+
if (brg.brgattr.max_bs > 1)
1237+
add_imm(reg_offs_batch, reg_offs_batch,
1238+
sizeof(brgemm_batch_element_t), X_TMP_2);
12371239
} else if (brg.type == brgemm_strd) {
12381240
mov(reg_aux_A, reg_aux1_A);
12391241
mov(reg_aux_B, reg_aux1_B);

src/cpu/aarch64/matmul/brgemm_matmul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
193193

194194
int idx = get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K);
195195
if (idx < 0) continue;
196-
brgemm_t &brg = brg_descs_[idx];
196+
brgemm_desc_t &brg = brg_descs_[idx];
197197
auto LDA = i_K && bgmmc_.use_buffer_a_tail_only
198198
? (dim_t)bgmmc_.wei_k_blk
199199
: bgmmc_.LDA;

src/cpu/aarch64/matmul/brgemm_matmul.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/*******************************************************************************
22
* Copyright 2021 Intel Corporation
33
* Copyright 2024 FUJITSU LIMITED
4+
* Copyright 2025 Arm Ltd. and affiliates
45
* Licensed under the Apache License, Version 2.0 (the "License");
56
* you may not use this file except in compliance with the License.
67
* You may obtain a copy of the License at
@@ -53,13 +54,15 @@ struct brgemm_matmul_t : public primitive_t {
5354
status_t init(engine_t *engine);
5455
int get_brg_kernel_idx(bool is_bs_tail, bool do_initialization,
5556
int m_ker_idx, bool is_N_tail, bool is_K_tail) const;
56-
const brgemm_t &get_brg_desc(int idx) const { return brg_descs_[idx]; }
57+
const brgemm_desc_t &get_brg_desc(int idx) const {
58+
return brg_descs_[idx];
59+
}
5760
const brgemm_matmul_conf_t &get_brgemm_matmul_conf() const {
5861
return bgmmc_;
5962
}
6063

6164
private:
62-
brgemm_t brg_descs_[max_num_brg_kernels_matmul];
65+
brgemm_desc_t brg_descs_[max_num_brg_kernels_matmul];
6366
brgemm_matmul_conf_t bgmmc_;
6467
};
6568

src/cpu/aarch64/matmul/brgemm_matmul_copy_utils.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "common/type_helpers.hpp"
2222
#include "common/utils.hpp"
2323
#include "cpu/aarch64/jit_generator.hpp"
24+
#include "xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_reg.h"
2425

2526
#include "cpu/aarch64/matmul/brgemm_matmul_copy_utils.hpp"
2627

@@ -586,7 +587,7 @@ void jit_brgemm_matmul_copy_b_f32_t::copy_16_8_x_n_block(
586587
continue;
587588
}
588589

589-
const opmask_t curr_msk = zero_padding < n_blk_step ? kTail : kFFFF;
590+
const opmask_t curr_msk = zero_padding < n_blk_step ? kTail : P_ALL_ONE;
590591
const int blk_idx = iter % max_regs_available;
591592
load(blk_idx, k, n, curr_msk);
592593
add_imm(X_DEFAULT_ADDR, reg_tr_src, tr_src_off, X_TMP_0);
@@ -621,6 +622,7 @@ void jit_brgemm_matmul_copy_b_f32_t::compute_k_loop(int ncolumns) {
621622
}
622623

623624
void jit_brgemm_matmul_copy_b_f32_t::generate() {
625+
624626
preamble();
625627
eor(zmm_zero.d, zmm_zero.d, zmm_zero.d);
626628
LDR_IMM(reg_src, param1, GET_OFF(src));

src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,67 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
989989
return status::success;
990990
}
991991

992+
status_t init_conf(brgemm_matmul_conf_t &conf, dim_t batch, dim_t M, dim_t K,
993+
dim_t N, dim_t in_ld, dim_t n_blk, data_type_t in_type,
994+
data_type_t out_type, format_tag_t in_tag) {
995+
if (n_blk <= 0 && M <= 0) return status::invalid_arguments;
996+
997+
const auto vnni_granularity = data_type_vnni_granularity(out_type);
998+
if (vnni_granularity <= 0) return status::invalid_arguments;
999+
1000+
// Zero initialize the `conf` to avoid access to 'garbage' in members.
1001+
conf = brgemm_matmul_conf_t();
1002+
1003+
const bool is_bf16 = one_of(in_type, bf16) || one_of(out_type, bf16);
1004+
const bool is_s8u8 = one_of(in_type, s8, u8) || one_of(out_type, s8, u8);
1005+
1006+
VCONDCHECK_BG(!(is_bf16 || is_s8u8), VERBOSE_UNSUPPORTED_DT);
1007+
1008+
const bool is_copyB = N > 0;
1009+
conf.isa = get_max_cpu_isa(); // Just use the best ISA possible.
1010+
conf.is_bf32 = false;
1011+
conf.batch = batch;
1012+
conf.src_dt = conf.wei_dt = out_type;
1013+
conf.orig_src_dt = conf.orig_wei_dt = in_type;
1014+
// Note: will need to change `tr_a_dt_sz` for copyA in cases where src_dt != dst_dt
1015+
conf.a_dt_sz = conf.tr_a_dt_sz = types::data_type_size(conf.src_dt);
1016+
conf.N = N;
1017+
conf.M = M;
1018+
conf.K = K;
1019+
const dim_t copyA_K_blk = isa_num_vregs(conf.isa) / 2;
1020+
const dim_t copyB_K_blk = 16 * vnni_granularity;
1021+
conf.K_blk = is_copyB ? copyB_K_blk : copyA_K_blk;
1022+
conf.K_tail = conf.K % conf.K_blk;
1023+
if (!is_copyB) {
1024+
// Note: current implementation always calls the transposed kernel.
1025+
conf.transposed_A = true;
1026+
conf.M_blk = (dim_t)isa_max_vlen(conf.isa) / conf.a_dt_sz;
1027+
conf.M_tail = conf.M % conf.M_blk;
1028+
conf.copy_A_src_stride = in_ld * conf.a_dt_sz;
1029+
// setting LDA parameter required for plain transpose
1030+
conf.LDA = conf.K;
1031+
} else {
1032+
conf.blocked_B = !utils::one_of(in_tag, ab, ba, abc, acb);
1033+
conf.transposed_B = utils::one_of(in_tag, ba, acb);
1034+
conf.wei_tag = in_tag;
1035+
conf.wei_n_blk = conf.N_blk = conf.LDB = n_blk;
1036+
conf.N_tail = conf.N % conf.N_blk;
1037+
conf.b_dt_sz = types::data_type_size(in_type);
1038+
conf.tr_b_dt_sz = types::data_type_size(conf.wei_dt);
1039+
conf.copy_B_wei_stride = in_ld * conf.b_dt_sz;
1040+
conf.N_chunk_elems = conf.N;
1041+
conf.s8s8_comp_b_str = utils::rnd_up(conf.N, conf.wei_n_blk);
1042+
conf.s8s8_comp_n_str = conf.wei_n_blk;
1043+
}
1044+
1045+
conf.s8s8_compensation_required = false;
1046+
conf.src_zp_type = brgemm_broadcast_t::none;
1047+
conf.has_zero_point_a = false;
1048+
conf.has_zero_point_b = false;
1049+
1050+
return status::success;
1051+
}
1052+
9921053
void init_aux_values(brgemm_matmul_conf_t &bgmmc,
9931054
const memory_desc_wrapper &src_d, const memory_desc_wrapper &wei_d,
9941055
const memory_desc_wrapper &dst_d) {

src/cpu/aarch64/matmul/brgemm_matmul_utils.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/*******************************************************************************
22
* Copyright 2021 Intel Corporation
33
* Copyright 2023-2024 FUJITSU LIMITED
4+
* Copyright 2025 Arm Ltd. and affiliates
45
*
56
* Licensed under the Apache License, Version 2.0 (the "License");
67
* you may not use this file except in compliance with the License.
@@ -121,6 +122,8 @@ struct brgemm_matmul_conf_t {
121122
data_type_t wei_dt;
122123
data_type_t acc_dt;
123124
data_type_t bia_dt;
125+
data_type_t orig_src_dt;
126+
data_type_t orig_wei_dt;
124127
int nthr;
125128
int nthr_k;
126129

@@ -166,6 +169,7 @@ struct brgemm_matmul_conf_t {
166169
bool has_zero_point_a, has_zero_point_b, has_zero_point_c;
167170
bool post_ops_applicable;
168171
bool transposed_A;
172+
bool transposed_B;
169173
bool blocked_B;
170174

171175
dim_t zp_a_comp_shift_n;
@@ -301,6 +305,10 @@ struct brgemm_matmul_conf_utils_t {
301305
const cpu_isa_t isa_;
302306
};
303307

308+
status_t init_conf(brgemm_matmul_conf_t &conf, dim_t batch, dim_t M, dim_t K,
309+
dim_t N, dim_t in_ld, dim_t n_blk, data_type_t in_type,
310+
data_type_t out_type, format_tag_t in_tag);
311+
304312
void init_aux_values(brgemm_matmul_conf_t &bgmmc,
305313
const memory_desc_wrapper &src_d, const memory_desc_wrapper &wei_d,
306314
const memory_desc_wrapper &dst_d);
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*******************************************************************************
2+
* Copyright 2025 Arm Ltd. and affiliates
3+
* Copyright 2025 Intel Corporation
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*******************************************************************************/
17+
18+
#include "common/utils.hpp"
19+
20+
#include "cpu/aarch64/ukernel/attr_params.hpp"
21+
22+
#ifdef DNNL_EXPERIMENTAL_UKERNEL
23+
24+
using namespace dnnl::impl;
25+
using namespace dnnl::impl::cpu::ukernel;
26+
27+
status_t attr_params_t::set_post_ops_args(const void **post_ops_args) {
28+
post_ops_args_ = post_ops_args;
29+
return status::success;
30+
}
31+
32+
status_t attr_params_t::set_scales(const void *scales, int arg) {
33+
switch (arg) {
34+
case DNNL_ARG_SRC: a_scales_ = scales; break;
35+
case DNNL_ARG_WEIGHTS: b_scales_ = scales; break;
36+
case DNNL_ARG_DST: d_scales_ = scales; break;
37+
default: assert(!"unsupported arg");
38+
}
39+
return status::success;
40+
}
41+
42+
const void *attr_params_t::get_scales(int arg) const {
43+
switch (arg) {
44+
case DNNL_ARG_SRC: return a_scales_;
45+
case DNNL_ARG_WEIGHTS: return b_scales_;
46+
case DNNL_ARG_DST: return d_scales_;
47+
default: assert(!"unsupported arg");
48+
}
49+
return nullptr;
50+
}
51+
52+
namespace dnnl {
53+
namespace impl {
54+
namespace cpu {
55+
namespace aarch64 {
56+
namespace ukernel {
57+
58+
status_t dnnl_ukernel_attr_params_create(attr_params_t **attr_params) {
59+
*attr_params = new attr_params_t();
60+
return status::success;
61+
}
62+
63+
status_t dnnl_ukernel_attr_params_set_post_ops_args(
64+
attr_params_t *attr_params, const void **post_ops_args) {
65+
if (attr_params == nullptr) return status::invalid_arguments;
66+
67+
CHECK(attr_params->set_post_ops_args(post_ops_args));
68+
return status::success;
69+
}
70+
71+
status_t dnnl_ukernel_attr_params_set_A_scales(
72+
attr_params_t *attr_params, const void *a_scales) {
73+
return status::unimplemented;
74+
}
75+
76+
status_t dnnl_ukernel_attr_params_set_B_scales(
77+
attr_params_t *attr_params, const void *b_scales) {
78+
return status::unimplemented;
79+
}
80+
81+
status_t dnnl_ukernel_attr_params_set_D_scales(
82+
attr_params_t *attr_params, const void *d_scales) {
83+
if (attr_params == nullptr) return status::invalid_arguments;
84+
85+
CHECK(attr_params->set_scales(d_scales, DNNL_ARG_DST));
86+
return status::success;
87+
}
88+
89+
status_t dnnl_ukernel_attr_params_destroy(attr_params_t *attr_params) {
90+
delete attr_params;
91+
return status::success;
92+
}
93+
94+
} // namespace ukernel
95+
} // namespace aarch64
96+
} // namespace cpu
97+
} // namespace impl
98+
} // namespace dnnl
99+
100+
#endif
101+
102+
//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s

0 commit comments

Comments
 (0)