diff --git a/README.md b/README.md index d534053..9bc1a6b 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ A General purpose k-nearest neighbor classifier algorithm based on the k-d tree ## API -### new KNN(dataset, labels[, options]) +### new KNNClassifier(dataset, labels[, options]) Instantiates the KNN algorithm. @@ -40,7 +40,7 @@ var train_dataset = [ [2, 1, 2], ]; var train_labels = [0, 0, 0, 1, 1, 1]; -var knn = new KNN(train_dataset, train_labels, { k: 2 }); // consider 2 nearest neighbors +var knn = new KNNClassifier(train_dataset, train_labels, { k: 2 }); // consider 2 nearest neighbors ``` ### predict(newDataset) @@ -75,14 +75,14 @@ console.log(ans); Returns an object representing the model. This function is automatically called if `JSON.stringify(knn)` is used. Be aware that the serialized model takes about 1.3 times the size of the input dataset (it actually is the dataset in a tree structure). Stringification can fail if the resulting string is too large. -### KNN.load(model[, distance]) +### KNNClassifier.load(model[, distance]) Loads a model previously exported by `knn.toJSON()`. If a custom distance function was provided, it must be passed again. ## External links -Check this cool blog post for a detailed example: -https://hackernoon.com/machine-learning-with-javascript-part-2-da994c17d483 +Check [this cool blog post](https://hackernoon.com/machine-learning-with-javascript-part-2-da994c17d483) for a detailed example. + ## License diff --git a/package.json b/package.json index 1e195dc..fade391 100644 --- a/package.json +++ b/package.json @@ -36,15 +36,19 @@ }, "homepage": "https://github.com/mljs/knn#readme", "devDependencies": { - "@babel/plugin-transform-modules-commonjs": "^7.16.8", - "eslint": "^8.10.0", - "eslint-config-cheminfo": "^7.2.2", - "jest": "^27.5.1", + "@babel/plugin-transform-modules-commonjs": "^7.20.11", + "@types/jest": "^29.4.0", + "eslint": "^8.34.0", + "eslint-config-cheminfo": "^8.1.3", + "jest": "^29.4.2", "ml-dataset-iris": "^1.2.1", - "prettier": "^2.5.1", - "rollup": "^2.69.1" + "prettier": "^2.8.4", + "rollup": "^3.15.0" }, "dependencies": { "ml-distance-euclidean": "^2.0.0" + }, + "jest": { + "testEnvironment": "node" } } diff --git a/src/__tests__/base.test.js b/src/__tests__/base.test.js new file mode 100644 index 0000000..e69de29 diff --git a/src/__tests__/test.js b/src/__tests__/classifier.test.js similarity index 81% rename from src/__tests__/test.js rename to src/__tests__/classifier.test.js index 42d0d64..44061d5 100644 --- a/src/__tests__/test.js +++ b/src/__tests__/classifier.test.js @@ -1,6 +1,6 @@ import { getNumbers, getClasses } from 'ml-dataset-iris'; -import KNN from '..'; +import { KNNClassifier } from '../classifier'; describe('knn', () => { const cases = [ @@ -13,7 +13,7 @@ describe('knn', () => { ]; const labels = [0, 0, 0, 1, 1, 1]; - const knn = new KNN(cases, labels, { + const knn = new KNNClassifier(cases, labels, { k: 3, }); @@ -40,7 +40,7 @@ describe('knn', () => { it('load', () => { const model = JSON.parse(JSON.stringify(knn)); - const newKnn = KNN.load(model); + const newKnn = KNNClassifier.load(model); const result = newKnn.predict([ [1.81, 1.81, 1.81], [0.5, 0.5, 0.5], @@ -53,11 +53,15 @@ describe('knn', () => { }); it('load errors', () => { - expect(() => KNN.load({})).toThrow('invalid model: undefined'); - expect(() => KNN.load({ name: 'KNN', isEuclidean: true }, () => 1)).toThrow( + expect(() => KNNClassifier.load({})).toThrow('invalid model: undefined'); + expect(() => + KNNClassifier.load({ name: 'KNNClassifier', isEuclidean: true }, () => 1), + ).toThrow( 'the model was created with the default distance function. Do not load it with another one', ); - expect(() => KNN.load({ name: 'KNN', isEuclidean: false })).toThrow( + expect(() => + KNNClassifier.load({ name: 'KNNClassifier', isEuclidean: false }), + ).toThrow( 'a custom distance function was used to create the model. Please provide it again', ); }); @@ -66,7 +70,7 @@ describe('knn', () => { let data = getNumbers(); let labels = getClasses(); - let knn = new KNN(data, labels, { k: 5 }); + let knn = new KNNClassifier(data, labels, { k: 5 }); let test = [ [5.1, 3.5, 1.4, 0.2], [4.9, 3.0, 1.4, 0.2], @@ -85,7 +89,7 @@ describe('knn', () => { [6.5, 3.0, 5.2, 2.0], ]; - knn = KNN.load(JSON.parse(JSON.stringify(knn))); + knn = KNNClassifier.load(JSON.parse(JSON.stringify(knn))); let expected = [ 'setosa', 'setosa', @@ -117,7 +121,7 @@ describe('knn', () => { [2, 1, 2], ]; const predictions = [0, 0, 0, 1, 1, 1]; - const knn = new KNN(dataset, predictions); + const knn = new KNNClassifier(dataset, predictions); expect(knn.k).toBe(3); diff --git a/src/base.js b/src/base.js new file mode 100644 index 0000000..6a955b0 --- /dev/null +++ b/src/base.js @@ -0,0 +1,74 @@ +import { euclidean as euclideanDistance } from 'ml-distance-euclidean'; + +import KDTree from './KDTree'; +import { loadDistanceCheck } from './utils'; + +export class KNN { + /** + * @param {Array} dataset + * @param {Array} labels + * @param {object} options + * @param {number} [options.k=numberOfClasses + 1] - Number of neighbors to classify. + * @param {function} [options.distance=euclideanDistance] - Distance function that takes two parameters. + */ + constructor(dataset, labels, options = {}) { + if (dataset === true) { + const model = labels; + this.kdTree = new KDTree(model.kdTree, options); + this.k = model.k; + this.classes = new Set(model.classes); + this.isEuclidean = model.isEuclidean; + return; + } + + const classes = new Set(labels); + + const { distance = euclideanDistance, k = classes.size + 1 } = options; + + const points = new Array(dataset.length); + for (let i = 0; i < points.length; ++i) { + points[i] = dataset[i].slice(); + } + + for (let i = 0; i < labels.length; ++i) { + points[i].push(labels[i]); + } + + this.kdTree = new KDTree(points, distance); + this.k = k; + this.classes = classes; + this.isEuclidean = distance === euclideanDistance; + } + + /** + * Create a new KNN instance with the given model. + * @param {object} model + * @param {function} distance=euclideanDistance - distance function must be provided if the model wasn't trained with euclidean distance. + * @return {KNN} + */ + static load(model, distance = euclideanDistance) { + if (model.name !== 'KNN') { + throw new TypeError(`invalid model: ${model.name}`); + } + loadDistanceCheck(model, distance); + return new KNN(true, model, distance); + } + + /** + * Return a JSON containing the kd-tree model. + * @return {object} JSON KNN model. + */ + toJSON() { + return { + name: 'KNN', + kdTree: this.kdTree, + k: this.k, + classes: Array.from(this.classes), + isEuclidean: this.isEuclidean, + }; + } + + predict(dataset) { + throw new Error('predict method must be implemented'); + } +} diff --git a/src/classifier.js b/src/classifier.js new file mode 100644 index 0000000..58fefac --- /dev/null +++ b/src/classifier.js @@ -0,0 +1,81 @@ +import { euclidean as euclideanDistance } from 'ml-distance-euclidean'; + +import { KNN } from './base'; +import { loadDistanceCheck } from './utils'; + +export class KNNClassifier extends KNN { + /** + * Predicts the output given the matrix to predict. + * @param {Array} dataset + * @return {Array} predictions + */ + + /** + * Return a JSON containing the kd-tree model. + * @return {object} JSON KNN model. + */ + toJSON() { + return { + name: 'KNNClassifier', + kdTree: this.kdTree, + k: this.k, + classes: Array.from(this.classes), + isEuclidean: this.isEuclidean, + }; + } + + /** + * Create a new KNNClassifier instance with the given model. + * @param {object} model + * @param {function} distance=euclideanDistance - distance function must be provided if the model wasn't trained with euclidean distance. + * @return {KNNClassifier} + */ + static load(model, distance = euclideanDistance) { + if (model.name === undefined || model.name !== 'KNNClassifier') { + throw new TypeError(`invalid model: ${model.name}`); + } + loadDistanceCheck(model, distance); + return new KNNClassifier(true, model, distance); + } + + predict(dataset) { + if (Array.isArray(dataset)) { + if (typeof dataset[0] === 'number') { + return getSinglePrediction(this, dataset); + } else if ( + Array.isArray(dataset[0]) && + typeof dataset[0][0] === 'number' + ) { + const predictions = new Array(dataset.length); + for (let i = 0; i < dataset.length; i++) { + predictions[i] = getSinglePrediction(this, dataset[i]); + } + return predictions; + } + } + throw new TypeError('dataset to predict must be an array or a matrix'); + } +} + +function getSinglePrediction(knn, currentCase) { + let nearestPoints = knn.kdTree.nearest(currentCase, knn.k); + let pointsPerClass = {}; + let predictedClass = -1; + let maxPoints = -1; + let lastElement = nearestPoints[0][0].length - 1; + + for (let element of knn.classes) { + pointsPerClass[element] = 0; + } + + for (let i = 0; i < nearestPoints.length; ++i) { + let currentClass = nearestPoints[i][0][lastElement]; + let currentPoints = ++pointsPerClass[currentClass]; + if (currentPoints > maxPoints) { + predictedClass = currentClass; + maxPoints = currentPoints; + } + } + + return predictedClass; +} diff --git a/src/index.js b/src/index.js index 84cee35..e69de29 100644 --- a/src/index.js +++ b/src/index.js @@ -1,124 +0,0 @@ -import { euclidean as euclideanDistance } from 'ml-distance-euclidean'; - -import KDTree from './KDTree'; - -export default class KNN { - /** - * @param {Array} dataset - * @param {Array} labels - * @param {object} options - * @param {number} [options.k=numberOfClasses + 1] - Number of neighbors to classify. - * @param {function} [options.distance=euclideanDistance] - Distance function that takes two parameters. - */ - constructor(dataset, labels, options = {}) { - if (dataset === true) { - const model = labels; - this.kdTree = new KDTree(model.kdTree, options); - this.k = model.k; - this.classes = new Set(model.classes); - this.isEuclidean = model.isEuclidean; - return; - } - - const classes = new Set(labels); - - const { distance = euclideanDistance, k = classes.size + 1 } = options; - - const points = new Array(dataset.length); - for (let i = 0; i < points.length; ++i) { - points[i] = dataset[i].slice(); - } - - for (let i = 0; i < labels.length; ++i) { - points[i].push(labels[i]); - } - - this.kdTree = new KDTree(points, distance); - this.k = k; - this.classes = classes; - this.isEuclidean = distance === euclideanDistance; - } - - /** - * Create a new KNN instance with the given model. - * @param {object} model - * @param {function} distance=euclideanDistance - distance function must be provided if the model wasn't trained with euclidean distance. - * @return {KNN} - */ - static load(model, distance = euclideanDistance) { - if (model.name !== 'KNN') { - throw new Error(`invalid model: ${model.name}`); - } - if (!model.isEuclidean && distance === euclideanDistance) { - throw new Error( - 'a custom distance function was used to create the model. Please provide it again', - ); - } - if (model.isEuclidean && distance !== euclideanDistance) { - throw new Error( - 'the model was created with the default distance function. Do not load it with another one', - ); - } - return new KNN(true, model, distance); - } - - /** - * Return a JSON containing the kd-tree model. - * @return {object} JSON KNN model. - */ - toJSON() { - return { - name: 'KNN', - kdTree: this.kdTree, - k: this.k, - classes: Array.from(this.classes), - isEuclidean: this.isEuclidean, - }; - } - - /** - * Predicts the output given the matrix to predict. - * @param {Array} dataset - * @return {Array} predictions - */ - predict(dataset) { - if (Array.isArray(dataset)) { - if (typeof dataset[0] === 'number') { - return getSinglePrediction(this, dataset); - } else if ( - Array.isArray(dataset[0]) && - typeof dataset[0][0] === 'number' - ) { - const predictions = new Array(dataset.length); - for (let i = 0; i < dataset.length; i++) { - predictions[i] = getSinglePrediction(this, dataset[i]); - } - return predictions; - } - } - throw new TypeError('dataset to predict must be an array or a matrix'); - } -} - -function getSinglePrediction(knn, currentCase) { - let nearestPoints = knn.kdTree.nearest(currentCase, knn.k); - let pointsPerClass = {}; - let predictedClass = -1; - let maxPoints = -1; - let lastElement = nearestPoints[0][0].length - 1; - - for (let element of knn.classes) { - pointsPerClass[element] = 0; - } - - for (let i = 0; i < nearestPoints.length; ++i) { - let currentClass = nearestPoints[i][0][lastElement]; - let currentPoints = ++pointsPerClass[currentClass]; - if (currentPoints > maxPoints) { - predictedClass = currentClass; - maxPoints = currentPoints; - } - } - - return predictedClass; -} diff --git a/src/regressors.js b/src/regressors.js new file mode 100644 index 0000000..e69de29 diff --git a/src/utils.js b/src/utils.js new file mode 100644 index 0000000..e2eb747 --- /dev/null +++ b/src/utils.js @@ -0,0 +1,14 @@ +import { euclidean as euclideanDistance } from 'ml-distance-euclidean'; + +export function loadDistanceCheck(model, distance) { + if (!model.isEuclidean && distance === euclideanDistance) { + throw new Error( + 'a custom distance function was used to create the model. Please provide it again', + ); + } + if (model.isEuclidean && distance !== euclideanDistance) { + throw new Error( + 'the model was created with the default distance function. Do not load it with another one', + ); + } +}