Skip to content

Commit b600647

Browse files
feat: add silu activation function (huggingface#1706)
* feat: add silu activation function * use silu/arg in grad * update candle-nn * use node
1 parent 14010a8 commit b600647

File tree

14 files changed

+206
-5
lines changed

14 files changed

+206
-5
lines changed

candle-core/src/accelerate.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
380380
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
381381
}
382382

383+
#[inline]
384+
pub fn vs_exp_inplace(y: &mut [f32]) {
385+
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
386+
}
387+
388+
#[inline]
389+
pub fn vd_exp_inplace(y: &mut [f64]) {
390+
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
391+
}
392+
383393
#[inline]
384394
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
385395
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
@@ -402,6 +412,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
402412
}
403413
}
404414

415+
#[inline]
416+
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
417+
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
418+
*y = -v
419+
}
420+
vs_exp_inplace(ys);
421+
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
422+
*y = v / (1.0 + *y)
423+
}
424+
}
425+
426+
#[inline]
427+
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
428+
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
429+
*y = -v
430+
}
431+
vd_exp_inplace(ys);
432+
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
433+
*y = v / (1.0 + *y)
434+
}
435+
}
436+
405437
macro_rules! binary_op {
406438
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
407439
#[inline]

candle-core/src/backprop.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,13 @@ impl Tensor {
589589
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
590590
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
591591
}
592+
Op::Unary(arg, UnaryOp::Silu) => {
593+
let sum_grad = grads.or_insert(arg)?;
594+
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
595+
let sigmoid_arg = (*node / arg)?;
596+
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
597+
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
598+
}
592599
Op::Elu(arg, alpha) => {
593600
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
594601
let sum_grad = grads.or_insert(arg)?;

candle-core/src/metal_backend.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,7 @@ impl BackendStorage for MetalStorage {
679679
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
680680
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
681681
("uerf", DType::F32) => contiguous::erf::FLOAT,
682+
("usilu", DType::F32) => contiguous::silu::FLOAT,
682683
("uabs", DType::F32) => contiguous::abs::FLOAT,
683684
("uceil", DType::F32) => contiguous::ceil::FLOAT,
684685
("ufloor", DType::F32) => contiguous::floor::FLOAT,
@@ -696,6 +697,7 @@ impl BackendStorage for MetalStorage {
696697
("ugelu", DType::F16) => contiguous::gelu::HALF,
697698
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
698699
("uerf", DType::F16) => contiguous::erf::HALF,
700+
("usilu", DType::F16) => contiguous::silu::HALF,
699701
("uabs", DType::F16) => contiguous::abs::HALF,
700702
("uceil", DType::F16) => contiguous::ceil::HALF,
701703
("ufloor", DType::F16) => contiguous::floor::HALF,
@@ -730,6 +732,7 @@ impl BackendStorage for MetalStorage {
730732
("ugelu", DType::F32) => strided::gelu::FLOAT,
731733
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
732734
("uerf", DType::F32) => strided::erf::FLOAT,
735+
("usilu", DType::F32) => strided::silu::FLOAT,
733736
("uabs", DType::F32) => strided::abs::FLOAT,
734737
("uceil", DType::F32) => strided::ceil::FLOAT,
735738
("ufloor", DType::F32) => strided::floor::FLOAT,
@@ -745,6 +748,7 @@ impl BackendStorage for MetalStorage {
745748
("ugelu", DType::F16) => strided::gelu::HALF,
746749
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
747750
("uerf", DType::F16) => strided::erf::HALF,
751+
("usilu", DType::F16) => strided::silu::HALF,
748752
("uabs", DType::F16) => strided::abs::HALF,
749753
("uceil", DType::F16) => strided::ceil::HALF,
750754
("ufloor", DType::F16) => strided::floor::HALF,

candle-core/src/mkl.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
333333
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
334334
}
335335

336+
#[inline]
337+
pub fn vs_exp_inplace(y: &mut [f32]) {
338+
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
339+
}
340+
341+
#[inline]
342+
pub fn vd_exp_inplace(y: &mut [f64]) {
343+
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
344+
}
345+
336346
#[inline]
337347
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
338348
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
@@ -355,6 +365,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
355365
}
356366
}
357367

368+
#[inline]
369+
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
370+
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
371+
*y = -v
372+
}
373+
vs_exp_inplace(ys);
374+
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
375+
*y = v / (1.0 + *y)
376+
}
377+
}
378+
379+
#[inline]
380+
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
381+
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
382+
*y = -v
383+
}
384+
vd_exp_inplace(ys);
385+
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
386+
*y = v / (1.0 + *y)
387+
}
388+
}
389+
358390
macro_rules! binary_op {
359391
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
360392
#[inline]

candle-core/src/op.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ pub enum UnaryOp {
6161
GeluErf,
6262
Erf,
6363
Relu,
64+
Silu,
6465
Tanh,
6566
Floor,
6667
Ceil,
@@ -390,6 +391,7 @@ pub(crate) struct Gelu;
390391
pub(crate) struct GeluErf;
391392
pub(crate) struct Erf;
392393
pub(crate) struct Relu;
394+
pub(crate) struct Silu;
393395
pub(crate) struct Tanh;
394396
pub(crate) struct Floor;
395397
pub(crate) struct Ceil;
@@ -724,6 +726,77 @@ impl UnaryOpT for Erf {
724726
}
725727
}
726728

729+
/// Silu operation
730+
impl UnaryOpT for Silu {
731+
const NAME: &'static str = "silu";
732+
const V: Self = Silu;
733+
#[inline(always)]
734+
fn bf16(v: bf16) -> bf16 {
735+
v / (bf16::ONE + (-v).exp())
736+
}
737+
#[inline(always)]
738+
fn f16(v: f16) -> f16 {
739+
v / (f16::ONE + (-v).exp())
740+
}
741+
#[inline(always)]
742+
fn f32(v: f32) -> f32 {
743+
v / (1.0 + (-v).exp())
744+
}
745+
#[inline(always)]
746+
fn f64(v: f64) -> f64 {
747+
v / (1.0 + (-v).exp())
748+
}
749+
#[inline(always)]
750+
fn u8(_: u8) -> u8 {
751+
0
752+
}
753+
#[inline(always)]
754+
fn u32(_: u32) -> u32 {
755+
0
756+
}
757+
#[inline(always)]
758+
fn i64(_: i64) -> i64 {
759+
0
760+
}
761+
const KERNEL: &'static str = "usilu";
762+
763+
#[cfg(feature = "mkl")]
764+
const F32_VEC: bool = true;
765+
766+
#[cfg(feature = "mkl")]
767+
#[inline(always)]
768+
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
769+
crate::mkl::vs_silu(xs, ys)
770+
}
771+
772+
#[cfg(feature = "mkl")]
773+
const F64_VEC: bool = true;
774+
775+
#[cfg(feature = "mkl")]
776+
#[inline(always)]
777+
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
778+
crate::mkl::vd_silu(xs, ys)
779+
}
780+
781+
#[cfg(feature = "accelerate")]
782+
const F32_VEC: bool = true;
783+
784+
#[cfg(feature = "accelerate")]
785+
#[inline(always)]
786+
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
787+
crate::accelerate::vs_silu(xs, ys)
788+
}
789+
790+
#[cfg(feature = "accelerate")]
791+
const F64_VEC: bool = true;
792+
793+
#[cfg(feature = "accelerate")]
794+
#[inline(always)]
795+
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
796+
crate::accelerate::vd_silu(xs, ys)
797+
}
798+
}
799+
727800
impl UnaryOpT for Abs {
728801
const NAME: &'static str = "abs";
729802
const KERNEL: &'static str = "uabs";

candle-core/src/tensor.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ impl Tensor {
508508
unary_op!(gelu_erf, GeluErf);
509509
unary_op!(erf, Erf);
510510
unary_op!(relu, Relu);
511+
unary_op!(silu, Silu);
511512
unary_op!(ceil, Ceil);
512513
unary_op!(floor, Floor);
513514
unary_op!(round, Round);

candle-core/tests/grad_tests.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,19 @@ fn unary_grad(device: &Device) -> Result<()> {
270270
[0.7358, 2.0000, 0.2707, 1.0000]
271271
);
272272

273+
// testing compared to pytorch nn.Silu()
274+
let y = x.silu()?;
275+
let grads = y.backward()?;
276+
let grad_x = grads.get(&x).context("no grad for x")?;
277+
assert_eq!(
278+
test_utils::to_vec1_round(&y, 4)?,
279+
[2.8577, 0.7311, 3.9281, 0.0806]
280+
);
281+
assert_eq!(
282+
test_utils::to_vec1_round(grad_x, 4)?,
283+
[1.0881, 0.9277, 1.0527, 0.5747],
284+
);
285+
273286
// manually checked: see comments
274287
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
275288
let y = x.interpolate2d(6, 6)?.reshape(36)?;

candle-core/tests/tensor_tests.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ fn unary_op(device: &Device) -> Result<()> {
120120
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
121121
]
122122
);
123+
assert_eq!(
124+
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
125+
[
126+
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
127+
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
128+
]
129+
);
123130
assert_eq!(
124131
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
125132
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]

candle-kernels/src/unary.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ __device__ __forceinline__ T relu_fwd(T x) {
5555
return maxg(x, zero);
5656
}
5757

58+
template<typename T>
59+
__device__ __forceinline__ T silu_fwd(T x) {
60+
return x / (static_cast<scalar_t>(1) + expg(-x));
61+
}
62+
5863
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
5964
extern "C" __global__ void FN_NAME( \
6065
const size_t numel, \
@@ -103,6 +108,7 @@ UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
103108
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
104109
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
105110
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
111+
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
106112
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
107113
#endif
108114

@@ -127,6 +133,7 @@ UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
127133
UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
128134
UNARY_OP(__half, urelu_f16, relu_fwd(x))
129135
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
136+
UNARY_OP(__half, usilu_f16, silu_fwd(x))
130137
UNARY_OP1(__half, upowf_f16, powg(x, param))
131138
#endif
132139

@@ -173,5 +180,7 @@ UNARY_OP(float, urelu_f32, relu_fwd(x))
173180
UNARY_OP(double, urelu_f64, relu_fwd(x))
174181
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
175182
UNARY_OP1(double, uelu_f64, elu_fwd(x, param))
183+
UNARY_OP(float, usilu_f32, silu_fwd(x))
184+
UNARY_OP(double, usilu_f64, silu_fwd(x))
176185
UNARY_OP1(float, upowf_f32, powg(x, param))
177186
UNARY_OP1(double, upowf_f64, powg(x, param))

candle-metal-kernels/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ macro_rules! ops{
183183
pub mod unary {
184184
ops!(
185185
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
186-
tanh, recip
186+
tanh, recip, silu
187187
);
188188
}
189189
pub mod binary {

0 commit comments

Comments
 (0)