5454//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
5555use std:: cmp:: Ordering ;
5656use std:: fmt:: Debug ;
57- use std:: marker:: PhantomData ;
5857
5958#[ cfg( feature = "serde" ) ]
6059use serde:: { Deserialize , Serialize } ;
@@ -79,9 +78,11 @@ pub enum LogisticRegressionSolverName {
7978/// Logistic Regression parameters
8079#[ cfg_attr( feature = "serde" , derive( Serialize , Deserialize ) ) ]
8180#[ derive( Debug , Clone ) ]
82- pub struct LogisticRegressionParameters {
81+ pub struct LogisticRegressionParameters < T : RealNumber > {
8382 /// Solver to use for estimation of regression coefficients.
8483 pub solver : LogisticRegressionSolverName ,
84+ /// Regularization parameter.
85+ pub alpha : T ,
8586}
8687
8788/// Logistic Regression
@@ -113,21 +114,27 @@ trait ObjectiveFunction<T: RealNumber, M: Matrix<T>> {
113114struct BinaryObjectiveFunction < ' a , T : RealNumber , M : Matrix < T > > {
114115 x : & ' a M ,
115116 y : Vec < usize > ,
116- phantom : PhantomData < & ' a T > ,
117+ alpha : T ,
117118}
118119
119- impl LogisticRegressionParameters {
120+ impl < T : RealNumber > LogisticRegressionParameters < T > {
120121 /// Solver to use for estimation of regression coefficients.
121122 pub fn with_solver ( mut self , solver : LogisticRegressionSolverName ) -> Self {
122123 self . solver = solver;
123124 self
124125 }
126+ /// Regularization parameter.
127+ pub fn with_alpha ( mut self , alpha : T ) -> Self {
128+ self . alpha = alpha;
129+ self
130+ }
125131}
126132
127- impl Default for LogisticRegressionParameters {
133+ impl < T : RealNumber > Default for LogisticRegressionParameters < T > {
128134 fn default ( ) -> Self {
129135 LogisticRegressionParameters {
130136 solver : LogisticRegressionSolverName :: LBFGS ,
137+ alpha : T :: zero ( ) ,
131138 }
132139 }
133140}
@@ -156,13 +163,22 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
156163{
157164 fn f ( & self , w_bias : & M ) -> T {
158165 let mut f = T :: zero ( ) ;
159- let ( n, _ ) = self . x . shape ( ) ;
166+ let ( n, p ) = self . x . shape ( ) ;
160167
161168 for i in 0 ..n {
162169 let wx = BinaryObjectiveFunction :: partial_dot ( w_bias, self . x , 0 , i) ;
163170 f += wx. ln_1pe ( ) - ( T :: from ( self . y [ i] ) . unwrap ( ) ) * wx;
164171 }
165172
173+ if self . alpha > T :: zero ( ) {
174+ let mut w_squared = T :: zero ( ) ;
175+ for i in 0 ..p {
176+ let w = w_bias. get ( 0 , i) ;
177+ w_squared += w * w;
178+ }
179+ f += T :: half ( ) * self . alpha * w_squared;
180+ }
181+
166182 f
167183 }
168184
@@ -180,14 +196,21 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
180196 }
181197 g. set ( 0 , p, g. get ( 0 , p) - dyi) ;
182198 }
199+
200+ if self . alpha > T :: zero ( ) {
201+ for i in 0 ..p {
202+ let w = w_bias. get ( 0 , i) ;
203+ g. set ( 0 , i, g. get ( 0 , i) + self . alpha * w) ;
204+ }
205+ }
183206 }
184207}
185208
186209struct MultiClassObjectiveFunction < ' a , T : RealNumber , M : Matrix < T > > {
187210 x : & ' a M ,
188211 y : Vec < usize > ,
189212 k : usize ,
190- phantom : PhantomData < & ' a T > ,
213+ alpha : T ,
191214}
192215
193216impl < ' a , T : RealNumber , M : Matrix < T > > ObjectiveFunction < T , M >
@@ -209,6 +232,17 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
209232 f -= prob. get ( 0 , self . y [ i] ) . ln ( ) ;
210233 }
211234
235+ if self . alpha > T :: zero ( ) {
236+ let mut w_squared = T :: zero ( ) ;
237+ for i in 0 ..self . k {
238+ for j in 0 ..p {
239+ let wi = w_bias. get ( 0 , i * ( p + 1 ) + j) ;
240+ w_squared += wi * wi;
241+ }
242+ }
243+ f += T :: half ( ) * self . alpha * w_squared;
244+ }
245+
212246 f
213247 }
214248
@@ -239,16 +273,27 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
239273 g. set ( 0 , j * ( p + 1 ) + p, g. get ( 0 , j * ( p + 1 ) + p) - yi) ;
240274 }
241275 }
276+
277+ if self . alpha > T :: zero ( ) {
278+ for i in 0 ..self . k {
279+ for j in 0 ..p {
280+ let pos = i * ( p + 1 ) ;
281+ let wi = w. get ( 0 , pos + j) ;
282+ g. set ( 0 , pos + j, g. get ( 0 , pos + j) + self . alpha * wi) ;
283+ }
284+ }
285+ }
242286 }
243287}
244288
245- impl < T : RealNumber , M : Matrix < T > > SupervisedEstimator < M , M :: RowVector , LogisticRegressionParameters >
289+ impl < T : RealNumber , M : Matrix < T > >
290+ SupervisedEstimator < M , M :: RowVector , LogisticRegressionParameters < T > >
246291 for LogisticRegression < T , M >
247292{
248293 fn fit (
249294 x : & M ,
250295 y : & M :: RowVector ,
251- parameters : LogisticRegressionParameters ,
296+ parameters : LogisticRegressionParameters < T > ,
252297 ) -> Result < Self , Failed > {
253298 LogisticRegression :: fit ( x, y, parameters)
254299 }
@@ -268,7 +313,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
268313 pub fn fit (
269314 x : & M ,
270315 y : & M :: RowVector ,
271- _parameters : LogisticRegressionParameters ,
316+ parameters : LogisticRegressionParameters < T > ,
272317 ) -> Result < LogisticRegression < T , M > , Failed > {
273318 let y_m = M :: from_row_vector ( y. clone ( ) ) ;
274319 let ( x_nrows, num_attributes) = x. shape ( ) ;
@@ -302,7 +347,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
302347 let objective = BinaryObjectiveFunction {
303348 x,
304349 y : yi,
305- phantom : PhantomData ,
350+ alpha : parameters . alpha ,
306351 } ;
307352
308353 let result = LogisticRegression :: minimize ( x0, objective) ;
@@ -324,7 +369,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
324369 x,
325370 y : yi,
326371 k,
327- phantom : PhantomData ,
372+ alpha : parameters . alpha ,
328373 } ;
329374
330375 let result = LogisticRegression :: minimize ( x0, objective) ;
@@ -431,9 +476,9 @@ mod tests {
431476
432477 let objective = MultiClassObjectiveFunction {
433478 x : & x,
434- y,
479+ y : y . clone ( ) ,
435480 k : 3 ,
436- phantom : PhantomData ,
481+ alpha : 0.0 ,
437482 } ;
438483
439484 let mut g: DenseMatrix < f64 > = DenseMatrix :: zeros ( 1 , 9 ) ;
@@ -454,6 +499,24 @@ mod tests {
454499 ] ) ) ;
455500
456501 assert ! ( ( f - 408.0052230582765 ) . abs( ) < std:: f64 :: EPSILON ) ;
502+
503+ let objective_reg = MultiClassObjectiveFunction {
504+ x : & x,
505+ y : y. clone ( ) ,
506+ k : 3 ,
507+ alpha : 1.0 ,
508+ } ;
509+
510+ let f = objective_reg. f ( & DenseMatrix :: row_vector_from_array ( & [
511+ 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9. ,
512+ ] ) ) ;
513+ assert ! ( ( f - 487.5052 ) . abs( ) < 1e-4 ) ;
514+
515+ objective_reg. df (
516+ & mut g,
517+ & DenseMatrix :: row_vector_from_array ( & [ 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9. ] ) ,
518+ ) ;
519+ assert ! ( ( g. get( 0 , 0 ) . abs( ) - 32.0 ) . abs( ) < 1e-4 ) ;
457520 }
458521
459522 #[ test]
@@ -480,8 +543,8 @@ mod tests {
480543
481544 let objective = BinaryObjectiveFunction {
482545 x : & x,
483- y,
484- phantom : PhantomData ,
546+ y : y . clone ( ) ,
547+ alpha : 0.0 ,
485548 } ;
486549
487550 let mut g: DenseMatrix < f64 > = DenseMatrix :: zeros ( 1 , 3 ) ;
@@ -496,6 +559,20 @@ mod tests {
496559 let f = objective. f ( & DenseMatrix :: row_vector_from_array ( & [ 1. , 2. , 3. ] ) ) ;
497560
498561 assert ! ( ( f - 59.76994756647412 ) . abs( ) < std:: f64 :: EPSILON ) ;
562+
563+ let objective_reg = BinaryObjectiveFunction {
564+ x : & x,
565+ y : y. clone ( ) ,
566+ alpha : 1.0 ,
567+ } ;
568+
569+ let f = objective_reg. f ( & DenseMatrix :: row_vector_from_array ( & [ 1. , 2. , 3. ] ) ) ;
570+ assert ! ( ( f - 62.2699 ) . abs( ) < 1e-4 ) ;
571+
572+ objective_reg. df ( & mut g, & DenseMatrix :: row_vector_from_array ( & [ 1. , 2. , 3. ] ) ) ;
573+ assert ! ( ( g. get( 0 , 0 ) - 27.0511 ) . abs( ) < 1e-4 ) ;
574+ assert ! ( ( g. get( 0 , 1 ) - 12.239 ) . abs( ) < 1e-4 ) ;
575+ assert ! ( ( g. get( 0 , 2 ) - 3.8693 ) . abs( ) < 1e-4 ) ;
499576 }
500577
501578 #[ test]
@@ -547,6 +624,15 @@ mod tests {
547624 let y_hat = lr. predict ( & x) . unwrap ( ) ;
548625
549626 assert ! ( accuracy( & y_hat, & y) > 0.9 ) ;
627+
628+ let lr_reg = LogisticRegression :: fit (
629+ & x,
630+ & y,
631+ LogisticRegressionParameters :: default ( ) . with_alpha ( 10.0 ) ,
632+ )
633+ . unwrap ( ) ;
634+
635+ assert ! ( lr_reg. coefficients( ) . abs( ) . sum( ) < lr. coefficients( ) . abs( ) . sum( ) ) ;
550636 }
551637
552638 #[ test]
@@ -561,6 +647,15 @@ mod tests {
561647 let y_hat = lr. predict ( & x) . unwrap ( ) ;
562648
563649 assert ! ( accuracy( & y_hat, & y) > 0.9 ) ;
650+
651+ let lr_reg = LogisticRegression :: fit (
652+ & x,
653+ & y,
654+ LogisticRegressionParameters :: default ( ) . with_alpha ( 10.0 ) ,
655+ )
656+ . unwrap ( ) ;
657+
658+ assert ! ( lr_reg. coefficients( ) . abs( ) . sum( ) < lr. coefficients( ) . abs( ) . sum( ) ) ;
564659 }
565660
566661 #[ test]
@@ -622,6 +717,12 @@ mod tests {
622717 ] ;
623718
624719 let lr = LogisticRegression :: fit ( & x, & y, Default :: default ( ) ) . unwrap ( ) ;
720+ let lr_reg = LogisticRegression :: fit (
721+ & x,
722+ & y,
723+ LogisticRegressionParameters :: default ( ) . with_alpha ( 1.0 ) ,
724+ )
725+ . unwrap ( ) ;
625726
626727 let y_hat = lr. predict ( & x) . unwrap ( ) ;
627728
@@ -632,5 +733,6 @@ mod tests {
632733 . sum ( ) ;
633734
634735 assert ! ( error <= 1.0 ) ;
736+ assert ! ( lr_reg. coefficients( ) . abs( ) . sum( ) < lr. coefficients( ) . abs( ) . sum( ) ) ;
635737 }
636738}
0 commit comments