@@ -989,6 +989,67 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
989989 return status::success;
990990}
991991
992+ status_t init_conf (brgemm_matmul_conf_t &conf, dim_t batch, dim_t M, dim_t K,
993+ dim_t N, dim_t in_ld, dim_t n_blk, data_type_t in_type,
994+ data_type_t out_type, format_tag_t in_tag) {
995+ if (n_blk <= 0 && M <= 0 ) return status::invalid_arguments;
996+
997+ const auto vnni_granularity = data_type_vnni_granularity (out_type);
998+ if (vnni_granularity <= 0 ) return status::invalid_arguments;
999+
1000+ // Zero initialize the `conf` to avoid access to 'garbage' in members.
1001+ conf = brgemm_matmul_conf_t ();
1002+
1003+ const bool is_bf16 = one_of (in_type, bf16 ) || one_of (out_type, bf16 );
1004+ const bool is_s8u8 = one_of (in_type, s8, u8 ) || one_of (out_type, s8, u8 );
1005+
1006+ VCONDCHECK_BG (!(is_bf16 || is_s8u8), VERBOSE_UNSUPPORTED_DT);
1007+
1008+ const bool is_copyB = N > 0 ;
1009+ conf.isa = get_max_cpu_isa (); // Just use the best ISA possible.
1010+ conf.is_bf32 = false ;
1011+ conf.batch = batch;
1012+ conf.src_dt = conf.wei_dt = out_type;
1013+ conf.orig_src_dt = conf.orig_wei_dt = in_type;
1014+ // Note: will need to change `tr_a_dt_sz` for copyA in cases where src_dt != dst_dt
1015+ conf.a_dt_sz = conf.tr_a_dt_sz = types::data_type_size (conf.src_dt );
1016+ conf.N = N;
1017+ conf.M = M;
1018+ conf.K = K;
1019+ const dim_t copyA_K_blk = isa_num_vregs (conf.isa ) / 2 ;
1020+ const dim_t copyB_K_blk = 16 * vnni_granularity;
1021+ conf.K_blk = is_copyB ? copyB_K_blk : copyA_K_blk;
1022+ conf.K_tail = conf.K % conf.K_blk ;
1023+ if (!is_copyB) {
1024+ // Note: current implementation always calls the transposed kernel.
1025+ conf.transposed_A = true ;
1026+ conf.M_blk = (dim_t )isa_max_vlen (conf.isa ) / conf.a_dt_sz ;
1027+ conf.M_tail = conf.M % conf.M_blk ;
1028+ conf.copy_A_src_stride = in_ld * conf.a_dt_sz ;
1029+ // setting LDA parameter required for plain transpose
1030+ conf.LDA = conf.K ;
1031+ } else {
1032+ conf.blocked_B = !utils::one_of (in_tag, ab, ba, abc, acb);
1033+ conf.transposed_B = utils::one_of (in_tag, ba, acb);
1034+ conf.wei_tag = in_tag;
1035+ conf.wei_n_blk = conf.N_blk = conf.LDB = n_blk;
1036+ conf.N_tail = conf.N % conf.N_blk ;
1037+ conf.b_dt_sz = types::data_type_size (in_type);
1038+ conf.tr_b_dt_sz = types::data_type_size (conf.wei_dt );
1039+ conf.copy_B_wei_stride = in_ld * conf.b_dt_sz ;
1040+ conf.N_chunk_elems = conf.N ;
1041+ conf.s8s8_comp_b_str = utils::rnd_up (conf.N , conf.wei_n_blk );
1042+ conf.s8s8_comp_n_str = conf.wei_n_blk ;
1043+ }
1044+
1045+ conf.s8s8_compensation_required = false ;
1046+ conf.src_zp_type = brgemm_broadcast_t ::none;
1047+ conf.has_zero_point_a = false ;
1048+ conf.has_zero_point_b = false ;
1049+
1050+ return status::success;
1051+ }
1052+
9921053void init_aux_values (brgemm_matmul_conf_t &bgmmc,
9931054 const memory_desc_wrapper &src_d, const memory_desc_wrapper &wei_d,
9941055 const memory_desc_wrapper &dst_d) {
0 commit comments