Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] feat: implementing KNN regression #17

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ A General purpose k-nearest neighbor classifier algorithm based on the k-d tree

## API

### new KNN(dataset, labels[, options])
### new KNNClassifier(dataset, labels[, options])

Instantiates the KNN algorithm.

Expand All @@ -40,7 +40,7 @@ var train_dataset = [
[2, 1, 2],
];
var train_labels = [0, 0, 0, 1, 1, 1];
var knn = new KNN(train_dataset, train_labels, { k: 2 }); // consider 2 nearest neighbors
var knn = new KNNClassifier(train_dataset, train_labels, { k: 2 }); // consider 2 nearest neighbors
```

### predict(newDataset)
Expand Down Expand Up @@ -75,14 +75,14 @@ console.log(ans);
Returns an object representing the model. This function is automatically called if `JSON.stringify(knn)` is used.
Be aware that the serialized model takes about 1.3 times the size of the input dataset (it actually is the dataset in a tree structure). Stringification can fail if the resulting string is too large.

### KNN.load(model[, distance])
### KNNClassifier.load(model[, distance])

Loads a model previously exported by `knn.toJSON()`. If a custom distance function was provided, it must be passed again.

## External links

Check this cool blog post for a detailed example:
https://hackernoon.com/machine-learning-with-javascript-part-2-da994c17d483
Check [this cool blog post](https://hackernoon.com/machine-learning-with-javascript-part-2-da994c17d483) for a detailed example.


## License

Expand Down
16 changes: 10 additions & 6 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,19 @@
},
"homepage": "https://github.com/mljs/knn#readme",
"devDependencies": {
"@babel/plugin-transform-modules-commonjs": "^7.16.8",
"eslint": "^8.10.0",
"eslint-config-cheminfo": "^7.2.2",
"jest": "^27.5.1",
"@babel/plugin-transform-modules-commonjs": "^7.20.11",
"@types/jest": "^29.4.0",
"eslint": "^8.34.0",
"eslint-config-cheminfo": "^8.1.3",
"jest": "^29.4.2",
"ml-dataset-iris": "^1.2.1",
"prettier": "^2.5.1",
"rollup": "^2.69.1"
"prettier": "^2.8.4",
"rollup": "^3.15.0"
},
"dependencies": {
"ml-distance-euclidean": "^2.0.0"
},
"jest": {
"testEnvironment": "node"
}
}
Empty file added src/__tests__/base.test.js
Empty file.
22 changes: 13 additions & 9 deletions src/__tests__/test.js → src/__tests__/classifier.test.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { getNumbers, getClasses } from 'ml-dataset-iris';

import KNN from '..';
import { KNNClassifier } from '../classifier';

describe('knn', () => {
const cases = [
Expand All @@ -13,7 +13,7 @@ describe('knn', () => {
];
const labels = [0, 0, 0, 1, 1, 1];

const knn = new KNN(cases, labels, {
const knn = new KNNClassifier(cases, labels, {
k: 3,
});

Expand All @@ -40,7 +40,7 @@ describe('knn', () => {

it('load', () => {
const model = JSON.parse(JSON.stringify(knn));
const newKnn = KNN.load(model);
const newKnn = KNNClassifier.load(model);
const result = newKnn.predict([
[1.81, 1.81, 1.81],
[0.5, 0.5, 0.5],
Expand All @@ -53,11 +53,15 @@ describe('knn', () => {
});

it('load errors', () => {
expect(() => KNN.load({})).toThrow('invalid model: undefined');
expect(() => KNN.load({ name: 'KNN', isEuclidean: true }, () => 1)).toThrow(
expect(() => KNNClassifier.load({})).toThrow('invalid model: undefined');
expect(() =>
KNNClassifier.load({ name: 'KNNClassifier', isEuclidean: true }, () => 1),
).toThrow(
'the model was created with the default distance function. Do not load it with another one',
);
expect(() => KNN.load({ name: 'KNN', isEuclidean: false })).toThrow(
expect(() =>
KNNClassifier.load({ name: 'KNNClassifier', isEuclidean: false }),
).toThrow(
'a custom distance function was used to create the model. Please provide it again',
);
});
Expand All @@ -66,7 +70,7 @@ describe('knn', () => {
let data = getNumbers();
let labels = getClasses();

let knn = new KNN(data, labels, { k: 5 });
let knn = new KNNClassifier(data, labels, { k: 5 });
let test = [
[5.1, 3.5, 1.4, 0.2],
[4.9, 3.0, 1.4, 0.2],
Expand All @@ -85,7 +89,7 @@ describe('knn', () => {
[6.5, 3.0, 5.2, 2.0],
];

knn = KNN.load(JSON.parse(JSON.stringify(knn)));
knn = KNNClassifier.load(JSON.parse(JSON.stringify(knn)));
let expected = [
'setosa',
'setosa',
Expand Down Expand Up @@ -117,7 +121,7 @@ describe('knn', () => {
[2, 1, 2],
];
const predictions = [0, 0, 0, 1, 1, 1];
const knn = new KNN(dataset, predictions);
const knn = new KNNClassifier(dataset, predictions);

expect(knn.k).toBe(3);

Expand Down
74 changes: 74 additions & 0 deletions src/base.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import { euclidean as euclideanDistance } from 'ml-distance-euclidean';

import KDTree from './KDTree';
import { loadDistanceCheck } from './utils';

export class KNN {
/**
* @param {Array} dataset
* @param {Array} labels
* @param {object} options
* @param {number} [options.k=numberOfClasses + 1] - Number of neighbors to classify.
* @param {function} [options.distance=euclideanDistance] - Distance function that takes two parameters.
*/
constructor(dataset, labels, options = {}) {
if (dataset === true) {
const model = labels;
this.kdTree = new KDTree(model.kdTree, options);
this.k = model.k;
this.classes = new Set(model.classes);
this.isEuclidean = model.isEuclidean;
return;
}

const classes = new Set(labels);

const { distance = euclideanDistance, k = classes.size + 1 } = options;

const points = new Array(dataset.length);
for (let i = 0; i < points.length; ++i) {
points[i] = dataset[i].slice();
}

for (let i = 0; i < labels.length; ++i) {
points[i].push(labels[i]);
}

this.kdTree = new KDTree(points, distance);
this.k = k;
this.classes = classes;
this.isEuclidean = distance === euclideanDistance;
}

/**
* Create a new KNN instance with the given model.
* @param {object} model
* @param {function} distance=euclideanDistance - distance function must be provided if the model wasn't trained with euclidean distance.
* @return {KNN}
*/
static load(model, distance = euclideanDistance) {
if (model.name !== 'KNN') {
throw new TypeError(`invalid model: ${model.name}`);
}
loadDistanceCheck(model, distance);
return new KNN(true, model, distance);
}

/**
* Return a JSON containing the kd-tree model.
* @return {object} JSON KNN model.
*/
toJSON() {
return {
name: 'KNN',
kdTree: this.kdTree,
k: this.k,
classes: Array.from(this.classes),
isEuclidean: this.isEuclidean,
};
}

predict(dataset) {
throw new Error('predict method must be implemented');
}
}
81 changes: 81 additions & 0 deletions src/classifier.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import { euclidean as euclideanDistance } from 'ml-distance-euclidean';

import { KNN } from './base';
import { loadDistanceCheck } from './utils';

export class KNNClassifier extends KNN {
/**
* Predicts the output given the matrix to predict.
* @param {Array} dataset
* @return {Array} predictions
*/

/**
* Return a JSON containing the kd-tree model.
* @return {object} JSON KNN model.
*/
toJSON() {
return {
name: 'KNNClassifier',
kdTree: this.kdTree,
k: this.k,
classes: Array.from(this.classes),
isEuclidean: this.isEuclidean,
};
}

/**
* Create a new KNNClassifier instance with the given model.
* @param {object} model
* @param {function} distance=euclideanDistance - distance function must be provided if the model wasn't trained with euclidean distance.
* @return {KNNClassifier}
*/
static load(model, distance = euclideanDistance) {
if (model.name === undefined || model.name !== 'KNNClassifier') {
throw new TypeError(`invalid model: ${model.name}`);
}
loadDistanceCheck(model, distance);
return new KNNClassifier(true, model, distance);
}

predict(dataset) {
if (Array.isArray(dataset)) {
if (typeof dataset[0] === 'number') {
return getSinglePrediction(this, dataset);
} else if (
Array.isArray(dataset[0]) &&
typeof dataset[0][0] === 'number'
) {
const predictions = new Array(dataset.length);
for (let i = 0; i < dataset.length; i++) {
predictions[i] = getSinglePrediction(this, dataset[i]);
}
return predictions;
}
}
throw new TypeError('dataset to predict must be an array or a matrix');
}
}

function getSinglePrediction(knn, currentCase) {
let nearestPoints = knn.kdTree.nearest(currentCase, knn.k);
let pointsPerClass = {};
let predictedClass = -1;
let maxPoints = -1;
let lastElement = nearestPoints[0][0].length - 1;

for (let element of knn.classes) {
pointsPerClass[element] = 0;
}

for (let i = 0; i < nearestPoints.length; ++i) {
let currentClass = nearestPoints[i][0][lastElement];
let currentPoints = ++pointsPerClass[currentClass];
if (currentPoints > maxPoints) {
predictedClass = currentClass;
maxPoints = currentPoints;
}
}

return predictedClass;
}
Loading