diff --git a/Cargo.lock b/Cargo.lock index 7db559e..c6e3043 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1918,6 +1918,7 @@ dependencies = [ "futures", "log", "object_store", + "openssl", "parquet", "prost", "serde", @@ -3763,6 +3764,15 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" +[[package]] +name = "openssl-src" +version = "300.5.4+3.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a507b3792995dae9b0df8a1c1e3771e8418b7c2d9f0baeba32e6fe8b06c7cb72" +dependencies = [ + "cc", +] + [[package]] name = "openssl-sys" version = "0.9.111" @@ -3771,6 +3781,7 @@ checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" dependencies = [ "cc", "libc", + "openssl-src", "pkg-config", "vcpkg", ] diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 204943e..3c7394a 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -27,6 +27,7 @@ axum = "0.7" object_store = { version = "0.12.4", features = ["aws"] } aws-config = "1" aws-sdk-ec2 = "1" +openssl = { version = "0.10", features = ["vendored"] } [[bin]] name = "dfbench" diff --git a/benchmarks/cdk/README.md b/benchmarks/cdk/README.md index 3d7cabb..37b97a3 100644 --- a/benchmarks/cdk/README.md +++ b/benchmarks/cdk/README.md @@ -119,5 +119,5 @@ Several arguments can be passed for running the benchmarks against different sca for example: ```shell -npm run datafusion-bench -- --sf 10 --files-per-task 4 --query 7 +npm run datafusion-bench -- --datset tpch_sf10 --files-per-task 4 --query 7 ``` \ No newline at end of file diff --git a/benchmarks/cdk/bin/@bench-common.ts b/benchmarks/cdk/bin/@bench-common.ts index 2d978b6..a37fea6 100644 --- a/benchmarks/cdk/bin/@bench-common.ts +++ b/benchmarks/cdk/bin/@bench-common.ts @@ -1,135 +1,204 @@ import path from "path"; import fs from "fs/promises"; -import { z } from 'zod'; +import {z} from 'zod'; export const ROOT = path.join(__dirname, '../../..') +export const BUCKET = 's3://datafusion-distributed-benchmarks' // hardcoded in CDK code // Simple data structures export type QueryResult = { - query: string; - iterations: { elapsed: number; row_count: number }[]; + query: string; + iterations: { elapsed: number; row_count: number }[]; + failure?: string } export type BenchmarkResults = { - queries: QueryResult[]; + queries: QueryResult[]; } export const BenchmarkResults = z.object({ - queries: z.array(z.object({ - query: z.string(), - iterations: z.array(z.object({ - elapsed: z.number(), - row_count: z.number() + queries: z.array(z.object({ + query: z.string(), + iterations: z.array(z.object({ + elapsed: z.number(), + row_count: z.number() + })), + failed: z.string().optional() })) - })) }) -export const IDS = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22] - export async function writeJson(results: BenchmarkResults, outputPath?: string) { - if (!outputPath) return; - await fs.mkdir(path.dirname(outputPath), { recursive: true }); - await fs.writeFile(outputPath, JSON.stringify(results, null, 2)); + if (!outputPath) return; + await fs.mkdir(path.dirname(outputPath), { recursive: true }); + await fs.writeFile(outputPath, JSON.stringify(results, null, 2)); } export async function compareWithPrevious(results: BenchmarkResults, outputPath: string) { - let prevResults: BenchmarkResults; - try { - const prevContent = await fs.readFile(outputPath, 'utf-8'); - prevResults = BenchmarkResults.parse(JSON.parse(prevContent)); - } catch { - return; // No previous results to compare - } - - console.log('\n==== Comparison with previous run ===='); - - for (const query of results.queries) { - const prevQuery = prevResults.queries.find(q => q.query === query.query); - if (!prevQuery || prevQuery.iterations.length === 0 || query.iterations.length === 0) { - continue; + let prevResults: BenchmarkResults; + try { + const prevContent = await fs.readFile(outputPath, 'utf-8'); + prevResults = BenchmarkResults.parse(JSON.parse(prevContent)); + } catch { + return; // No previous results to compare } - const avgPrev = Math.round( - prevQuery.iterations.reduce((sum, i) => sum + i.elapsed, 0) / prevQuery.iterations.length - ); - const avg = Math.round( - query.iterations.reduce((sum, i) => sum + i.elapsed, 0) / query.iterations.length - ); - - const factor = avg < avgPrev ? avgPrev / avg : avg / avgPrev; - const tag = avg < avgPrev ? "faster" : "slower"; - const emoji = factor > 1.2 ? (avg < avgPrev ? "✅" : "❌") : (avg < avgPrev ? "✔" : "✖"); - - console.log( - `${query.query.padStart(8)}: prev=${avgPrev.toString().padStart(4)} ms, new=${avg.toString().padStart(4)} ms, ${factor.toFixed(2)}x ${tag} ${emoji}` - ); - } -} - -export interface BenchmarkRunner { - createTables(sf: number): Promise; - - executeQuery(query: string): Promise<{ rowCount: number }>; -} + console.log('\n==== Comparison with previous run ===='); -export async function runBenchmark( - runner: BenchmarkRunner, - options: { - sf: number; - iterations: number; - specificQuery?: number; - outputPath: string; - } -) { - const { sf, iterations, specificQuery, outputPath } = options; + for (const query of results.queries) { + const prevQuery = prevResults.queries.find(q => q.query === query.query); + if (!prevQuery || prevQuery.iterations.length === 0 || query.iterations.length === 0) { + continue; + } - const results: BenchmarkResults = { queries: [] }; - const queriesPath = path.join(ROOT, "testdata", "tpch", "queries") + const avgPrev = Math.round( + prevQuery.iterations.reduce((sum, i) => sum + i.elapsed, 0) / prevQuery.iterations.length + ); + const avg = Math.round( + query.iterations.reduce((sum, i) => sum + i.elapsed, 0) / query.iterations.length + ); - console.log("Creating tables..."); - await runner.createTables(sf); + const factor = avg < avgPrev ? avgPrev / avg : avg / avgPrev; + const tag = avg < avgPrev ? "faster" : "slower"; + const emoji = factor > 1.2 ? (avg < avgPrev ? "✅" : "❌") : (avg < avgPrev ? "✔" : "✖"); - for (let id of IDS) { - if (specificQuery && specificQuery !== id) { - continue; + console.log( + `${query.query.padStart(8)}: prev=${avgPrev.toString().padStart(4)} ms, new=${avg.toString().padStart(4)} ms, ${factor.toFixed(2)}x ${tag} ${emoji}` + ); } +} - const queryId = `q${id}`; - const filePath = path.join(queriesPath, `${queryId}.sql`) - const queryToExecute = await fs.readFile(filePath, 'utf-8') +export interface TableSpec { + schema: string + name: string + s3Path: string +} + +export interface BenchmarkRunner { + createTables(s3Paths: TableSpec[]): Promise; - const queryResult: QueryResult = { - query: queryId, - iterations: [] - }; + executeQuery(query: string): Promise<{ rowCount: number }>; +} - console.log(`Warming up query ${id}...`) - await runner.executeQuery(queryToExecute); +async function tablePathsForDataset(dataset: string): Promise { + const datasetPath = path.join(ROOT, "benchmarks", "data", dataset) + + const result: TableSpec[] = [] + for (const entryName of await fs.readdir(datasetPath)) { + const dir = path.join(datasetPath, entryName) + if (await isDirWithAllParquetFiles(dir)) { + result.push({ + name: entryName, + schema: dataset, + s3Path: `${BUCKET}/${dataset}/${entryName}/` + }) + } + } + return result +} - for (let i = 0; i < iterations; i++) { - const start = new Date() - const response = await runner.executeQuery(queryToExecute); - const elapsed = Math.round(new Date().getTime() - start.getTime()) +async function isDirWithAllParquetFiles(dir: string): Promise { + let readDir + try { + readDir = await fs.readdir(dir) + } catch (e) { + return false + } + for (const file of readDir) { + if (!file.endsWith(".parquet")) { + return false + } + } + return true +} - queryResult.iterations.push({ - elapsed, - row_count: response.rowCount - }); +async function queriesForDataset(dataset: string): Promise<[string, string][]> { + const datasetSuffix = dataset.split("_")[0] + const queriesPath = path.join(ROOT, "testdata", datasetSuffix, "queries") - console.log( - `Query ${id} iteration ${i} took ${elapsed} ms and returned ${response.rowCount} rows` - ); + const queries: [string, string][] = [] + for (const queryName of await fs.readdir(queriesPath)) { + const sql = await fs.readFile(path.join(queriesPath, queryName), 'utf-8'); + queries.push([queryName, sql]) } + queries.sort(([name1], [name2]) => numericId(name1) > numericId(name2) ? 1 : -1) + return queries +} - const avg = Math.round( - queryResult.iterations.reduce((a, b) => a + b.elapsed, 0) / queryResult.iterations.length - ); - console.log(`Query ${id} avg time: ${avg} ms`); +function numericId(queryName: string): number { + return parseInt([...queryName.matchAll(/(\d+)/g)][0][0]) +} - results.queries.push(queryResult); - } +export async function runBenchmark( + runner: BenchmarkRunner, + options: { + dataset: string + iterations: number; + queries: number[]; + outputPath: string; + } +) { + const { dataset, iterations, queries, outputPath } = options; + + const results: BenchmarkResults = { queries: [] }; + + console.log("Creating tables..."); + const s3Paths = await tablePathsForDataset(dataset) + await runner.createTables(s3Paths); + + for (const [queryName, sql] of await queriesForDataset(dataset)) { + const id = numericId(queryName) + + if (queries.length > 0 && !queries.includes(id)) { + continue; + } + + const queryResult: QueryResult = { + query: queryName, + iterations: [], + }; + + console.log(`Warming up query ${id}...`) + try { + await runner.executeQuery(sql); + } catch (e: any) { + queryResult.failure = e.toString(); + console.error(`Query ${queryResult.query} failed: ${queryResult.failure}`) + continue + } + + for (let i = 0; i < iterations; i++) { + const start = new Date() + let response + try { + response = await runner.executeQuery(sql); + } catch (e: any) { + queryResult.failure = e.toString(); + break + } + const elapsed = Math.round(new Date().getTime() - start.getTime()) + + queryResult.iterations.push({ + elapsed, + row_count: response.rowCount + }); + + console.log( + `Query ${id} iteration ${i} took ${elapsed} ms and returned ${response.rowCount} rows` + ); + } + + const avg = Math.round( + queryResult.iterations.reduce((a, b) => a + b.elapsed, 0) / queryResult.iterations.length + ); + console.log(`Query ${id} avg time: ${avg} ms`); + + if (queryResult.failure) { + console.error(`Query ${queryResult.query} failed: ${queryResult.failure}`) + } + results.queries.push(queryResult); + } - // Write results and compare - await compareWithPrevious(results, outputPath); - await writeJson(results, outputPath); + // Write results and compare + await compareWithPrevious(results, outputPath); + await writeJson(results, outputPath); } diff --git a/benchmarks/cdk/bin/datafusion-bench.ts b/benchmarks/cdk/bin/datafusion-bench.ts index 4c5b340..c817a71 100644 --- a/benchmarks/cdk/bin/datafusion-bench.ts +++ b/benchmarks/cdk/bin/datafusion-bench.ts @@ -1,7 +1,7 @@ import path from "path"; import {Command} from "commander"; import {z} from 'zod'; -import {BenchmarkRunner, ROOT, runBenchmark} from "./@bench-common"; +import {BenchmarkRunner, ROOT, runBenchmark, TableSpec} from "./@bench-common"; // Remember to port-forward a worker with // aws ssm start-session --target {host-id} --document-name AWS-StartPortForwardingSession --parameters "portNumber=9000,localPortNumber=9000" @@ -10,7 +10,7 @@ async function main() { const program = new Command(); program - .option('--sf ', 'Scale factor', '1') + .option('--dataset ', 'Dataset to run queries on') .option('-i, --iterations ', 'Number of iterations', '3') .option('--files-per-task ', 'Files per task', '4') .option('--cardinality-task-sf ', 'Cardinality task scale factor', '2') @@ -21,12 +21,12 @@ async function main() { const options = program.opts(); - const sf = parseInt(options.sf); + const dataset: string = options.dataset const iterations = parseInt(options.iterations); const filesPerTask = parseInt(options.filesPerTask); const cardinalityTaskSf = parseInt(options.cardinalityTaskSf); const shuffleBatchSize = parseInt(options.shuffleBatchSize); - const specificQuery = options.query ? parseInt(options.query) : undefined; + const queries = options.query ? [parseInt(options.query)] : []; const collectMetrics = options.collectMetrics === 'true' || options.collectMetrics === 1 const runner = new DataFusionRunner({ @@ -36,12 +36,13 @@ async function main() { collectMetrics }); - const outputPath = path.join(ROOT, "benchmarks", "data", `tpch_sf${sf}`, "remote-results.json"); + const datasetPath = path.join(ROOT, "benchmarks", "data", dataset); + const outputPath = path.join(datasetPath, "remote-results.json") await runBenchmark(runner, { - sf, + dataset, iterations, - specificQuery, + queries, outputPath, }); } @@ -75,7 +76,7 @@ class DataFusionRunner implements BenchmarkRunner { response = await this.query(sql) } - return {rowCount: response.count}; + return { rowCount: response.count }; } private async query(sql: string): Promise { @@ -93,22 +94,13 @@ class DataFusionRunner implements BenchmarkRunner { return QueryResponse.parse(unparsed); } - async createTables(sf: number): Promise { + async createTables(tables: TableSpec[]): Promise { let stmt = ''; - for (const tbl of [ - "lineitem", - "orders", - "part", - "partsupp", - "customer", - "nation", - "region", - "supplier", - ]) { + for (const table of tables) { // language=SQL format=false stmt += ` - DROP TABLE IF EXISTS ${tbl}; - CREATE EXTERNAL TABLE IF NOT EXISTS ${tbl} STORED AS PARQUET LOCATION 's3://datafusion-distributed-benchmarks/tpch_sf${sf}/${tbl}/'; + DROP TABLE IF EXISTS ${table.name}; + CREATE EXTERNAL TABLE IF NOT EXISTS ${table.name} STORED AS PARQUET LOCATION '${table.s3Path}'; `; } await this.query(stmt); diff --git a/benchmarks/cdk/bin/trino-bench.ts b/benchmarks/cdk/bin/trino-bench.ts index d353637..136b3a0 100644 --- a/benchmarks/cdk/bin/trino-bench.ts +++ b/benchmarks/cdk/bin/trino-bench.ts @@ -1,249 +1,825 @@ import path from "path"; -import { Command } from "commander"; -import { ROOT, runBenchmark, BenchmarkRunner } from "./@bench-common"; +import {Command} from "commander"; +import {ROOT, runBenchmark, BenchmarkRunner, TableSpec} from "./@bench-common"; // Remember to port-forward Trino coordinator with // aws ssm start-session --target {instance-0-id} --document-name AWS-StartPortForwardingSession --parameters "portNumber=8080,localPortNumber=8080" async function main() { - const program = new Command(); + const program = new Command(); - program - .option('--sf ', 'Scale factor', '1') - .option('-i, --iterations ', 'Number of iterations', '3') - .option('--query ', 'A specific query to run', undefined) - .parse(process.argv); + program + .option('--dataset ', 'Scale factor', '1') + .option('-i, --iterations ', 'Number of iterations', '3') + .option('--query ', 'A specific query to run', undefined) + .parse(process.argv); - const options = program.opts(); + const options = program.opts(); - const sf = parseInt(options.sf); - const iterations = parseInt(options.iterations); - const specificQuery = options.query ? parseInt(options.query) : undefined; + const dataset: string = options.dataset + const iterations = parseInt(options.iterations); + const queries = options.query ? [parseInt(options.query)] : []; - const runner = new TrinoRunner({ sf }); - const outputPath = path.join(ROOT, "benchmarks", "data", `tpch_sf${sf}`, "remote-results.json"); + const datasetPath = path.join(ROOT, "benchmarks", "data", dataset); + const outputPath = path.join(datasetPath, "remote-results.json") - await runBenchmark(runner, { - sf, - iterations, - specificQuery, - outputPath, - }); + const runner = new TrinoRunner(); + + await runBenchmark(runner, { + dataset, + iterations, + queries, + outputPath, + }); } class TrinoRunner implements BenchmarkRunner { - private trinoUrl = 'http://localhost:8080'; - - constructor(private readonly options: { - sf: number - }) { - } - - - async executeQuery(sql: string): Promise<{ rowCount: number }> { - // Fix query 4: Add DATE prefix to date literals that don't have it. - sql = sql.replace(/(? { + // Fix TPCH query 4: Add DATE prefix to date literals that don't have it. + sql = sql.replace(/(? { - // Submit query - const submitResponse = await fetch(`${this.trinoUrl}/v1/statement`, { - method: 'POST', - headers: { - 'X-Trino-User': 'benchmark', - 'X-Trino-Catalog': 'hive', - 'X-Trino-Schema': `tpch_sf${this.options.sf}`, - }, - body: sql.trim().replace(/;+$/, ''), - }); + private async executeSingleStatement(sql: string): Promise<{ rowCount: number }> { + if (!this.schema) { + throw new Error("No schema available, where the tables created?") + } + + // Submit query + const submitResponse = await fetch(`${this.trinoUrl}/v1/statement`, { + method: 'POST', + headers: { + 'X-Trino-User': 'benchmark', + 'X-Trino-Catalog': 'hive', + 'X-Trino-Schema': this.schema ?? '', + }, + body: sql.trim().replace(/;+$/, ''), + }); + + if (!submitResponse.ok) { + const msg = await submitResponse.text(); + throw new Error(`Query submission failed: ${submitResponse.status} ${msg}`); + } - if (!submitResponse.ok) { - const msg = await submitResponse.text(); - throw new Error(`Query submission failed: ${submitResponse.status} ${msg}`); + let result: any = await submitResponse.json(); + let rowCount = 0; + + // Poll for results + while (result.nextUri) { + const pollResponse = await fetch(result.nextUri); + + if (!pollResponse.ok) { + const msg = await pollResponse.text(); + throw new Error(`Query polling failed: ${pollResponse.status} ${msg}`); + } + + result = await pollResponse.json(); + + // Count rows if data is present + if (result.data) { + if (typeof result.data?.[0]?.[0] === 'string') { + // Extract row count from EXPLAIN ANALYZE output + const outputMatch = result.data[0][0].match(/Output.*?(\d+)\s+rows/i); + if (outputMatch) { + rowCount = parseInt(outputMatch[1]); + } + } else { + rowCount += result.data.length; + } + } + + // Check for errors + if (result.error) { + throw new Error(`Query failed: ${result.error.message}`); + } + } + + return { rowCount }; } - let result: any = await submitResponse.json(); - let rowCount = 0; + async createTables(tables: TableSpec[]): Promise { + if (tables.length === 0) { + throw new Error("No table passed") + } + let schema = tables[0].schema + let basePath = tables[0].s3Path.split('/').slice(0, -1).join("/") - // Poll for results - while (result.nextUri) { - const pollResponse = await fetch(result.nextUri); + this.schema = schema - if (!pollResponse.ok) { - const msg = await pollResponse.text(); - throw new Error(`Query polling failed: ${pollResponse.status} ${msg}`); - } + await this.executeSingleStatement(` + CREATE SCHEMA IF NOT EXISTS hive."${schema}" WITH (location = '${basePath}')`); - result = await pollResponse.json(); + for (const table of tables) { + await this.executeSingleStatement(` + DROP TABLE IF EXISTS hive."${table.schema}"."${table.name}"`); - // Count rows if data is present - if (result.data) { - if (typeof result.data?.[0]?.[0] === 'string') { - // Extract row count from EXPLAIN ANALYZE output - const outputMatch = result.data[0][0].match(/Output.*?(\d+)\s+rows/i); - if (outputMatch) { - rowCount = parseInt(outputMatch[1]); - } - } else { - rowCount += result.data.length; + await this.executeSingleStatement(` + CREATE TABLE hive."${table.schema}"."${table.name}" ${getSchema(table)} + WITH (external_location = '${table.s3Path}', format = 'PARQUET')`); } - } + } +} - // Check for errors - if (result.error) { - throw new Error(`Query failed: ${result.error.message}`); - } +const SCHEMAS: Record> = { + tpch: { + customer: `( + c_custkey bigint, + c_name varchar(25), + c_address varchar(40), + c_nationkey bigint, + c_phone varchar(15), + c_acctbal decimal(15, 2), + c_mktsegment varchar(10), + c_comment varchar(117) +)`, + lineitem: `( + l_orderkey bigint, + l_partkey bigint, + l_suppkey bigint, + l_linenumber integer, + l_quantity decimal(15, 2), + l_extendedprice decimal(15, 2), + l_discount decimal(15, 2), + l_tax decimal(15, 2), + l_returnflag varchar(1), + l_linestatus varchar(1), + l_shipdate date, + l_commitdate date, + l_receiptdate date, + l_shipinstruct varchar(25), + l_shipmode varchar(10), + l_comment varchar(44) +)`, + nation: `( + n_nationkey bigint, + n_name varchar(25), + n_regionkey bigint, + n_comment varchar(152) +)`, + orders: `( + o_orderkey bigint, + o_custkey bigint, + o_orderstatus varchar(1), + o_totalprice decimal(15, 2), + o_orderdate date, + o_orderpriority varchar(15), + o_clerk varchar(15), + o_shippriority integer, + o_comment varchar(79) +)`, + part: `( + p_partkey bigint, + p_name varchar(55), + p_mfgr varchar(25), + p_brand varchar(10), + p_type varchar(25), + p_size integer, + p_container varchar(10), + p_retailprice decimal(15, 2), + p_comment varchar(23) +)`, + partsupp: `( + ps_partkey bigint, + ps_suppkey bigint, + ps_availqty integer, + ps_supplycost decimal(15, 2), + ps_comment varchar(199) +)`, + region: `( + r_regionkey bigint, + r_name varchar(25), + r_comment varchar(152) +)`, + supplier: `( + s_suppkey bigint, + s_name varchar(25), + s_address varchar(40), + s_nationkey bigint, + s_phone varchar(15), + s_acctbal decimal(15, 2), + s_comment varchar(101) +)` + }, + clickbench: { + hits: `( + WatchID bigint, + JavaEnable smallint, + Title varchar, + GoodEvent smallint, + EventTime bigint, + EventDate date, + CounterID integer, + ClientIP integer, + RegionID integer, + UserID bigint, + CounterClass smallint, + OS smallint, + UserAgent smallint, + URL varchar, + Referer varchar, + IsRefresh smallint, + RefererCategoryID smallint, + RefererRegionID integer, + URLCategoryID smallint, + URLRegionID integer, + ResolutionWidth smallint, + ResolutionHeight smallint, + ResolutionDepth smallint, + FlashMajor smallint, + FlashMinor smallint, + FlashMinor2 varchar, + NetMajor smallint, + NetMinor smallint, + UserAgentMajor smallint, + UserAgentMinor varchar(255), + CookieEnable smallint, + JavascriptEnable smallint, + IsMobile smallint, + MobilePhone smallint, + MobilePhoneModel varchar, + Params varchar, + IPNetworkID integer, + TraficSourceID smallint, + SearchEngineID smallint, + SearchPhrase varchar, + AdvEngineID smallint, + IsArtifical smallint, + WindowClientWidth smallint, + WindowClientHeight smallint, + ClientTimeZone smallint, + ClientEventTime bigint, + SilverlightVersion1 smallint, + SilverlightVersion2 smallint, + SilverlightVersion3 integer, + SilverlightVersion4 smallint, + PageCharset varchar, + CodeVersion integer, + IsLink smallint, + IsDownload smallint, + IsNotBounce smallint, + FUniqID bigint, + OriginalURL varchar, + HID integer, + IsOldCounter smallint, + IsEvent smallint, + IsParameter smallint, + DontCountHits smallint, + WithHash smallint, + HitColor varchar(1), + LocalEventTime bigint, + Age smallint, + Sex smallint, + Income smallint, + Interests smallint, + Robotness smallint, + RemoteIP integer, + WindowName integer, + OpenerName integer, + HistoryLength smallint, + BrowserLanguage varchar, + BrowserCountry varchar, + SocialNetwork varchar, + SocialAction varchar, + HTTPError smallint, + SendTiming integer, + DNSTiming integer, + ConnectTiming integer, + ResponseStartTiming integer, + ResponseEndTiming integer, + FetchTiming integer, + SocialSourceNetworkID smallint, + SocialSourcePage varchar, + ParamPrice bigint, + ParamOrderID varchar, + ParamCurrency varchar, + ParamCurrencyID smallint, + OpenstatServiceName varchar, + OpenstatCampaignID varchar, + OpenstatAdID varchar, + OpenstatSourceID varchar, + UTMSource varchar, + UTMMedium varchar, + UTMCampaign varchar, + UTMContent varchar, + UTMTerm varchar, + FromTag varchar, + HasGCLID smallint, + RefererHash bigint, + URLHash bigint, + CLID integer +)` + }, + tpcds: { + call_center: `( + cc_call_center_sk integer, + cc_call_center_id varchar, + cc_rec_start_date date, + cc_rec_end_date date, + cc_closed_date_sk double, + cc_open_date_sk integer, + cc_name varchar, + cc_class varchar, + cc_employees integer, + cc_sq_ft integer, + cc_hours varchar, + cc_manager varchar, + cc_mkt_id integer, + cc_mkt_class varchar, + cc_mkt_desc varchar, + cc_market_manager varchar, + cc_division integer, + cc_division_name varchar, + cc_company integer, + cc_company_name varchar, + cc_street_number varchar, + cc_street_name varchar, + cc_street_type varchar, + cc_suite_number varchar, + cc_city varchar, + cc_county varchar, + cc_state varchar, + cc_zip varchar, + cc_country varchar, + cc_gmt_offset decimal(3, 2), + cc_tax_percentage decimal(2, 2) +)`, + catalog_page: `( + cp_catalog_page_sk integer, + cp_catalog_page_id varchar, + cp_start_date_sk double, + cp_end_date_sk double, + cp_department varchar, + cp_catalog_number double, + cp_catalog_page_number double, + cp_description varchar, + cp_type varchar +)`, + catalog_returns: `( + cr_returned_date_sk integer, + cr_returned_time_sk integer, + cr_item_sk integer, + cr_refunded_customer_sk double, + cr_refunded_cdemo_sk double, + cr_refunded_hdemo_sk double, + cr_refunded_addr_sk double, + cr_returning_customer_sk double, + cr_returning_cdemo_sk double, + cr_returning_hdemo_sk double, + cr_returning_addr_sk double, + cr_call_center_sk double, + cr_catalog_page_sk double, + cr_ship_mode_sk double, + cr_warehouse_sk double, + cr_reason_sk double, + cr_order_number integer, + cr_return_quantity double, + cr_return_amount decimal(7, 2), + cr_return_tax decimal(6, 2), + cr_return_amt_inc_tax decimal(7, 2), + cr_fee decimal(5, 2), + cr_return_ship_cost decimal(7, 2), + cr_refunded_cash decimal(7, 2), + cr_reversed_charge decimal(7, 2), + cr_store_credit decimal(7, 2), + cr_net_loss decimal(7, 2) +)`, + catalog_sales: `( + cs_sold_date_sk double, + cs_sold_time_sk double, + cs_ship_date_sk double, + cs_bill_customer_sk double, + cs_bill_cdemo_sk double, + cs_bill_hdemo_sk double, + cs_bill_addr_sk double, + cs_ship_customer_sk double, + cs_ship_cdemo_sk double, + cs_ship_hdemo_sk double, + cs_ship_addr_sk double, + cs_call_center_sk double, + cs_catalog_page_sk double, + cs_ship_mode_sk double, + cs_warehouse_sk double, + cs_item_sk integer, + cs_promo_sk double, + cs_order_number integer, + cs_quantity double, + cs_wholesale_cost decimal(5, 2), + cs_list_price decimal(5, 2), + cs_sales_price decimal(5, 2), + cs_ext_discount_amt decimal(7, 2), + cs_ext_sales_price decimal(7, 2), + cs_ext_wholesale_cost decimal(7, 2), + cs_ext_list_price decimal(7, 2), + cs_ext_tax decimal(6, 2), + cs_coupon_amt decimal(7, 2), + cs_ext_ship_cost decimal(7, 2), + cs_net_paid decimal(7, 2), + cs_net_paid_inc_tax decimal(7, 2), + cs_net_paid_inc_ship decimal(7, 2), + cs_net_paid_inc_ship_tax decimal(7, 2), + cs_net_profit decimal(7, 2) +)`, + customer: `( + c_customer_sk integer, + c_customer_id varchar, + c_current_cdemo_sk double, + c_current_hdemo_sk double, + c_current_addr_sk integer, + c_first_shipto_date_sk double, + c_first_sales_date_sk double, + c_salutation varchar, + c_first_name varchar, + c_last_name varchar, + c_preferred_cust_flag varchar, + c_birth_day double, + c_birth_month double, + c_birth_year double, + c_birth_country varchar, + c_login varchar, + c_email_address varchar, + c_last_review_date double +)`, + customer_address: `( + ca_address_sk integer, + ca_address_id varchar, + ca_street_number varchar, + ca_street_name varchar, + ca_street_type varchar, + ca_suite_number varchar, + ca_city varchar, + ca_county varchar, + ca_state varchar, + ca_zip varchar, + ca_country varchar, + ca_gmt_offset decimal(4, 2), + ca_location_type varchar +)`, + customer_demographics: `( + cd_demo_sk integer, + cd_gender varchar, + cd_marital_status varchar, + cd_education_status varchar, + cd_purchase_estimate integer, + cd_credit_rating varchar, + cd_dep_count integer, + cd_dep_employed_count integer, + cd_dep_college_count integer +)`, + date_dim: `( + d_date_sk integer, + d_date_id varchar, + d_date date, + d_month_seq integer, + d_week_seq integer, + d_quarter_seq integer, + d_year integer, + d_dow integer, + d_moy integer, + d_dom integer, + d_qoy integer, + d_fy_year integer, + d_fy_quarter_seq integer, + d_fy_week_seq integer, + d_day_name varchar, + d_quarter_name varchar, + d_holiday varchar, + d_weekend varchar, + d_following_holiday varchar, + d_first_dom integer, + d_last_dom integer, + d_same_day_ly integer, + d_same_day_lq integer, + d_current_day varchar, + d_current_week varchar, + d_current_month varchar, + d_current_quarter varchar, + d_current_year varchar +)`, + household_demographics: `( + hd_demo_sk integer, + hd_income_band_sk integer, + hd_buy_potential varchar, + hd_dep_count integer, + hd_vehicle_count integer +)`, + income_band: `( + ib_income_band_sk integer, + ib_lower_bound integer, + ib_upper_bound integer +)`, + inventory: `( + inv_date_sk integer, + inv_item_sk integer, + inv_warehouse_sk integer, + inv_quantity_on_hand double +)`, + item: `( + i_item_sk integer, + i_item_id varchar, + i_rec_start_date date, + i_rec_end_date date, + i_item_desc varchar, + i_current_price decimal(4, 2), + i_wholesale_cost decimal(4, 2), + i_brand_id double, + i_brand varchar, + i_class_id double, + i_class varchar, + i_category_id double, + i_category varchar, + i_manufact_id double, + i_manufact varchar, + i_size varchar, + i_formulation varchar, + i_color varchar, + i_units varchar, + i_container varchar, + i_manager_id double, + i_product_name varchar +)`, + promotion: `( + p_promo_sk integer, + p_promo_id varchar, + p_start_date_sk double, + p_end_date_sk double, + p_item_sk double, + p_cost decimal(6, 2), + p_response_target double, + p_promo_name varchar, + p_channel_dmail varchar, + p_channel_email varchar, + p_channel_catalog varchar, + p_channel_tv varchar, + p_channel_radio varchar, + p_channel_press varchar, + p_channel_event varchar, + p_channel_demo varchar, + p_channel_details varchar, + p_purpose varchar, + p_discount_active varchar +)`, + reason: `( + r_reason_sk integer, + r_reason_id varchar, + r_reason_desc varchar +)`, + ship_mode: `( + sm_ship_mode_sk integer, + sm_ship_mode_id varchar, + sm_type varchar, + sm_code varchar, + sm_carrier varchar, + sm_contract varchar +)`, + store: `( + s_store_sk integer, + s_store_id varchar, + s_rec_start_date date, + s_rec_end_date date, + s_closed_date_sk double, + s_store_name varchar, + s_number_employees integer, + s_floor_space integer, + s_hours varchar, + s_manager varchar, + s_market_id integer, + s_geography_class varchar, + s_market_desc varchar, + s_market_manager varchar, + s_division_id integer, + s_division_name varchar, + s_company_id integer, + s_company_name varchar, + s_street_number varchar, + s_street_name varchar, + s_street_type varchar, + s_suite_number varchar, + s_city varchar, + s_county varchar, + s_state varchar, + s_zip varchar, + s_country varchar, + s_gmt_offset decimal(3, 2), + s_tax_precentage decimal(2, 2) +)`, + store_returns: `( + sr_returned_date_sk double, + sr_return_time_sk double, + sr_item_sk integer, + sr_customer_sk double, + sr_cdemo_sk double, + sr_hdemo_sk double, + sr_addr_sk double, + sr_store_sk double, + sr_reason_sk double, + sr_ticket_number integer, + sr_return_quantity double, + sr_return_amt decimal(7, 2), + sr_return_tax decimal(6, 2), + sr_return_amt_inc_tax decimal(7, 2), + sr_fee decimal(5, 2), + sr_return_ship_cost decimal(6, 2), + sr_refunded_cash decimal(7, 2), + sr_reversed_charge decimal(7, 2), + sr_store_credit decimal(7, 2), + sr_net_loss decimal(6, 2) +)`, + store_sales: `( + ss_sold_date_sk double, + ss_sold_time_sk double, + ss_item_sk integer, + ss_customer_sk double, + ss_cdemo_sk double, + ss_hdemo_sk double, + ss_addr_sk double, + ss_store_sk double, + ss_promo_sk double, + ss_ticket_number integer, + ss_quantity double, + ss_wholesale_cost decimal(5, 2), + ss_list_price decimal(5, 2), + ss_sales_price decimal(5, 2), + ss_ext_discount_amt decimal(7, 2), + ss_ext_sales_price decimal(7, 2), + ss_ext_wholesale_cost decimal(7, 2), + ss_ext_list_price decimal(7, 2), + ss_ext_tax decimal(6, 2), + ss_coupon_amt decimal(7, 2), + ss_net_paid decimal(7, 2), + ss_net_paid_inc_tax decimal(7, 2), + ss_net_profit decimal(6, 2) +)`, + time_dim: `( + t_time_sk integer, + t_time_id varchar, + t_time integer, + t_hour integer, + t_minute integer, + t_second integer, + t_am_pm varchar, + t_shift varchar, + t_sub_shift varchar, + t_meal_time varchar +)`, + warehouse: `( + w_warehouse_sk integer, + w_warehouse_id varchar, + w_warehouse_name varchar, + w_warehouse_sq_ft integer, + w_street_number varchar, + w_street_name varchar, + w_street_type varchar, + w_suite_number varchar, + w_city varchar, + w_county varchar, + w_state varchar, + w_zip varchar, + w_country varchar, + w_gmt_offset decimal(3, 2) +)`, + web_page: `( + wp_web_page_sk integer, + wp_web_page_id varchar, + wp_rec_start_date date, + wp_rec_end_date date, + wp_creation_date_sk integer, + wp_access_date_sk integer, + wp_autogen_flag varchar, + wp_customer_sk double, + wp_url varchar, + wp_type varchar, + wp_char_count integer, + wp_link_count integer, + wp_image_count integer, + wp_max_ad_count integer +)`, + web_returns: `( + wr_returned_date_sk double, + wr_returned_time_sk double, + wr_item_sk integer, + wr_refunded_customer_sk double, + wr_refunded_cdemo_sk double, + wr_refunded_hdemo_sk double, + wr_refunded_addr_sk double, + wr_returning_customer_sk double, + wr_returning_cdemo_sk double, + wr_returning_hdemo_sk double, + wr_returning_addr_sk double, + wr_web_page_sk double, + wr_reason_sk double, + wr_order_number integer, + wr_return_quantity double, + wr_return_amt decimal(7, 2), + wr_return_tax decimal(6, 2), + wr_return_amt_inc_tax decimal(7, 2), + wr_fee decimal(5, 2), + wr_return_ship_cost decimal(7, 2), + wr_refunded_cash decimal(7, 2), + wr_reversed_charge decimal(7, 2), + wr_account_credit decimal(7, 2), + wr_net_loss decimal(7, 2) +)`, + web_sales: `( + ws_sold_date_sk double, + ws_sold_time_sk double, + ws_ship_date_sk double, + ws_item_sk integer, + ws_bill_customer_sk double, + ws_bill_cdemo_sk double, + ws_bill_hdemo_sk double, + ws_bill_addr_sk double, + ws_ship_customer_sk double, + ws_ship_cdemo_sk double, + ws_ship_hdemo_sk double, + ws_ship_addr_sk double, + ws_web_page_sk double, + ws_web_site_sk double, + ws_ship_mode_sk double, + ws_warehouse_sk double, + ws_promo_sk double, + ws_order_number integer, + ws_quantity double, + ws_wholesale_cost decimal(5, 2), + ws_list_price decimal(5, 2), + ws_sales_price decimal(5, 2), + ws_ext_discount_amt decimal(7, 2), + ws_ext_sales_price decimal(7, 2), + ws_ext_wholesale_cost decimal(7, 2), + ws_ext_list_price decimal(7, 2), + ws_ext_tax decimal(6, 2), + ws_coupon_amt decimal(7, 2), + ws_ext_ship_cost decimal(7, 2), + ws_net_paid decimal(7, 2), + ws_net_paid_inc_tax decimal(7, 2), + ws_net_paid_inc_ship decimal(7, 2), + ws_net_paid_inc_ship_tax decimal(7, 2), + ws_net_profit decimal(7, 2) +)`, + web_site: `( + web_site_sk integer, + web_site_id varchar, + web_rec_start_date date, + web_rec_end_date date, + web_name varchar, + web_open_date_sk double, + web_close_date_sk double, + web_class varchar, + web_manager varchar, + web_mkt_id integer, + web_mkt_class varchar, + web_mkt_desc varchar, + web_market_manager varchar, + web_company_id integer, + web_company_name varchar, + web_street_number varchar, + web_street_name varchar, + web_street_type varchar, + web_suite_number varchar, + web_city varchar, + web_county varchar, + web_state varchar, + web_zip varchar, + web_country varchar, + web_gmt_offset decimal(3, 2), + web_tax_percentage decimal(2, 2) +)` } +} - return { rowCount }; - } - - async createTables(sf: number): Promise { - const schema = `tpch_sf${sf}`; - - // Create schema first - await this.executeSingleStatement(`CREATE SCHEMA IF NOT EXISTS hive.${schema} WITH (location = 's3://datafusion-distributed-benchmarks/tpch_sf${sf}/')`); - - // Create customer table - await this.executeSingleStatement(`DROP TABLE IF EXISTS hive.${schema}.customer`); - await this.executeSingleStatement(`CREATE TABLE hive.${schema}.customer - ( - c_custkey bigint, - c_name varchar(25), - c_address varchar(40), - c_nationkey bigint, - c_phone varchar(15), - c_acctbal decimal(15, 2), - c_mktsegment varchar(10), - c_comment varchar(117) - ) - WITH (external_location = 's3://datafusion-distributed-benchmarks/tpch_sf${sf}/customer/', format = 'PARQUET')`); - - // Create lineitem table - await this.executeSingleStatement(`DROP TABLE IF EXISTS hive.${schema}.lineitem`); - await this.executeSingleStatement(`CREATE TABLE hive.${schema}.lineitem - ( - l_orderkey bigint, - l_partkey bigint, - l_suppkey bigint, - l_linenumber integer, - l_quantity decimal(15, 2), - l_extendedprice decimal(15, 2), - l_discount decimal(15, 2), - l_tax decimal(15, 2), - l_returnflag varchar(1), - l_linestatus varchar(1), - l_shipdate date, - l_commitdate date, - l_receiptdate date, - l_shipinstruct varchar(25), - l_shipmode varchar(10), - l_comment varchar(44) - ) - WITH (external_location = 's3://datafusion-distributed-benchmarks/tpch_sf${sf}/lineitem/', format = 'PARQUET')`); - - // Create nation table - await this.executeSingleStatement(`DROP TABLE IF EXISTS hive.${schema}.nation`); - await this.executeSingleStatement(`CREATE TABLE hive.${schema}.nation - ( - n_nationkey bigint, - n_name varchar(25), - n_regionkey bigint, - n_comment varchar(152) - ) - WITH (external_location = 's3://datafusion-distributed-benchmarks/tpch_sf${sf}/nation/', format = 'PARQUET')`); - - // Create orders table - await this.executeSingleStatement(`DROP TABLE IF EXISTS hive.${schema}.orders`); - await this.executeSingleStatement(`CREATE TABLE hive.${schema}.orders - ( - o_orderkey bigint, - o_custkey bigint, - o_orderstatus varchar(1), - o_totalprice decimal(15, 2), - o_orderdate date, - o_orderpriority varchar(15), - o_clerk varchar(15), - o_shippriority integer, - o_comment varchar(79) - ) - WITH (external_location = 's3://datafusion-distributed-benchmarks/tpch_sf${sf}/orders/', format = 'PARQUET')`); - - // Create part table - await this.executeSingleStatement(`DROP TABLE IF EXISTS hive.${schema}.part`); - await this.executeSingleStatement(`CREATE TABLE hive.${schema}.part - ( - p_partkey bigint, - p_name varchar(55), - p_mfgr varchar(25), - p_brand varchar(10), - p_type varchar(25), - p_size integer, - p_container varchar(10), - p_retailprice decimal(15, 2), - p_comment varchar(23) - ) - WITH (external_location = 's3://datafusion-distributed-benchmarks/tpch_sf${sf}/part/', format = 'PARQUET')`); - - // Create partsupp table - await this.executeSingleStatement(`DROP TABLE IF EXISTS hive.${schema}.partsupp`); - await this.executeSingleStatement(`CREATE TABLE hive.${schema}.partsupp - ( - ps_partkey bigint, - ps_suppkey bigint, - ps_availqty integer, - ps_supplycost decimal(15, 2), - ps_comment varchar(199) - ) - WITH (external_location = 's3://datafusion-distributed-benchmarks/tpch_sf${sf}/partsupp/', format = 'PARQUET')`); - - // Create region table - await this.executeSingleStatement(`DROP TABLE IF EXISTS hive.${schema}.region`); - await this.executeSingleStatement(`CREATE TABLE hive.${schema}.region - ( - r_regionkey bigint, - r_name varchar(25), - r_comment varchar(152) - ) - WITH (external_location = 's3://datafusion-distributed-benchmarks/tpch_sf${sf}/region/', format = 'PARQUET')`); - - // Create supplier table - await this.executeSingleStatement(`DROP TABLE IF EXISTS hive.${schema}.supplier`); - await this.executeSingleStatement(`CREATE TABLE hive.${schema}.supplier - ( - s_suppkey bigint, - s_name varchar(25), - s_address varchar(40), - s_nationkey bigint, - s_phone varchar(15), - s_acctbal decimal(15, 2), - s_comment varchar(101) - ) - WITH (external_location = 's3://datafusion-distributed-benchmarks/tpch_sf${sf}/supplier/', format = 'PARQUET')`); - } +function getSchema(table: TableSpec): string { + const tableSchema = SCHEMAS[table.schema.split("_")[0]]?.[table.name] + if (!tableSchema) { + throw new Error(`Could not find table ${table.name} in schema ${table.schema}`) + } + return tableSchema } main() - .catch(err => { - console.error(err) - process.exit(1) - }) + .catch(err => { + console.error(err) + process.exit(1) + }) diff --git a/benchmarks/cdk/bin/worker.rs b/benchmarks/cdk/bin/worker.rs index ab3e398..45bdfbc 100644 --- a/benchmarks/cdk/bin/worker.rs +++ b/benchmarks/cdk/bin/worker.rs @@ -6,11 +6,11 @@ use datafusion::common::DataFusionError; use datafusion::common::instant::Instant; use datafusion::common::runtime::SpawnedTask; use datafusion::execution::SessionStateBuilder; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::execute_stream; use datafusion::prelude::SessionContext; use datafusion_distributed::{ - DistributedExt, DistributedPhysicalOptimizerRule, Worker, WorkerQueryContext, WorkerResolver, - display_plan_ascii, + DistributedExt, DistributedPhysicalOptimizerRule, Worker, WorkerResolver, display_plan_ascii, }; use futures::{StreamExt, TryFutureExt}; use log::{error, info, warn}; @@ -64,19 +64,18 @@ async fn main() -> Result<(), Box> { .with_bucket_name(s3_url.host().unwrap().to_string()) .build()?, ); + let runtime_env = Arc::new(RuntimeEnv::default()); + runtime_env.register_object_store(&s3_url, s3); + let state = SessionStateBuilder::new() .with_default_features() - .with_object_store(&s3_url, Arc::clone(&s3) as _) + .with_runtime_env(Arc::clone(&runtime_env)) .with_distributed_worker_resolver(Ec2WorkerResolver::new()) .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)) .build(); let ctx = SessionContext::from(state); - let arrow_flight_endpoint = Worker::from_session_builder(move |ctx: WorkerQueryContext| { - let s3 = s3.clone(); - let s3_url = s3_url.clone(); - async move { Ok(ctx.builder.with_object_store(&s3_url, s3).build()) } - }); + let worker = Worker::default().with_runtime_env(runtime_env); let http_server = axum::serve( listener, Router::new().route( @@ -137,7 +136,7 @@ async fn main() -> Result<(), Box> { ), ); let grpc_server = Server::builder() - .add_service(arrow_flight_endpoint.into_flight_server()) + .add_service(worker.into_flight_server()) .serve(WORKER_ADDR.parse()?); info!("Started listener HTTP server in {LISTENER_ADDR}"); diff --git a/src/flight_service/worker.rs b/src/flight_service/worker.rs index ecd7f8b..2835540 100644 --- a/src/flight_service/worker.rs +++ b/src/flight_service/worker.rs @@ -57,6 +57,13 @@ impl Worker { } } + /// Sets a [RuntimeEnv] to be used in all the queries this [Worker] will handle during + /// its lifetime. + pub fn with_runtime_env(mut self, runtime_env: Arc) -> Self { + self.runtime = runtime_env; + self + } + /// Adds a callback for when an [ExecutionPlan] is received in the `do_get` call. /// /// The callback takes the plan and returns another plan that must be either the same,