Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapper interface #314

Merged
merged 6 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/AnomalyDetectors/LocalOutlierFactor.php
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class LocalOutlierFactor implements Estimator, Learner, Scoring, Persistable
*
* @var Spatial
*/
protected \Rubix\ML\Graph\Trees\Spatial $tree;
protected Spatial $tree;

/**
* The precomputed k distances between each training sample and its k'th nearest neighbor.
Expand Down
2 changes: 1 addition & 1 deletion src/AnomalyDetectors/Loda.php
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class Loda implements Estimator, Learner, Online, Scoring, Persistable
*
* @var \Tensor\Matrix|null
*/
protected ?\Tensor\Matrix $r = null;
protected ?Matrix $r = null;

/**
* The edges and bin counts of each histogram.
Expand Down
4 changes: 2 additions & 2 deletions src/AnomalyDetectors/OneClassSVM.php
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class OneClassSVM implements Estimator, Learner
*
* @var svm
*/
protected \svm $svm;
protected svm $svm;

/**
* The hyper-parameters of the model.
Expand All @@ -58,7 +58,7 @@ class OneClassSVM implements Estimator, Learner
*
* @var \svmmodel|null
*/
protected ?\svmmodel $model = null;
protected ?svmmodel $model = null;

/**
* @param float $nu
Expand Down
2 changes: 1 addition & 1 deletion src/BootstrapAggregator.php
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class BootstrapAggregator implements Estimator, Learner, Parallel, Persistable
*
* @var Learner
*/
protected \Rubix\ML\Learner $base;
protected Learner $base;

/**
* The number of base learners to train in the ensemble.
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/AdaBoost.php
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class AdaBoost implements Estimator, Learner, Probabilistic, Verbose, Persistabl
*
* @var Learner
*/
protected \Rubix\ML\Learner $base;
protected Learner $base;

/**
* The learning rate of the ensemble i.e. the *shrinkage* applied to each step.
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/KDNeighbors.php
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class KDNeighbors implements Estimator, Learner, Probabilistic, Persistable
*
* @var Spatial
*/
protected \Rubix\ML\Graph\Trees\Spatial $tree;
protected Spatial $tree;

/**
* The zero vector for the possible class outcomes.
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/KNearestNeighbors.php
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class KNearestNeighbors implements Estimator, Learner, Online, Probabilistic, Pe
*
* @var Distance
*/
protected \Rubix\ML\Kernels\Distance\Distance $kernel;
protected Distance $kernel;

/**
* The zero vector for the possible class outcomes.
Expand Down
6 changes: 3 additions & 3 deletions src/Classifiers/LogisticRegression.php
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class LogisticRegression implements Estimator, Learner, Online, Probabilistic, R
*
* @var Optimizer
*/
protected \Rubix\ML\NeuralNet\Optimizers\Optimizer $optimizer;
protected Optimizer $optimizer;

/**
* The amount of L2 regularization applied to the weights of the output layer.
Expand Down Expand Up @@ -103,14 +103,14 @@ class LogisticRegression implements Estimator, Learner, Online, Probabilistic, R
*
* @var ClassificationLoss
*/
protected \Rubix\ML\NeuralNet\CostFunctions\ClassificationLoss $costFn;
protected ClassificationLoss $costFn;

/**
* The underlying neural network instance.
*
* @var \Rubix\ML\NeuralNet\FeedForward|null
*/
protected ?\Rubix\ML\NeuralNet\FeedForward $network = null;
protected ?FeedForward $network = null;

/**
* The unique class labels.
Expand Down
4 changes: 2 additions & 2 deletions src/Classifiers/LogitBoost.php
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class LogitBoost implements Estimator, Learner, Probabilistic, RanksFeatures, Ve
*
* @var Learner
*/
protected \Rubix\ML\Learner $booster;
protected Learner $booster;

/**
* The learning rate of the ensemble i.e. the *shrinkage* applied to each step.
Expand Down Expand Up @@ -138,7 +138,7 @@ class LogitBoost implements Estimator, Learner, Probabilistic, RanksFeatures, Ve
*
* @var Metric
*/
protected \Rubix\ML\CrossValidation\Metrics\Metric $metric;
protected Metric $metric;

/**
* The ensemble of boosters.
Expand Down
8 changes: 4 additions & 4 deletions src/Classifiers/MultilayerPerceptron.php
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class MultilayerPerceptron implements Estimator, Learner, Online, Probabilistic,
*
* @var Optimizer
*/
protected \Rubix\ML\NeuralNet\Optimizers\Optimizer $optimizer;
protected Optimizer $optimizer;

