From d819198aa8d8b992b55c6eac55e6e9e13efa72d3 Mon Sep 17 00:00:00 2001 From: Andy Wermke Date: Sat, 20 Jun 2020 11:30:43 +0200 Subject: [PATCH 1/3] Major code change: Introduce Callback() and use it for spawn()/expose(), too --- observable.d.ts | 2 +- observable.js | 2 +- observable.mjs | 2 +- package-lock.json | 6 +- package.json | 2 +- src/common/call-proxy.ts | 296 ++++++++++++++++++++ src/common/callbacks.ts | 72 +++++ src/{ => common}/observable-promise.ts | 0 src/{common.ts => common/serialization.ts} | 16 +- src/{ => common}/transferable.ts | 2 +- src/index.ts | 7 +- src/master/implementation.browser.ts | 7 +- src/master/implementation.node.ts | 2 +- src/master/index.ts | 1 + src/master/invocation-proxy.ts | 152 ---------- src/master/spawn.ts | 19 +- src/serializers.ts | 84 ------ src/serializers/callbacks.ts | 23 ++ src/serializers/errors.ts | 21 ++ src/serializers/index.ts | 58 ++++ src/symbols.ts | 1 + src/types/common.ts | 14 + src/types/master.ts | 14 +- src/types/messages.ts | 81 +++--- src/types/serializers.ts | 40 +++ src/types/worker.ts | 6 +- src/util/events.ts | 35 +++ src/{ => util}/observable.ts | 0 src/{ => util}/promise.ts | 0 src/worker/implementation.browser.ts | 30 +- src/worker/implementation.tiny-worker.ts | 30 +- src/worker/implementation.worker_threads.ts | 35 +-- src/worker/index.ts | 159 ++--------- test/callbacks.test.ts | 16 ++ test/observable-promise.test.ts | 2 +- test/observable.test.ts | 2 +- test/transferables.test.ts | 4 +- test/webpack/app.ts | 2 +- test/webpack/webpack.node.config.js | 2 +- test/workers/map.ts | 7 + test/workers/minmax.ts | 2 +- tsconfig.json | 2 +- 42 files changed, 739 insertions(+), 519 deletions(-) create mode 100644 src/common/call-proxy.ts create mode 100644 src/common/callbacks.ts rename src/{ => common}/observable-promise.ts (100%) rename src/{common.ts => common/serialization.ts} (52%) rename src/{ => common}/transferable.ts (98%) delete mode 100644 src/master/invocation-proxy.ts delete mode 100644 src/serializers.ts create mode 100644 src/serializers/callbacks.ts create mode 100644 src/serializers/errors.ts create mode 100644 src/serializers/index.ts create mode 100644 src/types/common.ts create mode 100644 src/types/serializers.ts create mode 100644 src/util/events.ts rename src/{ => util}/observable.ts (100%) rename src/{ => util}/promise.ts (100%) create mode 100644 test/callbacks.test.ts create mode 100644 test/workers/map.ts diff --git a/observable.d.ts b/observable.d.ts index 32a958a3..d9c71f0c 100644 --- a/observable.d.ts +++ b/observable.d.ts @@ -1 +1 @@ -export * from "./dist/observable" +export * from "./dist/util/observable" diff --git a/observable.js b/observable.js index 14610a92..a0a3196a 100644 --- a/observable.js +++ b/observable.js @@ -1 +1 @@ -module.exports = require("./dist/observable") +module.exports = require("./dist/util/observable") diff --git a/observable.mjs b/observable.mjs index 52dc86fa..ca97cbf5 100644 --- a/observable.mjs +++ b/observable.mjs @@ -1,4 +1,4 @@ -import Observables from "./dist/observable.js" +import Observables from "./dist/util/observable.js" export const Observable = Observables.Observable export const Subject = Observables.Subject diff --git a/package-lock.json b/package-lock.json index cf68d143..b3bd5508 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10926,9 +10926,9 @@ } }, "threads-plugin": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/threads-plugin/-/threads-plugin-1.2.0.tgz", - "integrity": "sha512-sWJbMh7T+59hVr0MgwseNgVLuK53dNVvu4n+1C3/X3zEIkI/rwkwpaNML20fyKvwjTjLwWblwQpx6Gu+VJjBog==", + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/threads-plugin/-/threads-plugin-1.3.2.tgz", + "integrity": "sha512-62UGIGmDA9Vqx0vHhedmtcBIw7N/ByNvLUnU512i9l/Jzfx6XWR7XmbygWX82t7cbXDVr8PY4AHw6x0M3l3YSg==", "dev": true, "requires": { "loader-utils": "^1.1.0" diff --git a/package.json b/package.json index 39789844..6080cfc8 100644 --- a/package.json +++ b/package.json @@ -88,7 +88,7 @@ "rollup": "^1.16.2", "rollup-plugin-commonjs": "^10.0.1", "rollup-plugin-node-resolve": "^5.1.0", - "threads-plugin": "^1.2.0", + "threads-plugin": "^1.3.2", "tiny-worker": "^2.2.0", "ts-loader": "^6.0.1", "ts-node": "^8.10.2", diff --git a/src/common/call-proxy.ts b/src/common/call-proxy.ts new file mode 100644 index 00000000..c1753125 --- /dev/null +++ b/src/common/call-proxy.ts @@ -0,0 +1,296 @@ +/* + * This source file contains the code for proxying calls in the master thread to calls in the workers + * by `.postMessage()`-ing. + * + * Keep in mind that this code can make or break the program's performance! Need to optimize more… + */ + +import { Debugger } from "debug" +import isSomeObservable from "is-observable" +import { multicast, Observable, Subscription } from "observable-fns" +import { MessageRelay } from "../types/common" +import { + ModuleMethods, + ModuleProxy, + ProxyableFunction +} from "../types/master" +import { + CallCancelMessage, + CallErrorMessage, + CallInvocationMessage, + CallResultMessage, + CallRunningMessage, + CommonMessageType +} from "../types/messages" +import { SerializedError, Serializer } from "../types/serializers" +import { lookupLocalCallback, Callback } from "./callbacks" +import { ObservablePromise } from "./observable-promise" +import { isTransferDescriptor } from "./transferable" + +let nextCallID = 1 + +const activeSubscriptions = new Map>() + +const dedupe = (array: T[]): T[] => Array.from(new Set(array)) + +const isCallCancelMessage = (data: any): data is CallCancelMessage => data && data.type === CommonMessageType.cancel +const isCallErrorMessage = (data: any): data is CallErrorMessage => data && data.type === CommonMessageType.error +const isCallResultMessage = (data: any): data is CallResultMessage => data && data.type === CommonMessageType.result +const isCallRunningMessage = (data: any): data is CallRunningMessage => data && data.type === CommonMessageType.running +const isInvocationMessage = (data: any): data is CallInvocationMessage => data && data.type === CommonMessageType.invoke + +function isZenObservable(thing: any): thing is Observable { + return thing && typeof thing === "object" && typeof thing.subscribe === "function" +} + +/** + * There are issues with `is-observable` not recognizing zen-observable's instances. + * We are using `observable-fns`, but it's based on zen-observable, too. + */ +function isObservable(thing: any): thing is Observable { + return isSomeObservable(thing) || isZenObservable(thing) +} + +function deconstructTransfer(thing: any) { + return isTransferDescriptor(thing) + ? { payload: thing.send, transferables: thing.transferables } + : { payload: thing, transferables: undefined } +} + +function postCallError(relay: MessageRelay, uid: number, rawError: SerializedError) { + const { payload: error, transferables } = deconstructTransfer(rawError) + const errorMessage: CallErrorMessage = { + type: CommonMessageType.error, + uid, + error + } + relay.postMessage(errorMessage, transferables) +} + +function postCallResult(relay: MessageRelay, uid: number, completed: boolean, resultValue?: any) { + const { payload, transferables } = deconstructTransfer(resultValue) + const resultMessage: CallResultMessage = { + type: CommonMessageType.result, + uid, + complete: completed ? true : undefined, + payload + } + relay.postMessage(resultMessage, transferables) +} + +function postCallRunning(relay: MessageRelay, uid: number, resultType: CallRunningMessage["resultType"]) { + const startMessage: CallRunningMessage = { + type: CommonMessageType.running, + uid, + resultType + } + relay.postMessage(startMessage) +} + +function createObservableForJob( + relay: MessageRelay, + serializer: Serializer, + callID: number, + debug: Debugger +): Observable { + return new Observable(observer => { + let asyncType: "observable" | "promise" | undefined + + const messageHandler = ((event: MessageEvent) => { + const message = event.data + + if (!message || message.uid !== callID) return + debug(`Received message for running call ${callID}:`, message) + + if (isCallRunningMessage(message)) { + asyncType = message.resultType + } else if (isCallResultMessage(message)) { + if (asyncType === "promise") { + if (typeof message.payload !== "undefined") { + observer.next(serializer.deserialize(message.payload, relay)) + } + observer.complete() + relay.removeEventListener("message", messageHandler) + } else { + if (message.payload) { + observer.next(serializer.deserialize(message.payload, relay)) + } + if (message.complete) { + observer.complete() + relay.removeEventListener("message", messageHandler) + } + } + } else if (isCallErrorMessage(message)) { + const error = serializer.deserialize(message.error as any, relay) + if (asyncType === "promise" || !asyncType) { + observer.error(error) + } else { + observer.error(error) + } + relay.removeEventListener("message", messageHandler) + } + }) as EventListener + + relay.addEventListener("message", messageHandler) + + return () => { + if (asyncType === "observable" || !asyncType) { + const cancelMessage: CallCancelMessage = { + type: CommonMessageType.cancel, + uid: callID + } + relay.postMessage(cancelMessage) + } + relay.removeEventListener("message", messageHandler) + } + }) +} + +function prepareArguments(serializer: Serializer, rawArgs: any[]): { args: any[], transferables: Transferable[] } { + if (rawArgs.length === 0) { + // Exit early if possible + return { + args: [], + transferables: [] + } + } + + const args: any[] = [] + const transferables: Transferable[] = [] + + for (const arg of rawArgs) { + if (isTransferDescriptor(arg)) { + args.push(serializer.serialize(arg.send)) + transferables.push(...arg.transferables) + } else { + args.push(serializer.serialize(arg)) + } + } + + return { + args, + transferables: transferables.length === 0 ? transferables : dedupe(transferables) + } +} + +export function createProxyFunction( + relay: MessageRelay, + serializer: Serializer, + fid: number, + debug: Debugger +) { + return ((...rawArgs: Args) => { + const uid = nextCallID++ + const { args, transferables } = prepareArguments(serializer, rawArgs) + const runMessage: CallInvocationMessage = { + type: CommonMessageType.invoke, + fid, + uid, + args + } + + debug("Sending command to run function to worker:", runMessage) + + try { + relay.postMessage(runMessage, transferables) + } catch (error) { + return ObservablePromise.from(Promise.reject(error)) + } + + return ObservablePromise.from(multicast(createObservableForJob(relay, serializer, uid, debug))) + }) as any as ProxyableFunction +} + +export function createProxyModule( + relay: MessageRelay, + serializer: Serializer, + methods: Record, + debug: Debugger +): ModuleProxy { + const proxy: any = {} + + for (const methodName of Object.keys(methods)) { + proxy[methodName] = createProxyFunction(relay, serializer, methods[methodName], debug) + } + + return proxy +} + +async function invokeExposedLocalFunction( + relay: MessageRelay, + serializer: Serializer, + callback: Callback, + message: CallInvocationMessage +) { + let syncResult: any + const uid = message.uid + + try { + const args = message.args.map(arg => serializer.deserialize(arg, relay)) + syncResult = callback(...args) + } catch (error) { + postCallError(relay, uid, serializer.serialize(error) as any as SerializedError) + } + + const resultType = isObservable(syncResult) ? "observable" : "promise" + postCallRunning(relay, uid, resultType) + + if (isObservable(syncResult)) { + const subscription = syncResult.subscribe( + value => postCallResult(relay, uid, false, serializer.serialize(value)), + error => postCallError(relay, uid, serializer.serialize(error) as any), + () => postCallResult(relay, uid, true) + ) + activeSubscriptions.set(uid, subscription) + } else { + try { + const result = await syncResult + postCallResult(relay, uid, true, serializer.serialize(result)) + } catch (error) { + postCallError(relay, uid, serializer.serialize(error) as any) + } + } +} + +function handleRemoteInvocation( + relay: MessageRelay, + serializer: Serializer, + message: CallInvocationMessage, + debug: Debugger +) { + const callback = lookupLocalCallback(message.fid) + + if (!callback) { + debug(`Call to exposed local function failed: Function not found: UID ${message.uid}`) + return postCallError(relay, message.uid, serializer.serialize(Error(`Function not found: UID ${message.uid}`)) as any as SerializedError) + } + + debug(`Received invocation of local exposed function ${message.fid}, call UID ${message.uid} with arguments:`, message.args) + return invokeExposedLocalFunction(relay, serializer, callback, message) +} + +export function handleFunctionInvocations(relay: MessageRelay, serializer: Serializer, debug: Debugger) { + relay.addEventListener("message", (event: MessageEvent) => { + debug(`Received message:`, event.data) + + if (isInvocationMessage(event.data)) { + handleRemoteInvocation(relay, serializer, event.data, debug) + } + }) +} + +export function handleCallCancellations(relay: MessageRelay, debug: Debugger) { + relay.addEventListener("message", event => { + const messageData = event.data + + if (isCallCancelMessage(messageData)) { + const jobUID = messageData.uid + const subscription = activeSubscriptions.get(jobUID) + + if (subscription) { + subscription.unsubscribe() + activeSubscriptions.delete(jobUID) + } + } + }) +} diff --git a/src/common/callbacks.ts b/src/common/callbacks.ts new file mode 100644 index 00000000..fb68da20 --- /dev/null +++ b/src/common/callbacks.ts @@ -0,0 +1,72 @@ +// tslint:disable max-classes-per-file + +import { $callback } from "../symbols" + +export interface Callback any = (...args: any[]) => any> { + (...args: Parameters): ReturnType + [$callback]: true + id: number + release(): void +} + +let nextLocalCallbackID = 1 +let nextRemoteCallbackID = 1 + +const registeredLocalCallbacks = new Map() +const registeredRemoteCallbacks = new Map() + +export function isCallback(thing: any): thing is Callback { + return typeof thing === "function" && thing[$callback] +} + +export function lookupLocalCallback(id: number): Callback | undefined { + return registeredLocalCallbacks.get(id) +} + +function registerCallback(callback: Callback) { + registeredLocalCallbacks.set(callback.id, callback) + return callback +} + +function unregisterCallback(callback: Callback) { + registeredLocalCallbacks.delete(callback.id) + return callback +} + +export function lookupRemoteCallback(id: number) { + return registeredRemoteCallbacks.get(id) +} + +function registerRemoteCallback any>(callback: Callback): Callback { + registeredRemoteCallbacks.set(callback.id, callback) + return callback +} + +function unregisterRemoteCallback any>(callback: Callback): Callback { + registeredRemoteCallbacks.delete(callback.id) + return callback +} + +export function Callback any>(fn: Fn) { + const callback = ((...args: any[]) => fn(...args)) as Callback + callback[$callback] = true + callback.id = nextLocalCallbackID++ + callback.release = () => unregisterCallback(callback) + return registerCallback(callback) +} + +export function RemoteCallback any>(fn: Fn) { + const callback = ((...args: any[]) => fn(...args)) as Callback + callback[$callback] = true + callback.id = nextRemoteCallbackID++ + callback.release = () => unregisterRemoteCallback(callback) + return registerRemoteCallback(callback) +} + +export function SingleExposedCallback any>(fn: Fn) { + const callback = ((...args: any[]) => fn(...args)) as Callback + callback[$callback] = true + callback.id = 0 + callback.release = () => unregisterCallback(callback) + return registerCallback(callback) +} diff --git a/src/observable-promise.ts b/src/common/observable-promise.ts similarity index 100% rename from src/observable-promise.ts rename to src/common/observable-promise.ts diff --git a/src/common.ts b/src/common/serialization.ts similarity index 52% rename from src/common.ts rename to src/common/serialization.ts index c44d354a..42b3ddd9 100644 --- a/src/common.ts +++ b/src/common/serialization.ts @@ -1,19 +1,23 @@ +import { extendSerializer, DefaultSerializer } from "../serializers/index" +import { MessageRelay } from "../types/common" import { - extendSerializer, - DefaultSerializer, JsonSerializable, Serializer, SerializerImplementation -} from "./serializers" +} from "../types/serializers" -let registeredSerializer: Serializer = DefaultSerializer +let registeredSerializer: Serializer = DefaultSerializer() + +export function getRegisteredSerializer() { + return registeredSerializer +} export function registerSerializer(serializer: SerializerImplementation) { registeredSerializer = extendSerializer(registeredSerializer, serializer) } -export function deserialize(message: JsonSerializable): any { - return registeredSerializer.deserialize(message) +export function deserialize(message: JsonSerializable, origin: MessageRelay): any { + return registeredSerializer.deserialize(message, origin) } export function serialize(input: any): JsonSerializable { diff --git a/src/transferable.ts b/src/common/transferable.ts similarity index 98% rename from src/transferable.ts rename to src/common/transferable.ts index c83e85a4..e5bec321 100644 --- a/src/transferable.ts +++ b/src/common/transferable.ts @@ -1,4 +1,4 @@ -import { $transferable } from "./symbols" +import { $transferable } from "../symbols" export interface TransferDescriptor { [$transferable]: true diff --git a/src/index.ts b/src/index.ts index 8daf5286..faa844b3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,5 +1,6 @@ -export { registerSerializer } from "./common" +export { Callback } from "./common/callbacks" +export { registerSerializer } from "./common/serialization" export * from "./master/index" export { expose } from "./worker/index" -export { DefaultSerializer, JsonSerializable, Serializer, SerializerImplementation } from "./serializers" -export { Transfer, TransferDescriptor } from "./transferable" +export { DefaultSerializer, JsonSerializable, Serializer, SerializerImplementation } from "./serializers/index" +export { Transfer, TransferDescriptor } from "./common/transferable" diff --git a/src/master/implementation.browser.ts b/src/master/implementation.browser.ts index b2d898ee..4581b238 100644 --- a/src/master/implementation.browser.ts +++ b/src/master/implementation.browser.ts @@ -1,6 +1,7 @@ // tslint:disable max-classes-per-file -import { ImplementationExport, ThreadsWorkerOptions } from "../types/master" +import { ImplementationExport, ThreadsWorkerOptions, WorkerImplementation } from "../types/master" +import { multiplexEventTarget } from "../util/events" import { getBundleURL } from "./get-bundle-url.browser" export const defaultPoolSize = typeof navigator !== "undefined" && navigator.hardwareConcurrency @@ -28,7 +29,7 @@ function selectWorkerImplementation(): ImplementationExport { } as any } - class WebWorker extends Worker { + class WebWorker extends Worker implements WorkerImplementation { constructor(url: string | URL, options?: ThreadsWorkerOptions) { if (typeof url === "string" && options && options._baseURL) { url = new URL(url, options._baseURL) @@ -45,7 +46,7 @@ function selectWorkerImplementation(): ImplementationExport { } } - class BlobWorker extends WebWorker { + class BlobWorker extends WebWorker implements WorkerImplementation { constructor(blob: Blob, options?: ThreadsWorkerOptions) { const url = window.URL.createObjectURL(blob) super(url, options) diff --git a/src/master/implementation.node.ts b/src/master/implementation.node.ts index a8f5f517..1de86b5a 100644 --- a/src/master/implementation.node.ts +++ b/src/master/implementation.node.ts @@ -198,7 +198,7 @@ function initTinyWorker(): ImplementationExport { this.emitter = new EventEmitter() this.onerror = (error: Error) => this.emitter.emit("error", error) - this.onmessage = (message: MessageEvent) => this.emitter.emit("message", message) + this.onmessage = (event: MessageEvent) => this.emitter.emit("message", event) } public addEventListener(eventName: WorkerEventName, listener: EventListener) { diff --git a/src/master/index.ts b/src/master/index.ts index ed1b2da1..71c403d6 100644 --- a/src/master/index.ts +++ b/src/master/index.ts @@ -3,6 +3,7 @@ import type { BlobWorker as BlobWorkerClass } from "../types/master" import { Worker as WorkerType } from "../types/master" import { getWorkerImplementation, isWorkerRuntime } from "./implementation" +export { Callback } from "../common/callbacks" export { FunctionThread, ModuleThread } from "../types/master" export { Pool } from "./pool" export { spawn } from "./spawn" diff --git a/src/master/invocation-proxy.ts b/src/master/invocation-proxy.ts deleted file mode 100644 index 231eb1d2..00000000 --- a/src/master/invocation-proxy.ts +++ /dev/null @@ -1,152 +0,0 @@ -/* - * This source file contains the code for proxying calls in the master thread to calls in the workers - * by `.postMessage()`-ing. - * - * Keep in mind that this code can make or break the program's performance! Need to optimize more… - */ - -import DebugLogger from "debug" -import { multicast, Observable } from "observable-fns" -import { deserialize, serialize } from "../common" -import { ObservablePromise } from "../observable-promise" -import { isTransferDescriptor } from "../transferable" -import { - ModuleMethods, - ModuleProxy, - ProxyableFunction, - Worker as WorkerType -} from "../types/master" -import { - MasterJobCancelMessage, - MasterJobRunMessage, - MasterMessageType, - WorkerJobErrorMessage, - WorkerJobResultMessage, - WorkerJobStartMessage, - WorkerMessageType -} from "../types/messages" - -const debugMessages = DebugLogger("threads:master:messages") - -let nextJobUID = 1 - -const dedupe = (array: T[]): T[] => Array.from(new Set(array)) - -const isJobErrorMessage = (data: any): data is WorkerJobErrorMessage => data && data.type === WorkerMessageType.error -const isJobResultMessage = (data: any): data is WorkerJobResultMessage => data && data.type === WorkerMessageType.result -const isJobStartMessage = (data: any): data is WorkerJobStartMessage => data && data.type === WorkerMessageType.running - -function createObservableForJob(worker: WorkerType, jobUID: number): Observable { - return new Observable(observer => { - let asyncType: "observable" | "promise" | undefined - - const messageHandler = ((event: MessageEvent) => { - debugMessages("Message from worker:", event.data) - if (!event.data || event.data.uid !== jobUID) return - - if (isJobStartMessage(event.data)) { - asyncType = event.data.resultType - } else if (isJobResultMessage(event.data)) { - if (asyncType === "promise") { - if (typeof event.data.payload !== "undefined") { - observer.next(deserialize(event.data.payload)) - } - observer.complete() - worker.removeEventListener("message", messageHandler) - } else { - if (event.data.payload) { - observer.next(deserialize(event.data.payload)) - } - if (event.data.complete) { - observer.complete() - worker.removeEventListener("message", messageHandler) - } - } - } else if (isJobErrorMessage(event.data)) { - const error = deserialize(event.data.error as any) - if (asyncType === "promise" || !asyncType) { - observer.error(error) - } else { - observer.error(error) - } - worker.removeEventListener("message", messageHandler) - } - }) as EventListener - - worker.addEventListener("message", messageHandler) - - return () => { - if (asyncType === "observable" || !asyncType) { - const cancelMessage: MasterJobCancelMessage = { - type: MasterMessageType.cancel, - uid: jobUID - } - worker.postMessage(cancelMessage) - } - worker.removeEventListener("message", messageHandler) - } - }) -} - -function prepareArguments(rawArgs: any[]): { args: any[], transferables: Transferable[] } { - if (rawArgs.length === 0) { - // Exit early if possible - return { - args: [], - transferables: [] - } - } - - const args: any[] = [] - const transferables: Transferable[] = [] - - for (const arg of rawArgs) { - if (isTransferDescriptor(arg)) { - args.push(serialize(arg.send)) - transferables.push(...arg.transferables) - } else { - args.push(serialize(arg)) - } - } - - return { - args, - transferables: transferables.length === 0 ? transferables : dedupe(transferables) - } -} - -export function createProxyFunction(worker: WorkerType, method?: string) { - return ((...rawArgs: Args) => { - const uid = nextJobUID++ - const { args, transferables } = prepareArguments(rawArgs) - const runMessage: MasterJobRunMessage = { - type: MasterMessageType.run, - uid, - method, - args - } - - debugMessages("Sending command to run function to worker:", runMessage) - - try { - worker.postMessage(runMessage, transferables) - } catch (error) { - return ObservablePromise.from(Promise.reject(error)) - } - - return ObservablePromise.from(multicast(createObservableForJob(worker, uid))) - }) as any as ProxyableFunction -} - -export function createProxyModule( - worker: WorkerType, - methodNames: string[] -): ModuleProxy { - const proxy: any = {} - - for (const methodName of methodNames) { - proxy[methodName] = createProxyFunction(worker, methodName) - } - - return proxy -} diff --git a/src/master/spawn.ts b/src/master/spawn.ts index d9ef3cbe..78482f84 100644 --- a/src/master/spawn.ts +++ b/src/master/spawn.ts @@ -1,8 +1,9 @@ import DebugLogger from "debug" import { Observable } from "observable-fns" -import { deserialize } from "../common" -import { createPromiseWithResolver } from "../promise" +import { createProxyFunction, createProxyModule, handleFunctionInvocations } from "../common/call-proxy" +import { deserialize, getRegisteredSerializer } from "../common/serialization" import { $errors, $events, $terminate, $worker } from "../symbols" +import { MessageRelay } from "../types/common" import { FunctionThread, ModuleThread, @@ -17,7 +18,7 @@ import { } from "../types/master" import { WorkerInitMessage, WorkerUncaughtErrorMessage } from "../types/messages" import { WorkerFunction, WorkerModule } from "../types/worker" -import { createProxyFunction, createProxyModule } from "./invocation-proxy" +import { createPromiseWithResolver } from "../util/promise" type ArbitraryWorkerInterface = WorkerFunction & WorkerModule & { somekeythatisneverusedinproductioncode123: "magicmarker123" } type ArbitraryThreadType = FunctionThread & ModuleThread @@ -31,7 +32,6 @@ type ExposedToThreadType> = ? ModuleThread : never - const debugMessages = DebugLogger("threads:master:messages") const debugSpawn = DebugLogger("threads:master:spawn") const debugThreadUtils = DebugLogger("threads:master:thread-utils") @@ -58,7 +58,7 @@ async function withTimeout(promise: Promise, timeoutInMs: number, errorMes return result } -function receiveInitMessage(worker: WorkerType): Promise { +function receiveInitMessage(worker: MessageRelay): Promise { return new Promise((resolve, reject) => { const messageHandler = ((event: MessageEvent) => { debugMessages("Message from worker before finishing initialization:", event.data) @@ -67,7 +67,7 @@ function receiveInitMessage(worker: WorkerType): Promise { resolve(event.data) } else if (isUncaughtErrorMessage(event.data)) { worker.removeEventListener("message", messageHandler) - reject(deserialize(event.data.error)) + reject(deserialize(event.data.error, worker)) } }) as EventListener worker.addEventListener("message", messageHandler) @@ -151,12 +151,15 @@ export async function spawn = const { termination, terminate } = createTerminator(worker) const events = createEventObservable(worker, termination) + const serializer = getRegisteredSerializer() if (exposed.type === "function") { - const proxy = createProxyFunction(worker) + const proxy = createProxyFunction(worker, serializer, 0, debugMessages) + handleFunctionInvocations(worker, serializer, debugMessages) return setPrivateThreadProps(proxy, worker, events, terminate) as ExposedToThreadType } else if (exposed.type === "module") { - const proxy = createProxyModule(worker, exposed.methods) + const proxy = createProxyModule(worker, serializer, exposed.methods, debugMessages) + handleFunctionInvocations(worker, serializer, debugMessages) return setPrivateThreadProps(proxy, worker, events, terminate) as ExposedToThreadType } else { const type = (exposed as WorkerInitMessage["exposed"]).type diff --git a/src/serializers.ts b/src/serializers.ts deleted file mode 100644 index ea78b923..00000000 --- a/src/serializers.ts +++ /dev/null @@ -1,84 +0,0 @@ -import { SerializedError } from "./types/messages" - -export interface Serializer { - deserialize(message: Msg): Input - serialize(input: Input): Msg -} - -export interface SerializerImplementation { - deserialize(message: Msg, defaultDeserialize: ((msg: Msg) => Input)): Input - serialize(input: Input, defaultSerialize: ((inp: Input) => Msg)): Msg -} - -export function extendSerializer( - extend: Serializer, - implementation: SerializerImplementation -): Serializer { - const fallbackDeserializer = extend.deserialize.bind(extend) - const fallbackSerializer = extend.serialize.bind(extend) - - return { - deserialize(message: MessageType): InputType { - return implementation.deserialize(message, fallbackDeserializer) - }, - - serialize(input: InputType): MessageType { - return implementation.serialize(input, fallbackSerializer) - } - } -} - -type JsonSerializablePrimitive = string | number | boolean | null - -type JsonSerializableObject = { - [key: string]: - | JsonSerializablePrimitive - | JsonSerializablePrimitive[] - | JsonSerializableObject - | JsonSerializableObject[] - | undefined -} - -export type JsonSerializable = - | JsonSerializablePrimitive - | JsonSerializablePrimitive[] - | JsonSerializableObject - | JsonSerializableObject[] - - -const DefaultErrorSerializer: Serializer = { - deserialize(message: SerializedError): Error { - return Object.assign(Error(message.message), { - name: message.name, - stack: message.stack - }) - }, - serialize(error: Error): SerializedError { - return { - __error_marker: "$$error", - message: error.message, - name: error.name, - stack: error.stack - } - } -} - -const isSerializedError = (thing: any): thing is SerializedError => - thing && typeof thing === "object" && "__error_marker" in thing && thing.__error_marker === "$$error" - -export const DefaultSerializer: Serializer = { - deserialize(message: JsonSerializable): any { - if (isSerializedError(message)) { - return DefaultErrorSerializer.deserialize(message) - } else { - return message - } - }, - serialize(input: any): JsonSerializable { - if (input instanceof Error) { - return DefaultErrorSerializer.serialize(input) as any as JsonSerializable - } else { - return input - } - } -} diff --git a/src/serializers/callbacks.ts b/src/serializers/callbacks.ts new file mode 100644 index 00000000..4d4bd391 --- /dev/null +++ b/src/serializers/callbacks.ts @@ -0,0 +1,23 @@ +import DebugLogger from "debug" +import { createProxyFunction } from "../common/call-proxy" +import { Callback, RemoteCallback } from "../common/callbacks" +import { MessageRelay } from "../types/common" +import { SerializedCallback, Serializer } from "../types/serializers" + +const debug = DebugLogger("threads:callback:messages") + +export const DefaultCallbackSerializer = (rootSerializer: Serializer): Serializer> => ({ + deserialize(message: SerializedCallback, origin: MessageRelay): Callback { + const proxy = createProxyFunction(origin, rootSerializer, message.fid, debug) + return RemoteCallback(proxy) + }, + serialize(callback: Callback): SerializedCallback { + return { + __callback_marker: "$$callback", + fid: callback.id + } + } +}) + +export const isSerializedCallback = (thing: any): thing is SerializedCallback => + thing && typeof thing === "object" && "__callback_marker" in thing && thing.__callback_marker === "$$callback" diff --git a/src/serializers/errors.ts b/src/serializers/errors.ts new file mode 100644 index 00000000..db6187c9 --- /dev/null +++ b/src/serializers/errors.ts @@ -0,0 +1,21 @@ +import { SerializedError, Serializer } from "../types/serializers" + +export const DefaultErrorSerializer = (): Serializer => ({ + deserialize(message: SerializedError): Error { + return Object.assign(Error(message.message), { + name: message.name, + stack: message.stack + }) + }, + serialize(error: Error): SerializedError { + return { + __error_marker: "$$error", + message: error.message, + name: error.name, + stack: error.stack + } + } +}) + +export const isSerializedError = (thing: any): thing is SerializedError => + thing && typeof thing === "object" && "__error_marker" in thing && thing.__error_marker === "$$error" diff --git a/src/serializers/index.ts b/src/serializers/index.ts new file mode 100644 index 00000000..e5999582 --- /dev/null +++ b/src/serializers/index.ts @@ -0,0 +1,58 @@ +import { isCallback } from "../common/callbacks" +import { MessageRelay } from "../types/common" +import { JsonSerializable, Serializer, SerializerImplementation } from "../types/serializers" +import { isSerializedCallback, DefaultCallbackSerializer } from "./callbacks" +import { isSerializedError, DefaultErrorSerializer } from "./errors" + +export { + JsonSerializable, + Serializer, + SerializerImplementation +} + +export function extendSerializer( + extend: Serializer, + implementation: SerializerImplementation +): Serializer { + const fallbackSerializer = extend.serialize.bind(extend) + + return { + deserialize(message: MessageType, origin: MessageRelay): InputType { + const fallback = (msg: MessageType) => extend.deserialize(msg, origin) + return implementation.deserialize(message, fallback) + }, + + serialize(input: InputType): MessageType { + return implementation.serialize(input, fallbackSerializer) + } + } +} + + +export const DefaultSerializer = (): Serializer => { + const serializer: Serializer = { + deserialize(message: JsonSerializable, sender: MessageRelay | null): any { + if (isSerializedError(message)) { + return errorSerializer.deserialize(message, sender) + } else if (isSerializedCallback(message)) { + return callbackSerializer.deserialize(message, sender) + } else { + return message + } + }, + serialize(input: any): JsonSerializable { + if (input instanceof Error) { + return errorSerializer.serialize(input) as any as JsonSerializable + } else if (isCallback(input)) { + return callbackSerializer.serialize(input) as any as JsonSerializable + } else { + return input + } + } + } + + const callbackSerializer = DefaultCallbackSerializer(serializer) + const errorSerializer = DefaultErrorSerializer() + + return serializer +} diff --git a/src/symbols.ts b/src/symbols.ts index b53f1fd9..c0d3217e 100644 --- a/src/symbols.ts +++ b/src/symbols.ts @@ -1,3 +1,4 @@ +export const $callback = Symbol("thread.callback") export const $errors = Symbol("thread.errors") export const $events = Symbol("thread.events") export const $terminate = Symbol("thread.terminate") diff --git a/src/types/common.ts b/src/types/common.ts new file mode 100644 index 00000000..9520afc8 --- /dev/null +++ b/src/types/common.ts @@ -0,0 +1,14 @@ +export type TransferList = Transferable[] + +/** + * A thing than send and receive messages, usually a worker. + * Pretty much identical to a `MessagePort`, but with less strict types, + * so it really is an interface implemented by Worker, MessagePort, self. + */ +export interface MessageRelay { + addEventListener(event: "error", handler: (event: ErrorEvent) => void): any + addEventListener(event: "message", handler: (event: MessageEvent) => void): any + addEventListener(event: string, handler: EventListener): any + postMessage(value: any, transferList?: TransferList): void + removeEventListener(event: "error" | "message", handler: (arg: any) => any): any +} diff --git a/src/types/master.ts b/src/types/master.ts index d7b6a8d3..166b78f9 100644 --- a/src/types/master.ts +++ b/src/types/master.ts @@ -4,12 +4,14 @@ // Cannot use `compilerOptions.esModuleInterop` and default import syntax // See import { Observable } from "observable-fns" -import { ObservablePromise } from "../observable-promise" +import { ObservablePromise } from "../common/observable-promise" import { $errors, $events, $terminate, $worker } from "../symbols" +import { MessageRelay, TransferList } from "./common" interface ObservableLikeSubscription { unsubscribe(): any } + interface ObservableLike { subscribe(onNext: (value: T) => any, onError?: (error: any) => any, onComplete?: () => any): ObservableLikeSubscription subscribe(listeners: { @@ -61,11 +63,9 @@ interface AnyModuleThread extends PrivateThreadProps { /** Worker thread. Either a `FunctionThread` or a `ModuleThread`. */ export type Thread = AnyFunctionThread | AnyModuleThread -export type TransferList = Transferable[] - /** Worker instance. Either a web worker or a node.js Worker provided by `worker_threads` or `tiny-worker`. */ -export interface Worker extends EventTarget { - postMessage(value: any, transferList?: TransferList): void +export interface Worker extends MessageRelay { + removeEventListener(event: string, listener: EventListener): any terminate(callback?: (error?: Error, exitCode?: number) => void): void } @@ -81,12 +81,16 @@ export interface ThreadsWorkerOptions extends WorkerOptions { /** The size of a pre-allocated memory range used for generated code. */ codeRangeSizeMb?: number; } + timeout?: number } /** Worker implementation. Either web worker or a node.js Worker class. */ export declare class WorkerImplementation extends EventTarget implements Worker { constructor(path: string, options?: ThreadsWorkerOptions) + // Quick fix to use the more precise `addEventListener()` signatures: + public addEventListener: MessageRelay["addEventListener"] public postMessage(value: any, transferList?: TransferList): void + public removeEventListener(event: string, listener: EventListener): any public terminate(): void } diff --git a/src/types/messages.ts b/src/types/messages.ts index 38411b83..1489b0af 100644 --- a/src/types/messages.ts +++ b/src/types/messages.ts @@ -1,40 +1,51 @@ -export interface SerializedError { - __error_marker: "$$error" - message: string - name: string - stack?: string -} - -///////////////////////////// -// Messages sent by master: +import { SerializedError } from "./serializers" -export enum MasterMessageType { - cancel = "cancel", - run = "run" +export enum CommonMessageType { + cancel = "call:cancel", + error = "call:error", + invoke = "call:invoke", + result = "call:result", + running = "call:running" } -export type MasterJobCancelMessage = { - type: MasterMessageType.cancel, +export type CallCancelMessage = { + type: CommonMessageType.cancel, uid: number } -export type MasterJobRunMessage = { - type: MasterMessageType.run, +export type CallErrorMessage = { + type: CommonMessageType.error, + uid: number, + error: SerializedError +} + +export type CallInvocationMessage = { + type: CommonMessageType.invoke, + /** Function ID */ + fid: number, + /** Unique call ID */ uid: number, - method?: string, args: any[] } -export type MasterSentMessage = MasterJobCancelMessage | MasterJobRunMessage +export type CallResultMessage = { + type: CommonMessageType.result, + uid: number, + complete?: true, + payload?: any +} + +export type CallRunningMessage = { + type: CommonMessageType.running, + uid: number, + resultType: "observable" | "promise" +} //////////////////////////// // Messages sent by worker: export enum WorkerMessageType { - error = "error", init = "init", - result = "result", - running = "running", uncaughtError = "uncaughtError" } @@ -49,31 +60,5 @@ export type WorkerUncaughtErrorMessage = { export type WorkerInitMessage = { type: WorkerMessageType.init, - exposed: { type: "function" } | { type: "module", methods: string[] } + exposed: { type: "function" } | { type: "module", methods: Record } } - -export type WorkerJobErrorMessage = { - type: WorkerMessageType.error, - uid: number, - error: SerializedError -} - -export type WorkerJobResultMessage = { - type: WorkerMessageType.result, - uid: number, - complete?: true, - payload?: any -} - -export type WorkerJobStartMessage = { - type: WorkerMessageType.running, - uid: number, - resultType: "observable" | "promise" -} - -export type WorkerSentMessage = - | WorkerInitMessage - | WorkerJobErrorMessage - | WorkerJobResultMessage - | WorkerJobStartMessage - | WorkerUncaughtErrorMessage diff --git a/src/types/serializers.ts b/src/types/serializers.ts new file mode 100644 index 00000000..528e3938 --- /dev/null +++ b/src/types/serializers.ts @@ -0,0 +1,40 @@ +import { MessageRelay } from "./common" + +type JsonSerializablePrimitive = string | number | boolean | null + +type JsonSerializableObject = { + [key: string]: + | JsonSerializablePrimitive + | JsonSerializablePrimitive[] + | JsonSerializableObject + | JsonSerializableObject[] + | undefined +} + +export type JsonSerializable = + | JsonSerializablePrimitive + | JsonSerializablePrimitive[] + | JsonSerializableObject + | JsonSerializableObject[] + +export interface Serializer { + deserialize(message: Msg, sender: MessageRelay | null): Input + serialize(input: Input): Msg +} + +export interface SerializerImplementation { + deserialize(message: Msg, defaultDeserialize: ((msg: Msg) => Input)): Input + serialize(input: Input, defaultSerialize: ((inp: Input) => Msg)): Msg +} + +export interface SerializedCallback { + __callback_marker: "$$callback" + fid: number +} + +export interface SerializedError { + __error_marker: "$$error" + message: string + name: string + stack?: string +} diff --git a/src/types/worker.ts b/src/types/worker.ts index 3e1790e4..22251469 100644 --- a/src/types/worker.ts +++ b/src/types/worker.ts @@ -1,9 +1,7 @@ -type UnsubscribeFn = () => void +import { MessageRelay } from "./common" -export interface AbstractedWorkerAPI { +export interface AbstractedWorkerAPI extends MessageRelay { isWorkerRuntime(): boolean - postMessageToMaster(message: any, transferList?: Transferable[]): void - subscribeToMasterMessages(onMessage: (data: any) => void): UnsubscribeFn } export type WorkerFunction = ((...args: any[]) => any) | (() => any) diff --git a/src/util/events.ts b/src/util/events.ts new file mode 100644 index 00000000..400a7f08 --- /dev/null +++ b/src/util/events.ts @@ -0,0 +1,35 @@ +/** + * Make sure that there is only ever one listener set on that event emitter. + * Do so by setting a single event handler that then calls all the + * event listeners. + */ +export function multiplexEventTarget(emitter: Pick) { + const eventListeners = new Map any>>() + + function addEventListener(event: string, listener: EventListener) { + if (eventListeners.has(event)) { + eventListeners.get(event)!.add(listener) + } else { + eventListeners.set(event, new Set([listener])) + + emitter.addEventListener(event, (...args: any[]) => { + const listeners = eventListeners.get(event) || [] + + for (const callback of listeners) { + callback(...args) + } + }) + } + } + + function removeEventListener(event: string, listener: EventListener) { + if (eventListeners.has(event)) { + eventListeners.get(event)!.delete(listener) + } + } + + return { + addEventListener, + removeEventListener + } +} diff --git a/src/observable.ts b/src/util/observable.ts similarity index 100% rename from src/observable.ts rename to src/util/observable.ts diff --git a/src/promise.ts b/src/util/promise.ts similarity index 100% rename from src/promise.ts rename to src/util/promise.ts diff --git a/src/worker/implementation.browser.ts b/src/worker/implementation.browser.ts index 1988ad9c..b3ea645c 100644 --- a/src/worker/implementation.browser.ts +++ b/src/worker/implementation.browser.ts @@ -1,7 +1,9 @@ /// // tslint:disable no-shadowed-variable +import { MessageRelay } from "../types/common" import { AbstractedWorkerAPI } from "../types/worker" +import { multiplexEventTarget } from "../util/events" interface WorkerGlobalScope { addEventListener(eventName: string, listener: (event: Event) => void): void @@ -16,23 +18,23 @@ const isWorkerRuntime: AbstractedWorkerAPI["isWorkerRuntime"] = function isWorke return typeof self !== "undefined" && self.postMessage && !isWindowContext ? true : false } -const postMessageToMaster: AbstractedWorkerAPI["postMessageToMaster"] = function postMessageToMaster(data, transferList?) { +const postMessage: AbstractedWorkerAPI["postMessage"] = function postMessageToMaster(data, transferList?) { self.postMessage(data, transferList) } -const subscribeToMasterMessages: AbstractedWorkerAPI["subscribeToMasterMessages"] = function subscribeToMasterMessages(onMessage) { - const messageHandler = (messageEvent: MessageEvent) => { - onMessage(messageEvent.data) - } - const unsubscribe = () => { - self.removeEventListener("message", messageHandler as EventListener) - } - self.addEventListener("message", messageHandler as EventListener) - return unsubscribe -} +let muxedSelfEvents: Pick | undefined -export default { +const Implementation: AbstractedWorkerAPI = { + addEventListener(event: string, handler: (message: any) => any) { + muxedSelfEvents = muxedSelfEvents || multiplexEventTarget(self) + return muxedSelfEvents.addEventListener(event, handler) + }, + removeEventListener(event: string, handler: (message: any) => any) { + muxedSelfEvents = muxedSelfEvents || multiplexEventTarget(self) + return muxedSelfEvents.removeEventListener(event, handler) + }, isWorkerRuntime, - postMessageToMaster, - subscribeToMasterMessages + postMessage } + +export default Implementation diff --git a/src/worker/implementation.tiny-worker.ts b/src/worker/implementation.tiny-worker.ts index 97c18b11..8c2759a7 100644 --- a/src/worker/implementation.tiny-worker.ts +++ b/src/worker/implementation.tiny-worker.ts @@ -2,6 +2,7 @@ // tslint:disable no-shadowed-variable import { AbstractedWorkerAPI } from "../types/worker" +import { multiplexEventTarget } from "../util/events" interface WorkerGlobalScope { addEventListener(eventName: string, listener: (event: Event) => void): void @@ -19,32 +20,15 @@ const isWorkerRuntime: AbstractedWorkerAPI["isWorkerRuntime"] = function isWorke return typeof self !== "undefined" && self.postMessage ? true : false } -const postMessageToMaster: AbstractedWorkerAPI["postMessageToMaster"] = function postMessageToMaster(data) { +const postMessage: AbstractedWorkerAPI["postMessage"] = function postMessage(data) { // TODO: Warn that Transferables are not supported on first attempt to use feature self.postMessage(data) } -let muxingHandlerSetUp = false -const messageHandlers = new Set<(data: any) => void>() - -const subscribeToMasterMessages: AbstractedWorkerAPI["subscribeToMasterMessages"] = function subscribeToMasterMessages(onMessage) { - if (!muxingHandlerSetUp) { - // We have one multiplexing message handler as tiny-worker's - // addEventListener() only allows you to set a single message handler - self.addEventListener("message", ((event: MessageEvent) => { - messageHandlers.forEach(handler => handler(event.data)) - }) as EventListener) - muxingHandlerSetUp = true - } - - messageHandlers.add(onMessage) - - const unsubscribe = () => messageHandlers.delete(onMessage) - return unsubscribe -} - -export default { +const Implementation: AbstractedWorkerAPI = { + ...(multiplexEventTarget(self) as Pick), isWorkerRuntime, - postMessageToMaster, - subscribeToMasterMessages + postMessage } + +export default Implementation diff --git a/src/worker/implementation.worker_threads.ts b/src/worker/implementation.worker_threads.ts index 8bf79510..4b3a4d62 100644 --- a/src/worker/implementation.worker_threads.ts +++ b/src/worker/implementation.worker_threads.ts @@ -14,34 +14,29 @@ const isWorkerRuntime: AbstractedWorkerAPI["isWorkerRuntime"] = function isWorke return !WorkerThreads().isMainThread } -const postMessageToMaster: AbstractedWorkerAPI["postMessageToMaster"] = function postMessageToMaster(data, transferList) { +const postMessage: AbstractedWorkerAPI["postMessage"] = function postMessage(data, transferList) { assertMessagePort(WorkerThreads().parentPort).postMessage(data, transferList as any) } -const subscribeToMasterMessages: AbstractedWorkerAPI["subscribeToMasterMessages"] = function subscribeToMasterMessages(onMessage) { - const parentPort = WorkerThreads().parentPort - - if (!parentPort) { - throw Error("Invariant violation: MessagePort to parent is not available.") - } - const messageHandler = (message: any) => { - onMessage(message) - } - const unsubscribe = () => { - assertMessagePort(parentPort).off("message", messageHandler) - } - assertMessagePort(parentPort).on("message", messageHandler) - return unsubscribe -} - function testImplementation() { // Will throw if `worker_threads` are not available WorkerThreads() } -export default { +const Implementation: AbstractedWorkerAPI & { testImplementation: typeof testImplementation } = { + addEventListener(event: string, listener: (arg: any) => any) { + const port = assertMessagePort(WorkerThreads().parentPort) + return event === "message" + ? port.on(event, (data) => listener({ data })) + : port.on(event, listener) + }, + removeEventListener(event, listener) { + const port = assertMessagePort(WorkerThreads().parentPort) + return port.off(event, listener) + }, isWorkerRuntime, - postMessageToMaster, - subscribeToMasterMessages, + postMessage, testImplementation } + +export default Implementation diff --git a/src/worker/index.ts b/src/worker/index.ts index 79bbf0be..fa010f02 100644 --- a/src/worker/index.ts +++ b/src/worker/index.ts @@ -1,50 +1,22 @@ -import isSomeObservable from "is-observable" -import { Observable, Subscription } from "observable-fns" -import { deserialize, serialize } from "../common" -import { isTransferDescriptor, TransferDescriptor } from "../transferable" -import { - MasterJobCancelMessage, - MasterJobRunMessage, - MasterMessageType, - SerializedError, - WorkerInitMessage, - WorkerJobErrorMessage, - WorkerJobResultMessage, - WorkerJobStartMessage, - WorkerMessageType, - WorkerUncaughtErrorMessage -} from "../types/messages" +import DebugLogger from "debug" +import { handleCallCancellations, handleFunctionInvocations } from "../common/call-proxy" +import { Callback, SingleExposedCallback } from "../common/callbacks" +import { getRegisteredSerializer, serialize } from "../common/serialization" +import { WorkerInitMessage, WorkerMessageType, WorkerUncaughtErrorMessage } from "../types/messages" +import { SerializedError } from "../types/serializers" import { WorkerFunction, WorkerModule } from "../types/worker" import Implementation from "./implementation" -export { registerSerializer } from "../common" -export { Transfer } from "../transferable" +export { Callback } from "../common/callbacks" +export { registerSerializer } from "../common/serialization" +export { Transfer } from "../common/transferable" /** Returns `true` if this code is currently running in a worker. */ export const isWorkerRuntime = Implementation.isWorkerRuntime let exposeCalled = false -const activeSubscriptions = new Map>() - -const isMasterJobCancelMessage = (thing: any): thing is MasterJobCancelMessage => thing && thing.type === MasterMessageType.cancel -const isMasterJobRunMessage = (thing: any): thing is MasterJobRunMessage => thing && thing.type === MasterMessageType.run - -/** - * There are issues with `is-observable` not recognizing zen-observable's instances. - * We are using `observable-fns`, but it's based on zen-observable, too. - */ -const isObservable = (thing: any): thing is Observable => isSomeObservable(thing) || isZenObservable(thing) - -function isZenObservable(thing: any): thing is Observable { - return thing && typeof thing === "object" && typeof thing.subscribe === "function" -} - -function deconstructTransfer(thing: any) { - return isTransferDescriptor(thing) - ? { payload: thing.send, transferables: thing.transferables } - : { payload: thing, transferables: undefined } -} +const debugIncomingMessages = DebugLogger("threads:worker:messages") function postFunctionInitMessage() { const initMessage: WorkerInitMessage = { @@ -53,48 +25,18 @@ function postFunctionInitMessage() { type: "function" } } - Implementation.postMessageToMaster(initMessage) + Implementation.postMessage(initMessage) } -function postModuleInitMessage(methodNames: string[]) { +function postModuleInitMessage(methods: Record) { const initMessage: WorkerInitMessage = { type: WorkerMessageType.init, exposed: { type: "module", - methods: methodNames + methods } } - Implementation.postMessageToMaster(initMessage) -} - -function postJobErrorMessage(uid: number, rawError: Error | TransferDescriptor) { - const { payload: error, transferables } = deconstructTransfer(rawError) - const errorMessage: WorkerJobErrorMessage = { - type: WorkerMessageType.error, - uid, - error: serialize(error) as any as SerializedError - } - Implementation.postMessageToMaster(errorMessage, transferables) -} - -function postJobResultMessage(uid: number, completed: boolean, resultValue?: any) { - const { payload, transferables } = deconstructTransfer(resultValue) - const resultMessage: WorkerJobResultMessage = { - type: WorkerMessageType.result, - uid, - complete: completed ? true : undefined, - payload - } - Implementation.postMessageToMaster(resultMessage, transferables) -} - -function postJobStartMessage(uid: number, resultType: WorkerJobStartMessage["resultType"]) { - const startMessage: WorkerJobStartMessage = { - type: WorkerMessageType.running, - uid, - resultType - } - Implementation.postMessageToMaster(startMessage) + Implementation.postMessage(initMessage) } function postUncaughtErrorMessage(error: Error) { @@ -103,7 +45,7 @@ function postUncaughtErrorMessage(error: Error) { type: WorkerMessageType.uncaughtError, error: serialize(error) as any as SerializedError } - Implementation.postMessageToMaster(errorMessage) + Implementation.postMessage(errorMessage) } catch (subError) { // tslint:disable-next-line no-console console.error( @@ -115,41 +57,6 @@ function postUncaughtErrorMessage(error: Error) { } } -async function runFunction(jobUID: number, fn: WorkerFunction, args: any[]) { - let syncResult: any - - try { - syncResult = fn(...args) - } catch (error) { - return postJobErrorMessage(jobUID, error) - } - - const resultType = isObservable(syncResult) ? "observable" : "promise" - postJobStartMessage(jobUID, resultType) - - if (isObservable(syncResult)) { - const subscription = syncResult.subscribe( - value => postJobResultMessage(jobUID, false, serialize(value)), - error => { - postJobErrorMessage(jobUID, serialize(error) as any) - activeSubscriptions.delete(jobUID) - }, - () => { - postJobResultMessage(jobUID, true) - activeSubscriptions.delete(jobUID) - } - ) - activeSubscriptions.set(jobUID, subscription) - } else { - try { - const result = await syncResult - postJobResultMessage(jobUID, true, serialize(result)) - } catch (error) { - postJobErrorMessage(jobUID, serialize(error) as any) - } - } -} - /** * Expose a function or a module (an object whose values are functions) * to the main thread. Must be called exactly once in every worker thread @@ -167,36 +74,24 @@ export function expose(exposed: WorkerFunction | WorkerModule) { exposeCalled = true if (typeof exposed === "function") { - Implementation.subscribeToMasterMessages(messageData => { - if (isMasterJobRunMessage(messageData) && !messageData.method) { - runFunction(messageData.uid, exposed, messageData.args.map(deserialize)) - } - }) + SingleExposedCallback(exposed) + handleFunctionInvocations(Implementation, getRegisteredSerializer(), debugIncomingMessages) postFunctionInitMessage() } else if (typeof exposed === "object" && exposed) { - Implementation.subscribeToMasterMessages(messageData => { - if (isMasterJobRunMessage(messageData) && messageData.method) { - runFunction(messageData.uid, exposed[messageData.method], messageData.args.map(deserialize)) - } - }) - - const methodNames = Object.keys(exposed).filter(key => typeof exposed[key] === "function") - postModuleInitMessage(methodNames) + const methods = Object.keys(exposed).reduce>( + (reduced, methodName) => { + const callback = Callback(exposed[methodName]) + return { ...reduced, [methodName]: callback.id } + }, + {} + ) + handleFunctionInvocations(Implementation, getRegisteredSerializer(), debugIncomingMessages) + postModuleInitMessage(methods) } else { throw Error(`Invalid argument passed to expose(). Expected a function or an object, got: ${exposed}`) } - Implementation.subscribeToMasterMessages(messageData => { - if (isMasterJobCancelMessage(messageData)) { - const jobUID = messageData.uid - const subscription = activeSubscriptions.get(jobUID) - - if (subscription) { - subscription.unsubscribe() - activeSubscriptions.delete(jobUID) - } - } - }) + handleCallCancellations(Implementation, debugIncomingMessages) } if (typeof self !== "undefined" && typeof self.addEventListener === "function" && Implementation.isWorkerRuntime()) { diff --git a/test/callbacks.test.ts b/test/callbacks.test.ts new file mode 100644 index 00000000..0e977505 --- /dev/null +++ b/test/callbacks.test.ts @@ -0,0 +1,16 @@ +import test from "ava" +import { spawn, Callback, Thread, Worker } from "../src/index" +import { MapWorker } from "./workers/map" + +test("can register, use and release a callback", async t => { + const callback = Callback((x: number) => x * 2) + const map = await spawn(new Worker("./workers/map")) + + try { + const mapped = await map([1, 2, 3], callback) + t.deepEqual(mapped, [2, 4, 6]) + callback.release() + } finally { + await Thread.terminate(map) + } +}) diff --git a/test/observable-promise.test.ts b/test/observable-promise.test.ts index 25b42eb5..9537cf81 100644 --- a/test/observable-promise.test.ts +++ b/test/observable-promise.test.ts @@ -1,6 +1,6 @@ import test from "ava" import { Observable } from "observable-fns" -import { ObservablePromise } from "../src/observable-promise" +import { ObservablePromise } from "../src/common/observable-promise" const delay = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)) diff --git a/test/observable.test.ts b/test/observable.test.ts index 7a38de1b..61eec263 100644 --- a/test/observable.test.ts +++ b/test/observable.test.ts @@ -1,5 +1,5 @@ import test from "ava" -import { Observable, Subject } from "../src/observable" +import { Observable, Subject } from "../src/util/observable" test("Observable subject emits values and completion event", async t => { let completed1 = false diff --git a/test/transferables.test.ts b/test/transferables.test.ts index cee0cd7f..39c305f7 100644 --- a/test/transferables.test.ts +++ b/test/transferables.test.ts @@ -50,8 +50,8 @@ test("can pass transferable objects on thread call", async t => { t.is(postMessageCalls[0].length, 2) t.deepEqual(postMessageCalls[0][0], { args: [arrayBufferPlaceholder, 15], - method: undefined, - type: "run", + fid: 0, + type: "call:invoke", uid: postMessageCalls[0][0].uid }) t.deepEqual(postMessageCalls[0][1], [arrayBufferPlaceholder]) diff --git a/test/webpack/app.ts b/test/webpack/app.ts index 7f651e58..bff0fa78 100644 --- a/test/webpack/app.ts +++ b/test/webpack/app.ts @@ -38,7 +38,7 @@ async function test3() { return } - const hello = await spawn(new Worker("https://infallible-turing-115958.netlify.com/hello-worker.js")) + const hello = await spawn(new Worker("https://infallible-turing-115958.netlify.com/hello-worker-2020-06-21-callbacks.js")) const result = await hello("World") if (result !== "Hello, World") { diff --git a/test/webpack/webpack.node.config.js b/test/webpack/webpack.node.config.js index 8db142f8..95359950 100644 --- a/test/webpack/webpack.node.config.js +++ b/test/webpack/webpack.node.config.js @@ -29,7 +29,7 @@ module.exports = { ] }, plugins: [ - new ThreadsPlugin() + new ThreadsPlugin({ target: "node" }) ], resolve: { extensions: [".js", ".ts"] diff --git a/test/workers/map.ts b/test/workers/map.ts new file mode 100644 index 00000000..c0d82f5b --- /dev/null +++ b/test/workers/map.ts @@ -0,0 +1,7 @@ +import { expose } from "../../src/worker" + +export type MapWorker = (input: number[], mapper: (source: number) => number | Promise) => number[] + +expose(function map(input: number[], mapper: (source: number) => number) { + return Promise.all(input.map(mapper)) +}) diff --git a/test/workers/minmax.ts b/test/workers/minmax.ts index 303a9304..2a856877 100644 --- a/test/workers/minmax.ts +++ b/test/workers/minmax.ts @@ -1,4 +1,4 @@ -import { Observable, Subject } from "../../src/observable" +import { Observable, Subject } from "../../src/util/observable" import { expose } from "../../src/worker" let max = -Infinity diff --git a/tsconfig.json b/tsconfig.json index a855de20..fc4bbb0a 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -13,8 +13,8 @@ }, "include": [ "./src/index.ts", - "./src/observable.ts", "./src/master/*", + "./src/util/observable.ts", "./src/worker/*", "./types/tiny-worker.d.ts", "./types/is-observable.d.ts" From 67a09d4ecb5d69d0d648e6fa6086538c92a13168 Mon Sep 17 00:00:00 2001 From: Andy Wermke Date: Sun, 21 Jun 2020 20:46:18 +0200 Subject: [PATCH 2/3] Support exposing generator functions --- src/serializers/index.ts | 6 +++++ src/serializers/iterators.ts | 47 +++++++++++++++++++++++++++++++++ src/types/master.ts | 11 ++++++-- src/types/serializers.ts | 9 +++++-- test/iterators.test.ts | 36 +++++++++++++++++++++++++ test/workers/async-generator.ts | 12 +++++++++ test/workers/generator.ts | 9 +++++++ 7 files changed, 126 insertions(+), 4 deletions(-) create mode 100644 src/serializers/iterators.ts create mode 100644 test/iterators.test.ts create mode 100644 test/workers/async-generator.ts create mode 100644 test/workers/generator.ts diff --git a/src/serializers/index.ts b/src/serializers/index.ts index e5999582..7739fe39 100644 --- a/src/serializers/index.ts +++ b/src/serializers/index.ts @@ -3,6 +3,7 @@ import { MessageRelay } from "../types/common" import { JsonSerializable, Serializer, SerializerImplementation } from "../types/serializers" import { isSerializedCallback, DefaultCallbackSerializer } from "./callbacks" import { isSerializedError, DefaultErrorSerializer } from "./errors" +import { isIterator, isSerializedIterator, DefaultIteratorSerializer } from "./iterators" export { JsonSerializable, @@ -36,6 +37,8 @@ export const DefaultSerializer = (): Serializer => { return errorSerializer.deserialize(message, sender) } else if (isSerializedCallback(message)) { return callbackSerializer.deserialize(message, sender) + } else if (isSerializedIterator(message)) { + return iteratorSerializer.deserialize(message, sender) } else { return message } @@ -45,6 +48,8 @@ export const DefaultSerializer = (): Serializer => { return errorSerializer.serialize(input) as any as JsonSerializable } else if (isCallback(input)) { return callbackSerializer.serialize(input) as any as JsonSerializable + } else if (isIterator(input)) { + return iteratorSerializer.serialize(input) as any as JsonSerializable } else { return input } @@ -53,6 +58,7 @@ export const DefaultSerializer = (): Serializer => { const callbackSerializer = DefaultCallbackSerializer(serializer) const errorSerializer = DefaultErrorSerializer() + const iteratorSerializer = DefaultIteratorSerializer(serializer) return serializer } diff --git a/src/serializers/iterators.ts b/src/serializers/iterators.ts new file mode 100644 index 00000000..510a8821 --- /dev/null +++ b/src/serializers/iterators.ts @@ -0,0 +1,47 @@ +import DebugLogger from "debug" +import { createProxyFunction } from "../common/call-proxy" +import { Callback, RemoteCallback } from "../common/callbacks" +import { MessageRelay } from "../types/common" +import { SerializedIterator, Serializer } from "../types/serializers" + +const debug = DebugLogger("threads:callback:messages") + +export const DefaultIteratorSerializer = (rootSerializer: Serializer): Serializer | AsyncIterator, AsyncIterator> => ({ + deserialize(message: SerializedIterator, origin: MessageRelay): AsyncIterator & AsyncIterable { + const remoteNext = createProxyFunction<[], IteratorResult>(origin, rootSerializer, message.next_fid, debug) + const remoteCallback = RemoteCallback<() => Promise>>(remoteNext) + + const next = async () => { + const result = await remoteCallback() + if (result.done) { + remoteCallback.release() + } + return result + } + + const asyncIterator = { + [Symbol.asyncIterator]: () => asyncIterator, + next + } + return asyncIterator + }, + serialize(iter: Iterator | AsyncIterator): SerializedIterator { + const next = Callback(async () => { + const result = await iter.next() + if (result.done) { + next.release() + } + return result + }) + return { + __iterator_marker: "$$iterator", + next_fid: next.id + } + } +}) + +export const isIterator = (thing: any): thing is Iterator | AsyncIterator => + thing && typeof thing === "object" && "next" in thing && typeof thing.next === "function" + +export const isSerializedIterator = (thing: any): thing is SerializedIterator => + thing && typeof thing === "object" && "__iterator_marker" in thing && thing.__iterator_marker === "$$iterator" diff --git a/src/types/master.ts b/src/types/master.ts index 166b78f9..e261576e 100644 --- a/src/types/master.ts +++ b/src/types/master.ts @@ -28,12 +28,19 @@ export type StripAsync = ? ObservableBaseType : Type +export type AsyncifyIterator = + Type extends Iterator + ? AsyncIterator & AsyncIterable + : Type extends AsyncIterator + ? AsyncIterator & AsyncIterable + : Type + export type ModuleMethods = { [methodName: string]: (...args: any) => any } export type ProxyableFunction = Args extends [] - ? () => ObservablePromise> - : (...args: Args) => ObservablePromise> + ? () => ObservablePromise>> + : (...args: Args) => ObservablePromise>> export type ModuleProxy = { [method in keyof Methods]: ProxyableFunction, ReturnType> diff --git a/src/types/serializers.ts b/src/types/serializers.ts index 528e3938..6716c49a 100644 --- a/src/types/serializers.ts +++ b/src/types/serializers.ts @@ -17,8 +17,8 @@ export type JsonSerializable = | JsonSerializableObject | JsonSerializableObject[] -export interface Serializer { - deserialize(message: Msg, sender: MessageRelay | null): Input +export interface Serializer { + deserialize(message: Msg, sender: MessageRelay | null): Deserialized serialize(input: Input): Msg } @@ -38,3 +38,8 @@ export interface SerializedError { name: string stack?: string } + +export interface SerializedIterator { + __iterator_marker: "$$iterator" + next_fid: number +} diff --git a/test/iterators.test.ts b/test/iterators.test.ts new file mode 100644 index 00000000..cc0d9d79 --- /dev/null +++ b/test/iterators.test.ts @@ -0,0 +1,36 @@ +import test from "ava" +import { spawn, Callback, Thread, Worker } from "../src/index" +import { AsyncGenerator } from "./workers/async-generator" +import { Generator } from "./workers/generator" + +test("can use a generator function exposed by a worker", async t => { + const generate = await spawn(new Worker("./workers/generator")) + + try { + const results: number[] = [] + + for await (const i of await generate(3)) { + results.push(i) + } + + t.deepEqual(results, [1, 2, 3]) + } finally { + await Thread.terminate(generate) + } +}) + +test("can use an async generator function exposed by a worker", async t => { + const generate = await spawn(new Worker("./workers/async-generator")) + + try { + const results: number[] = [] + + for await (const i of await generate(3)) { + results.push(i) + } + + t.deepEqual(results, [1, 2, 3]) + } finally { + await Thread.terminate(generate) + } +}) diff --git a/test/workers/async-generator.ts b/test/workers/async-generator.ts new file mode 100644 index 00000000..3c8d8e20 --- /dev/null +++ b/test/workers/async-generator.ts @@ -0,0 +1,12 @@ +import { expose } from "../../src/worker" + +export type AsyncGenerator = (count: number) => AsyncIterator + +const delay = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)) + +expose(async function* generator(count: number) { + for (let i = 1; i <= count; i++) { + await delay(2) + yield i + } +}) diff --git a/test/workers/generator.ts b/test/workers/generator.ts new file mode 100644 index 00000000..bafc5853 --- /dev/null +++ b/test/workers/generator.ts @@ -0,0 +1,9 @@ +import { expose } from "../../src/worker" + +export type Generator = (count: number) => Iterator + +expose(function *generator(count: number) { + for (let i = 1; i <= count; i++) { + yield i + } +}) From 18968f96c3cac0528b4196370001f0b86e24f799 Mon Sep 17 00:00:00 2001 From: Andy Wermke Date: Tue, 23 Jun 2020 17:00:01 +0200 Subject: [PATCH 3/3] Improve iterable detection Thanks, @kimamula! (See #256) --- src/serializers/iterators.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/serializers/iterators.ts b/src/serializers/iterators.ts index 510a8821..35279119 100644 --- a/src/serializers/iterators.ts +++ b/src/serializers/iterators.ts @@ -41,7 +41,10 @@ export const DefaultIteratorSerializer = (rootSerializer: Serializer): Serialize }) export const isIterator = (thing: any): thing is Iterator | AsyncIterator => - thing && typeof thing === "object" && "next" in thing && typeof thing.next === "function" + thing && typeof thing === "object" && ( + typeof thing.next === "function" || + typeof (thing as AsyncIterable)[Symbol.asyncIterator] === "function" + ) export const isSerializedIterator = (thing: any): thing is SerializedIterator => thing && typeof thing === "object" && "__iterator_marker" in thing && thing.__iterator_marker === "$$iterator"