Skip to content
Draft
25 changes: 14 additions & 11 deletions src/backends/onnx.js → src/backends/onnx.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@
* @module backends/onnx
*/

import { env, apis } from '../env.js';
import { env, apis } from '../env';

// NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`.
// In either case, we select the default export if it exists, otherwise we use the named export.
import * as ONNX_NODE from 'onnxruntime-node';
import * as ONNX_WEB from 'onnxruntime-web';
import { DeviceType } from '../utils/devices';
import { InferenceSession as ONNXInferenceSession } from 'onnxruntime-common';

export { Tensor } from 'onnxruntime-common';

/**
* @typedef {import('onnxruntime-common').InferenceSession.ExecutionProviderConfig} ONNXExecutionProviders
*/
type ONNXExecutionProviders = ONNXInferenceSession.ExecutionProviderConfig;

/** @type {Record<import("../utils/devices.js").DeviceType, ONNXExecutionProviders>} */
const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({
const DEVICE_TO_EXECUTION_PROVIDER_MAPPING: Record<DeviceType, ONNXExecutionProviders> = Object.freeze({
auto: null, // Auto-detect based on device and environment
gpu: null, // Auto-detect GPU
cpu: 'cpu', // CPU
Expand All @@ -49,10 +52,10 @@ const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({
* The list of supported devices, sorted by priority/performance.
* @type {import("../utils/devices.js").DeviceType[]}
*/
const supportedDevices = [];
const supportedDevices: DeviceType[] = [];

/** @type {ONNXExecutionProviders[]} */
let defaultDevices;
let defaultDevices: ONNXExecutionProviders[];
let ONNX;
const ORT_SYMBOL = Symbol.for('onnxruntime');

Expand All @@ -61,7 +64,7 @@ if (ORT_SYMBOL in globalThis) {
ONNX = globalThis[ORT_SYMBOL];

} else if (apis.IS_NODE_ENV) {
ONNX = ONNX_NODE.default ?? ONNX_NODE;
ONNX = ONNX_NODE;

// Updated as of ONNX Runtime 1.20.1
// The following table lists the supported versions of ONNX Runtime Node.js binding provided with pre-built binaries.
Expand Down Expand Up @@ -109,7 +112,7 @@ const InferenceSession = ONNX.InferenceSession;
* @param {import("../utils/devices.js").DeviceType|"auto"|null} [device=null] (Optional) The device to run the inference on.
* @returns {ONNXExecutionProviders[]} The execution providers to use for the given device.
*/
export function deviceToExecutionProviders(device = null) {
export function deviceToExecutionProviders(device: DeviceType | "auto" | null = null): ONNXExecutionProviders[] {
// Use the default execution providers if the user hasn't specified anything
if (!device) return defaultDevices;

Expand Down Expand Up @@ -137,7 +140,7 @@ export function deviceToExecutionProviders(device = null) {
* will wait for this Promise to resolve before creating their own InferenceSession.
* @type {Promise<any>|null}
*/
let wasmInitPromise = null;
let wasmInitPromise: Promise<any> | null = null;

/**
* Create an ONNX inference session.
Expand All @@ -146,7 +149,7 @@ let wasmInitPromise = null;
* @param {Object} session_config ONNX inference session configuration.
* @returns {Promise<import('onnxruntime-common').InferenceSession & { config: Object}>} The ONNX inference session.
*/
export async function createInferenceSession(buffer, session_options, session_config) {
export async function createInferenceSession(buffer: Uint8Array, session_options: ONNXInferenceSession.SessionOptions, session_config: Object): Promise<ONNXInferenceSession & { config: Object; }> {
if (wasmInitPromise) {
// A previous session has already initialized the WASM runtime
// so we wait for it to resolve before creating this new session.
Expand All @@ -165,13 +168,13 @@ export async function createInferenceSession(buffer, session_options, session_co
* @param {any} x The object to check
* @returns {boolean} Whether the object is an ONNX tensor.
*/
export function isONNXTensor(x) {
export function isONNXTensor(x: any): boolean {
return x instanceof ONNX.Tensor;
}

/** @type {import('onnxruntime-common').Env} */
// @ts-ignore
const ONNX_ENV = ONNX?.env;
const ONNX_ENV: Env = ONNX?.env;
if (ONNX_ENV?.wasm) {
// Initialize wasm backend with suitable default settings.

Expand Down Expand Up @@ -202,7 +205,7 @@ if (ONNX_ENV?.webgpu) {
* Check if ONNX's WASM backend is being proxied.
* @returns {boolean} Whether ONNX's WASM backend is being proxied.
*/
export function isONNXProxy() {
export function isONNXProxy(): boolean {
// TODO: Update this when allowing non-WASM backends.
return ONNX_ENV?.wasm?.proxy;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import { FEATURE_EXTRACTOR_NAME } from "../utils/constants.js";
import { Callable } from "../utils/generic.js";
import { getModelJSON } from "../utils/hub.js";
import { getModelJSON, PretrainedOptions } from "../utils/hub.js";

/**
* Base class for feature extractors.
*/
export class FeatureExtractor extends Callable {
config: Object;
/**
* Constructs a new FeatureExtractor instance.
*
* @param {Object} config The configuration for the feature extractor.
*/
constructor(config) {
constructor(config: Object) {
super();
this.config = config
}
Expand All @@ -27,11 +28,11 @@ export class FeatureExtractor extends Callable {
* Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
* user or organization name, like `dbmdz/bert-base-german-cased`.
* - A path to a *directory* containing feature_extractor files, e.g., `./my_model_directory/`.
* @param {import('../utils/hub.js').PretrainedOptions} options Additional options for loading the feature_extractor.
* @param {import('../utils/hub').PretrainedOptions} options Additional options for loading the feature_extractor.
*
* @returns {Promise<FeatureExtractor>} A new instance of the Feature Extractor class.
*/
static async from_pretrained(pretrained_model_name_or_path, options) {
static async from_pretrained(pretrained_model_name_or_path: string, options: PretrainedOptions): Promise<FeatureExtractor> {
const config = await getModelJSON(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, true, options);
return new this(config);
}
Expand All @@ -44,9 +45,10 @@ export class FeatureExtractor extends Callable {
* @param {string} feature_extractor The name of the feature extractor.
* @private
*/
export function validate_audio_inputs(audio, feature_extractor) {
export function validate_audio_inputs(audio: Float32Array | Float64Array, feature_extractor: string) {
if (!(audio instanceof Float32Array || audio instanceof Float64Array)) {
throw new Error(
// @ts-expect-error TS2339
`${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead. ` +
`If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.`
)
Expand Down
Loading