/**
* The amount of L2 regularization applied to the weights of the output layer.
Expand Down Expand Up @@ -127,21 +127,21 @@ class MultilayerPerceptron implements Estimator, Learner, Online, Probabilistic,
*
* @var ClassificationLoss
*/
protected \Rubix\ML\NeuralNet\CostFunctions\ClassificationLoss $costFn;
protected ClassificationLoss $costFn;

/**
* The validation metric used to score the generalization performance of the model during training.
*
* @var Metric
*/
protected \Rubix\ML\CrossValidation\Metrics\Metric $metric;
protected Metric $metric;

/**
* The underlying neural network instance.
*
* @var \Rubix\ML\NeuralNet\FeedForward|null
*/
protected ?\Rubix\ML\NeuralNet\FeedForward $network = null;
protected ?FeedForward $network = null;

/**
* The unique class labels.
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/OneVsRest.php
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class OneVsRest implements Estimator, Learner, Probabilistic, Parallel, Persista
*
* @var Learner
*/
protected \Rubix\ML\Learner $base;
protected Learner $base;

/**
* A map of each class to its binary classifier.
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/RadiusNeighbors.php
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class RadiusNeighbors implements Estimator, Learner, Probabilistic, Persistable
*
* @var Spatial
*/
protected \Rubix\ML\Graph\Trees\Spatial $tree;
protected Spatial $tree;

/**
* The class label for any samples that have 0 neighbors within the specified radius.
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/RandomForest.php
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class RandomForest implements Estimator, Learner, Probabilistic, Parallel, Ranks
*
* @var Learner
*/
protected \Rubix\ML\Learner $base;
protected Learner $base;

/**
* The number of learners to train in the ensemble.
Expand Down
6 changes: 3 additions & 3 deletions src/Classifiers/SoftmaxClassifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class SoftmaxClassifier implements Estimator, Learner, Online, Probabilistic, Ve
*
* @var Optimizer
*/
protected \Rubix\ML\NeuralNet\Optimizers\Optimizer $optimizer;
protected Optimizer $optimizer;

/**
* The amount of L2 regularization applied to the weights of the output layer.
Expand Down Expand Up @@ -99,14 +99,14 @@ class SoftmaxClassifier implements Estimator, Learner, Online, Probabilistic, Ve
*
* @var ClassificationLoss
*/
protected \Rubix\ML\NeuralNet\CostFunctions\ClassificationLoss $costFn;
protected ClassificationLoss $costFn;

/**
* The underlying neural network instance.
*
* @var \Rubix\ML\NeuralNet\FeedForward|null
*/
protected ?\Rubix\ML\NeuralNet\FeedForward $network = null;
protected ?FeedForward $network = null;

/**
* The unique class labels.
Expand Down
2 changes: 1 addition & 1 deletion src/Clusterers/DBSCAN.php
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class DBSCAN implements Estimator
*
* @var Spatial
*/
protected \Rubix\ML\Graph\Trees\Spatial $tree;
protected Spatial $tree;

/**
* @param float $radius
Expand Down
4 changes: 2 additions & 2 deletions src/Clusterers/FuzzyCMeans.php
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ class FuzzyCMeans implements Estimator, Learner, Probabilistic, Verbose, Persist
*
* @var Distance
*/
protected \Rubix\ML\Kernels\Distance\Distance $kernel;
protected Distance $kernel;

/**
* The cluster centroid seeder.
*
* @var Seeder
*/
protected \Rubix\ML\Clusterers\Seeders\Seeder $seeder;
protected Seeder $seeder;

/**
* The computed centroid vectors of the training data.
Expand Down
2 changes: 1 addition & 1 deletion src/Clusterers/GaussianMixture.php
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class GaussianMixture implements Estimator, Learner, Probabilistic, Verbose, Per
*
* @var Seeder
*/
protected \Rubix\ML\Clusterers\Seeders\Seeder $seeder;
protected Seeder $seeder;

/**
* The precomputed log prior probabilities of each cluster.
Expand Down
4 changes: 2 additions & 2 deletions src/Clusterers/KMeans.php
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ class KMeans implements Estimator, Learner, Online, Probabilistic, Verbose, Pers
*
* @var Distance
*/
protected \Rubix\ML\Kernels\Distance\Distance $kernel;
protected Distance $kernel;

/**
* The cluster centroid seeder.
*
* @var Seeder
*/
protected \Rubix\ML\Clusterers\Seeders\Seeder $seeder;
protected Seeder $seeder;

