diff --git a/torchci/components/benchmark/llms/components/LLMsSummaryPanel.tsx b/torchci/components/benchmark/llms/components/LLMsSummaryPanel.tsx index 1a132a6a45..5c6773a348 100644 --- a/torchci/components/benchmark/llms/components/LLMsSummaryPanel.tsx +++ b/torchci/components/benchmark/llms/components/LLMsSummaryPanel.tsx @@ -1,4 +1,4 @@ -import { Grid2 } from "@mui/material"; +import { Grid2, styled, Tooltip } from "@mui/material"; import { GridCellParams, GridRenderCellParams } from "@mui/x-data-grid"; import styles from "components/metrics.module.css"; import { TablePanelWithData } from "components/metrics/panels/TablePanel"; @@ -12,6 +12,23 @@ import { UNIT_FOR_METRIC, } from "lib/benchmark/llms/common"; import { combineLeftAndRight } from "lib/benchmark/llms/utils/llmUtils"; +import { MdError } from "react-icons/md"; +import { VscError } from "react-icons/vsc"; + +const FlexDiv = styled("div")({ + display: "flex", + flexDirection: "row", + justifyContent: "flex-start", + alignItems: "center", +}); + +const FlexDivCenter = styled("div")({ + display: "flex", + flexDirection: "row", + justifyContent: "center", + alignItems: "center", + margin: "3px", +}); const getDeviceArch = ( device: string | undefined, @@ -62,6 +79,7 @@ export default function LLMsSummaryPanel({ lPerfData, rPerfData ); + const columns: any[] = [ { field: "metadata", @@ -121,9 +139,18 @@ export default function LLMsSummaryPanel({ ? `${model} (${metadata.origins.join(",")})` : model; return ( - - {displayName} - + + {params.row.FAILURE_REPORT && ( + + )} + + {displayName} + + ); }, }, @@ -233,6 +260,7 @@ export default function LLMsSummaryPanel({ return params.value; }, }, + // add all other metrics as columns ...metricNames .filter((metric: string) => { // TODO (huydhn): Just a temp fix, remove this after a few weeks @@ -242,6 +270,9 @@ export default function LLMsSummaryPanel({ (metric !== "speedup" && metric !== "Speedup") ); }) + .filter((metric: string) => { + return metric !== "FAILURE_REPORT"; + }) .map((metric: string) => { return { field: metric, @@ -252,6 +283,12 @@ export default function LLMsSummaryPanel({ flex: 1, cellClassName: (params: GridCellParams) => { const v = params.value; + + // If the row data has failure, we render it in grey color + if (params.row.FAILURE_REPORT) { + return styles.failure; + } + if (v === undefined) { return ""; } @@ -307,6 +344,9 @@ export default function LLMsSummaryPanel({ renderCell: (params: GridRenderCellParams) => { const v = params.value; if (v === undefined) { + if (params.row.FAILURE_REPORT) { + return "N/A"; + } return ""; } @@ -329,6 +369,21 @@ export default function LLMsSummaryPanel({ const showTarget = target && target != 0 ? `[target = ${target}]` : ""; + // A Failure is detected for a model and backend + if (params.row.FAILURE_REPORT) { + return handleModelBackendFailure( + params.row, + lCommit, + rCommit, + unit, + showTarget, + l, + r, + lPercent, + rPercent + ); + } + if (lCommit === rCommit || !v.highlight) { return `${r}${unit} ${rPercent} ${showTarget}`; } else { @@ -359,3 +414,115 @@ export default function LLMsSummaryPanel({ ); } + +// handle failure report for a row. +const handleModelBackendFailure = ( + row: any, + lCommit: string, + rCommit: string, + unit: string, + showTarget: string, + lactual: number, + ractual: number, + lPercent: string, + rPercent: string +) => { + const isLFailure = + row.FAILURE_REPORT?.l.actual == Number.MAX_SAFE_INTEGER ? true : false; + const isRFailure = + row.FAILURE_REPORT?.r.actual == Number.MAX_SAFE_INTEGER ? true : false; + + // render the row's value in other metric columns + if (isLFailure && isRFailure) { + if (lCommit === rCommit) { + return ( + + ); + } + return ( +
+ + ; +
+ ); + } else if (isLFailure) { + return ( + + + + + {ractual} + {unit} + {rPercent} {showTarget} + + + ); + } else if (isRFailure) { + return ( + + + {lactual} + {unit} + {lPercent} + + + + + ); + } +}; + +const RenderWarningOnNameForFailure = ({ + lCommit, + rCommit, + row, +}: { + lCommit: string; + rCommit: string; + row: any; +}) => { + const isLFailure = + row.FAILURE_REPORT?.l.actual == Number.MAX_SAFE_INTEGER ? true : false; + const isRFailure = + row.FAILURE_REPORT?.r.actual == Number.MAX_SAFE_INTEGER ? true : false; + // Indicate the failure details in Failure Report column + if (lCommit === rCommit) { + return ( + + ); + } + if (isLFailure && isRFailure) { + return ( + + ); + } + if (isLFailure) { + return ( + + ); + } + if (isRFailure) { + return ( + + ); + } + return <>; +}; + +const FailureElementWithTooltip = ({ message = "" }) => ( + +
+ +
+
+); + +const WarningElementWithTooltip = ({ message = "" }) => ( + + +
+ +
+
+
+); diff --git a/torchci/components/metrics.module.css b/torchci/components/metrics.module.css index 1a7d9ff914..bf775b2fab 100644 --- a/torchci/components/metrics.module.css +++ b/torchci/components/metrics.module.css @@ -8,6 +8,11 @@ color: var(--text-color); } +.failure { + background-color: var(--workflow-box-none-bg, lightgray); + color: var(--text-color); +} + .error { background-color: var(--workflow-box-fail-bg, lightpink); color: var(--text-color); diff --git a/torchci/lib/benchmark/llms/utils/llmUtils.ts b/torchci/lib/benchmark/llms/utils/llmUtils.ts index b8ebc04b47..2c12f85bb2 100644 --- a/torchci/lib/benchmark/llms/utils/llmUtils.ts +++ b/torchci/lib/benchmark/llms/utils/llmUtils.ts @@ -95,6 +95,7 @@ export function combineLeftAndRight( const rData = rPerfData.data; const dataGroupedByModel: { [k: string]: any } = {}; + rData.forEach((record: LLMsBenchmarkData) => { const model = record.model; const backend = record.backend; @@ -144,33 +145,6 @@ export function combineLeftAndRight( }); } - // NB: This is a hack to keep track of valid devices. The problem is that the records - // in the benchmark database alone don't have the information to differentiate between - // benchmarks that are failed to run and benchmarks that are not run. Both show up as - // 0 on the dashboard. Note that we can do a join with workflow_job table to get this - // information, but it's a rather slow and expensive route - const validDevices = new Set(); - const validBackends = new Set(); - // First round to get all the valid devices - Object.keys(dataGroupedByModel).forEach((key: string) => { - const [model, backend, mode, dtype, device, arch, extra] = key.split(";"); - const row: { [k: string]: any } = { - // Keep the name as as the row ID as DataGrid requires it - name: `${model} ${backend} (${mode} / ${dtype} / ${device} / ${arch})`, - }; - - for (const metric in dataGroupedByModel[key]) { - const record = dataGroupedByModel[key][metric]; - const hasL = "l" in record; - const hasR = "r" in record; - - if (hasL && hasR) { - validDevices.add(device); - validBackends.add(`${model} ${backend}`); - } - } - }); - // Transform the data into a displayable format const data: { [k: string]: any }[] = []; Object.keys(dataGroupedByModel).forEach((key: string) => { @@ -182,27 +156,14 @@ export function combineLeftAndRight( for (const metric in dataGroupedByModel[key]) { const record = dataGroupedByModel[key][metric]; + const hasL = "l" in record; const hasR = "r" in record; - // Skip devices and models that weren't run in this commit - if ( - (validDevices.size !== 0 && !validDevices.has(device)) || - (validBackends.size !== 0 && !validBackends.has(`${model} ${backend}`)) - ) { - continue; - } - - // No overlapping between left and right commits, just show what it's on the - // right commit instead of showing a blank page - if (!hasR) { - continue; - } - if (!("metadata" in row)) { row["metadata"] = { model: model, - origins: record["r"].origins, + origins: hasR ? record["r"].origins : [], backend: backend, mode: mode, dtype: dtype, @@ -280,32 +241,51 @@ export function combineLeftAndRight( const extraInfo = JSON.parse(extra); row["is_dynamic"] = extraInfo["is_dynamic"]; } - - row[metric] = { - l: hasL - ? { - actual: record["l"].actual, - target: record["l"].target, - } - : { - actual: 0, - target: 0, - }, - r: hasR - ? { - actual: record["r"].actual, - target: record["r"].target, - } - : { - actual: 0, - target: 0, - }, - highlight: - validDevices.size !== 0 && - validBackends.has(`${model} ${backend}`) && - hasL && - hasR, - }; + if (metric == "FAILURE_REPORT") { + row[metric] = { + l: hasL + ? { + actual: Number.MAX_SAFE_INTEGER, // indicate the failure on left side + target: 0, + } + : { + actual: 0, + target: 0, + }, + r: hasR + ? { + actual: Number.MAX_SAFE_INTEGER, // indicate the failure on right side + target: 0, + } + : { + actual: 0, + target: 0, + }, + highlight: hasL && hasR, + }; + } else { + row[metric] = { + l: hasL + ? { + actual: record["l"].actual, + target: record["l"].target, + } + : { + actual: 0, + target: 0, + }, + r: hasR + ? { + actual: record["r"].actual, + target: record["r"].target, + } + : { + actual: 0, + target: 0, + }, + highlight: hasL && hasR, + }; + } } if ("metadata" in row) {