Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,143 @@ void jit_softsign_emitter::register_table_entries() {
push_arg_entry_of("one", CONST_1_F);
}

/// SoftPlus ///
jit_softplus_emitter::jit_softplus_emitter(ov::intel_cpu::riscv64::jit_generator_t* host,
ov::intel_cpu::riscv64::cpu_isa_t host_isa,
ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
prepare_table();
exp_emitter = std::make_unique<jit_exp_emitter>(h, host_isa, exec_prc);
}

jit_softplus_emitter::jit_softplus_emitter(ov::intel_cpu::riscv64::jit_generator_t* host,
ov::intel_cpu::riscv64::cpu_isa_t host_isa,
[[maybe_unused]] const std::shared_ptr<ov::Node>& node,
ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
prepare_table();
exp_emitter = std::make_unique<jit_exp_emitter>(h, host_isa, exec_prc);
}

size_t jit_softplus_emitter::get_inputs_num() const {
return 1;
}

size_t jit_softplus_emitter::aux_gprs_count() const {
return std::max<size_t>(exp_emitter->aux_gprs_count(), 1) + 1;
}

size_t jit_softplus_emitter::aux_vecs_count() const {
return std::max<size_t>(exp_emitter->aux_vecs_count() + 2, 5);
}

size_t jit_softplus_emitter::aux_fp_gprs_count() const {
return std::max<size_t>(exp_emitter->aux_fp_gprs_count(), 1);
}