/**
* The computed centroid vectors of the training data.
Expand Down
4 changes: 2 additions & 2 deletions src/Clusterers/MeanShift.php
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ class MeanShift implements Estimator, Learner, Probabilistic, Verbose, Persistab
*
* @var Spatial
*/
protected \Rubix\ML\Graph\Trees\Spatial $tree;
protected Spatial $tree;

/**
* The cluster centroid seeder.
*
* @var Seeder
*/
protected \Rubix\ML\Clusterers\Seeders\Seeder $seeder;
protected Seeder $seeder;

/**
* The computed centroid vectors of the training data.
Expand Down
2 changes: 1 addition & 1 deletion src/Clusterers/Seeders/KMC2.php
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class KMC2 implements Seeder
*
* @var Distance
*/
protected \Rubix\ML\Kernels\Distance\Distance $kernel;
protected Distance $kernel;

/**
* @param int $m
Expand Down
2 changes: 1 addition & 1 deletion src/Clusterers/Seeders/PlusPlus.php
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class PlusPlus implements Seeder
*
* @var Distance
*/
protected \Rubix\ML\Kernels\Distance\Distance $kernel;
protected Distance $kernel;

/**
* @param \Rubix\ML\Kernels\Distance\Distance|null $kernel
Expand Down
2 changes: 1 addition & 1 deletion src/Datasets/Generators/Blob.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Blob implements Generator
*
* @var Vector
*/
protected \Tensor\Vector $center;
protected Vector $center;

/**
* The standard deviation of the blob.
Expand Down
2 changes: 1 addition & 1 deletion src/Datasets/Generators/Circle.php
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Circle implements Generator
*
* @var Vector
*/
protected \Tensor\Vector $center;
protected Vector $center;

/**
* The scaling factor of the circle.
Expand Down
2 changes: 1 addition & 1 deletion src/Datasets/Generators/HalfMoon.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class HalfMoon implements Generator
*
* @var Vector
*/
protected \Tensor\Vector $center;
protected Vector $center;

/**
* The scaling factor of the half moon.
Expand Down
2 changes: 1 addition & 1 deletion src/Datasets/Generators/Hyperplane.php
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Hyperplane implements Generator
*
* @var Vector
*/
protected \Tensor\Vector $coefficients;
protected Vector $coefficients;

/**
* The y intercept term.
Expand Down
2 changes: 1 addition & 1 deletion src/Datasets/Generators/SwissRoll.php
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class SwissRoll implements Generator
*
* @var Vector
*/
protected \Tensor\Vector $center;
protected Vector $center;

/**
* The scaling factor of the swiss roll.
Expand Down
20 changes: 20 additions & 0 deletions src/EstimatorWrapper.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<?php

namespace Rubix\ML;

/**
* Wrapper
*
* @category Machine Learning
* @package Rubix/ML
* @author Ronan Giron
*/
interface EstimatorWrapper extends Estimator
{
/**
* Return the base estimator instance.
*
* @return Estimator
*/
public function base() : Estimator;
}
2 changes: 1 addition & 1 deletion src/Extractors/SQLTable.php
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class SQLTable implements Extractor
*
* @var PDO
*/
protected \PDO $connection;
protected PDO $connection;

/**
* The name of the table to select from.
Expand Down
2 changes: 1 addition & 1 deletion src/Graph/Nodes/Clique.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Clique implements Hypersphere, BinaryNode
*
* @var Labeled
*/
protected \Rubix\ML\Datasets\Labeled $dataset;
protected Labeled $dataset;

/**
* The centroid or multivariate mean of the cluster.
Expand Down
2 changes: 1 addition & 1 deletion src/Graph/Nodes/Neighborhood.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Neighborhood implements Hypercube, BinaryNode
*
* @var Labeled
*/
protected \Rubix\ML\Datasets\Labeled $dataset;
protected Labeled $dataset;

/**
* The multivariate minimum of the bounding box.
Expand Down
4 changes: 2 additions & 2 deletions src/Graph/Nodes/Traits/HasBinaryChildrenTrait.php
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ trait HasBinaryChildrenTrait
*
* @var \Rubix\ML\Graph\Nodes\BinaryNode|null
*/
protected ?\Rubix\ML\Graph\Nodes\BinaryNode $left = null;
protected ?BinaryNode $left = null;

/**
* The right child node.
*
* @var \Rubix\ML\Graph\Nodes\BinaryNode|null
*/
protected ?\Rubix\ML\Graph\Nodes\BinaryNode $right = null;
protected ?BinaryNode $right = null;

/**
* Return the children of this node in a generator.
Expand Down
Loading
Loading