-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[CPU][RISCV64] Implement CPU Plugin JIT emitter for SoftPlus activation #33144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
bf5b093
359b362
e59d93e
c4ba973
2daa0f1
1710de0
8e9fc33
0616c92
d886e6c
6c38ae9
141dde5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -947,6 +947,143 @@ void jit_softsign_emitter::register_table_entries() { | |
| push_arg_entry_of("one", CONST_1_F); | ||
| } | ||
|
|
||
| /// SoftPlus /// | ||
| jit_softplus_emitter::jit_softplus_emitter(ov::intel_cpu::riscv64::jit_generator_t* host, | ||
| ov::intel_cpu::riscv64::cpu_isa_t host_isa, | ||
| ov::element::Type exec_prc) | ||
| : jit_emitter(host, host_isa, exec_prc) { | ||
| prepare_table(); | ||
| exp_emitter = std::make_unique<jit_exp_emitter>(h, host_isa, exec_prc); | ||
| } | ||
|
|
||
| jit_softplus_emitter::jit_softplus_emitter(ov::intel_cpu::riscv64::jit_generator_t* host, | ||
| ov::intel_cpu::riscv64::cpu_isa_t host_isa, | ||
| [[maybe_unused]] const std::shared_ptr<ov::Node>& node, | ||
| ov::element::Type exec_prc) | ||
| : jit_emitter(host, host_isa, exec_prc) { | ||
| prepare_table(); | ||
| exp_emitter = std::make_unique<jit_exp_emitter>(h, host_isa, exec_prc); | ||
| } | ||
|
|
||
| size_t jit_softplus_emitter::get_inputs_num() const { | ||
| return 1; | ||
| } | ||
|
|
||
| size_t jit_softplus_emitter::aux_gprs_count() const { | ||
| return std::max<size_t>(exp_emitter->aux_gprs_count(), 1) + 1; | ||
| } | ||
|
|
||
| size_t jit_softplus_emitter::aux_vecs_count() const { | ||
| return std::max<size_t>(exp_emitter->aux_vecs_count() + 2, 5); | ||
| } | ||
|
|
||
| size_t jit_softplus_emitter::aux_fp_gprs_count() const { | ||
| return std::max<size_t>(exp_emitter->aux_fp_gprs_count(), 1); | ||
| } | ||
|
|
||
| void jit_softplus_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, | ||
| const std::vector<size_t>& out_vec_idxs) const { | ||
| if (host_isa_ == ov::intel_cpu::riscv64::cpu_isa_t::gv) { | ||
| emit_isa<ov::intel_cpu::riscv64::cpu_isa_t::gv>(in_vec_idxs, out_vec_idxs); | ||
| } else { | ||
| OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel"); | ||
| } | ||
| } | ||
|
|
||
| template <ov::intel_cpu::riscv64::cpu_isa_t isa> | ||
| void jit_softplus_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs, | ||
| const std::vector<size_t>& out_vec_idxs) const { | ||
| OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "Unsupported precision: ", exec_prc_); | ||
|
|
||
| // SoftPlus: softplus(x) = ln(1 + exp(x)) | ||
| // For x >= 0: softplus(x) = x + ln(1 + exp(-x)) to avoid overflow | ||
| // For x < 0: softplus(x) = ln(1 + exp(x)) | ||
| // For x > 20: softplus(x) ≈ x | ||
|
|
||
| auto src = VReg(in_vec_idxs[0]); | ||
| auto dst = VReg(out_vec_idxs[0]); | ||
|
|
||
| auto aux0 = VReg(aux_vec_idxs[0]); | ||
| auto aux1 = VReg(aux_vec_idxs[1]); | ||
| auto aux2 = VReg(aux_vec_idxs[2]); | ||
| auto aux3 = VReg(aux_vec_idxs[3]); | ||
| auto aux4 = VReg(aux_vec_idxs[4]); | ||
|
|
||
| auto fp0 = FReg(aux_fp_gpr_idxs[0]); | ||
| auto tmp = Reg(aux_gpr_idxs[0]); | ||
|
|
||
| h->vmv_v_v(aux0, src); | ||
|
|
||
| load_table_val("large_threshold", fp0); | ||
| h->vmfgt_vf(mask_vreg(), aux0, fp0); | ||
| h->vmv_v_v(dst, aux0, VM::masked); | ||
|
|
||
| h->fmv_w_x(fp0, zero); | ||
| h->vmfge_vf(mask_vreg(), aux0, fp0); | ||
| h->vfneg_vv(aux1, aux0, VM::masked); | ||
|
|
||
| h->vmv_v_v(aux1, aux0, VM::unmasked); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems this particular operation will override the result that is stored in |
||
|
|
||
| auto exp_aux_vec_idxs = aux_vec_idxs; | ||
| exp_aux_vec_idxs.erase( | ||
| std::find(exp_aux_vec_idxs.begin(), exp_aux_vec_idxs.end(), static_cast<size_t>(aux0.getIdx()))); | ||
| exp_aux_vec_idxs.erase( | ||
| std::find(exp_aux_vec_idxs.begin(), exp_aux_vec_idxs.end(), static_cast<size_t>(aux1.getIdx()))); | ||
|
Comment on lines
+1027
to
+1031
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it makes more sense to create the |
||
|
|
||
| exp_emitter->emit_code({static_cast<size_t>(aux1.getIdx())}, | ||
| {static_cast<size_t>(aux1.getIdx())}, | ||
| exp_aux_vec_idxs, | ||
| aux_gpr_idxs, | ||
| aux_fp_gpr_idxs); | ||
|
|
||
| load_table_val("one", fp0); | ||
| h->vfadd_vf(aux2, aux1, fp0); | ||
|
|
||
| load_table_val("log_taylor_threshold", fp0); | ||
| h->vmflt_vf(mask_vreg(), aux1, fp0); | ||
|
|
||
| h->vfmul_vv(aux3, aux1, aux1); | ||
| load_table_val("half", fp0); | ||
| h->vfmul_vf(aux3, aux3, fp0); | ||
| h->vfsub_vv(aux4, aux1, aux3, VM::masked); | ||
|
|
||
| h->vfmv_v_v(aux3, aux2); | ||
|
|
||
| load_table_val("log_c1", fp0); | ||
| h->vfmv_v_f(dst, fp0); | ||
|
|
||
| load_table_val("log_c2", fp0); | ||
| h->vfmadd_vf(dst, fp0, aux2, VM::andnot); | ||
|
|
||
| h->vmerge_vvm(dst, dst, aux4); | ||
|
|
||
| h->fmv_w_x(fp0, zero); | ||
| h->vmfge_vf(mask_vreg(), aux0, fp0); | ||
| h->vfadd_vv(dst, dst, aux0, VM::masked); | ||
|
|
||
| load_table_val("large_threshold", fp0); | ||
| h->vmfle_vf(mask_vreg(), aux0, fp0); | ||
| } | ||
|
|
||
| std::set<std::vector<element::Type>> jit_softplus_emitter::get_supported_precisions( | ||
| [[maybe_unused]] const std::shared_ptr<ov::Node>& node) { | ||
| return {{element::f32}}; | ||
| } | ||
|
|
||
| void jit_softplus_emitter::register_table_entries() { | ||
| push_arg_entry_of("one", CONST_1_F); | ||
| push_arg_entry_of("half", 0x3f000000); | ||
| push_arg_entry_of("large_threshold", 0x41a00000); | ||
| push_arg_entry_of("log_taylor_threshold", 0x3f000000); | ||
| push_arg_entry_of("log_c1", 0x3f800000); | ||
| push_arg_entry_of("log_c2", 0x3f000000); | ||
| } | ||
|
|
||
| void jit_softplus_emitter::emit_data() const { | ||
| jit_emitter::emit_data(); | ||
| exp_emitter->emit_data(); | ||
| } | ||
|
|
||
| /// Greater /// | ||
| jit_greater_emitter::jit_greater_emitter(jit_generator_t* host, cpu_isa_t host_isa, const element::Type exec_prc) | ||
| : jit_emitter(host, host_isa, exec_prc) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, remove, since
tmpis not actually used in your code at all