void jit_softplus_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs,
const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == ov::intel_cpu::riscv64::cpu_isa_t::gv) {
emit_isa<ov::intel_cpu::riscv64::cpu_isa_t::gv>(in_vec_idxs, out_vec_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <ov::intel_cpu::riscv64::cpu_isa_t isa>
void jit_softplus_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs,
const std::vector<size_t>& out_vec_idxs) const {
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "Unsupported precision: ", exec_prc_);

// SoftPlus: softplus(x) = ln(1 + exp(x))
// For x >= 0: softplus(x) = x + ln(1 + exp(-x)) to avoid overflow
// For x < 0: softplus(x) = ln(1 + exp(x))
// For x > 20: softplus(x) ≈ x

auto src = VReg(in_vec_idxs[0]);
auto dst = VReg(out_vec_idxs[0]);

auto aux0 = VReg(aux_vec_idxs[0]);
auto aux1 = VReg(aux_vec_idxs[1]);
auto aux2 = VReg(aux_vec_idxs[2]);
auto aux3 = VReg(aux_vec_idxs[3]);
auto aux4 = VReg(aux_vec_idxs[4]);

auto fp0 = FReg(aux_fp_gpr_idxs[0]);
auto tmp = Reg(aux_gpr_idxs[0]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, remove, since tmp is not actually used in your code at all


h->vmv_v_v(aux0, src);

load_table_val("large_threshold", fp0);
h->vmfgt_vf(mask_vreg(), aux0, fp0);
h->vmv_v_v(dst, aux0, VM::masked);

h->fmv_w_x(fp0, zero);
h->vmfge_vf(mask_vreg(), aux0, fp0);
h->vfneg_vv(aux1, aux0, VM::masked);

h->vmv_v_v(aux1, aux0, VM::unmasked);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this particular operation will override the result that is stored in aux1 at line 1023, isn't it?


auto exp_aux_vec_idxs = aux_vec_idxs;
exp_aux_vec_idxs.erase(
std::find(exp_aux_vec_idxs.begin(), exp_aux_vec_idxs.end(), static_cast<size_t>(aux0.getIdx())));
exp_aux_vec_idxs.erase(
std::find(exp_aux_vec_idxs.begin(), exp_aux_vec_idxs.end(), static_cast<size_t>(aux1.getIdx())));
Comment on lines +1027 to +1031
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it makes more sense to create the exp_aux_vec_idxs vector as a slice of aux_vec_idxs starting from index 2?


exp_emitter->emit_code({static_cast<size_t>(aux1.getIdx())},
{static_cast<size_t>(aux1.getIdx())},
exp_aux_vec_idxs,
aux_gpr_idxs,
aux_fp_gpr_idxs);

load_table_val("one", fp0);
h->vfadd_vf(aux2, aux1, fp0);

load_table_val("log_taylor_threshold", fp0);
h->vmflt_vf(mask_vreg(), aux1, fp0);

h->vfmul_vv(aux3, aux1, aux1);
load_table_val("half", fp0);
h->vfmul_vf(aux3, aux3, fp0);
h->vfsub_vv(aux4, aux1, aux3, VM::masked);

h->vfmv_v_v(aux3, aux2);

load_table_val("log_c1", fp0);
h->vfmv_v_f(dst, fp0);

load_table_val("log_c2", fp0);
h->vfmadd_vf(dst, fp0, aux2, VM::andnot);

h->vmerge_vvm(dst, dst, aux4);

h->fmv_w_x(fp0, zero);
h->vmfge_vf(mask_vreg(), aux0, fp0);
h->vfadd_vv(dst, dst, aux0, VM::masked);

load_table_val("large_threshold", fp0);
h->vmfle_vf(mask_vreg(), aux0, fp0);
}

std::set<std::vector<element::Type>> jit_softplus_emitter::get_supported_precisions(
[[maybe_unused]] const std::shared_ptr<ov::Node>& node) {
return {{element::f32}};
}

void jit_softplus_emitter::register_table_entries() {
push_arg_entry_of("one", CONST_1_F);
push_arg_entry_of("half", 0x3f000000);
push_arg_entry_of("large_threshold", 0x41a00000);
push_arg_entry_of("log_taylor_threshold", 0x3f000000);
push_arg_entry_of("log_c1", 0x3f800000);
push_arg_entry_of("log_c2", 0x3f000000);
}

void jit_softplus_emitter::emit_data() const {
jit_emitter::emit_data();
exp_emitter->emit_data();
}

/// Greater ///
jit_greater_emitter::jit_greater_emitter(jit_generator_t* host, cpu_isa_t host_isa, const element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,38 @@ class jit_softsign_emitter : public jit_emitter {

void register_table_entries() override;
};
///SoftPlus///
class jit_softplus_emitter : public jit_emitter {
public:
jit_softplus_emitter(ov::intel_cpu::riscv64::jit_generator_t* host,
ov::intel_cpu::riscv64::cpu_isa_t host_isa,
ov::element::Type exec_prc = ov::element::f32);

jit_softplus_emitter(ov::intel_cpu::riscv64::jit_generator_t* host,
ov::intel_cpu::riscv64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node,
ov::element::Type exec_prc = ov::element::f32);

size_t get_inputs_num() const override;
size_t aux_vecs_count() const override;
size_t aux_gprs_count() const override;
size_t aux_fp_gprs_count() const override;

void emit_data() const override;

static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);

private:
std::unique_ptr<jit_exp_emitter> exp_emitter{nullptr};

void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;

template <ov::intel_cpu::riscv64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;

void register_table_entries() override;
};
/// Greater///
class jit_greater_emitter : public jit_emitter {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
OV_CASE(Algorithm::EltwiseSqrt, jit_sqrt_emitter),
OV_CASE(Algorithm::EltwiseSquaredDifference, jit_squared_difference_emitter),
OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter),
OV_CASE(Algorithm::EltwiseTanh, jit_tanh_emitter));
OV_CASE(Algorithm::EltwiseTanh, jit_tanh_emitter),
OV_CASE(Algorithm::EltwiseSoftPlus, jit_softplus_emitter));

OPENVINO_ASSERT(ctx.emitter, "Unsupported operation type '" + algToString(data.algo) + "' for Eltwise emitter");

Expand Down Expand Up @@ -684,7 +685,8 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_CASE(Algorithm::EltwiseSqrt, jit_sqrt_emitter),
OV_CASE(Algorithm::EltwiseSquaredDifference, jit_squared_difference_emitter),
OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter),
OV_CASE(Algorithm::EltwiseTanh, jit_tanh_emitter));
OV_CASE(Algorithm::EltwiseTanh, jit_tanh_emitter),
OV_CASE(Algorithm::EltwiseSoftPlus, jit_softplus_emitter));

OPENVINO_ASSERT(!precisions.empty(), "Unsupported operation type for Eltwise emitter");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType
(activation_type == utils::ActivationTypes::Sigmoid) ||
(activation_type == utils::ActivationTypes::SoftSign) ||
(activation_type == utils::ActivationTypes::Sqrt) ||
(activation_type == utils::ActivationTypes::Tanh))
(activation_type == utils::ActivationTypes::Tanh) ||
(activation_type == utils::ActivationTypes::SoftPlus))
return "jit";
}
#if defined(OV_CPU_WITH_SHL)
Expand Down
Loading