diff --git a/torchci/components/benchmark/llms/components/LLMsSummaryPanel.tsx b/torchci/components/benchmark/llms/components/LLMsSummaryPanel.tsx
index 1a132a6a45..5c6773a348 100644
--- a/torchci/components/benchmark/llms/components/LLMsSummaryPanel.tsx
+++ b/torchci/components/benchmark/llms/components/LLMsSummaryPanel.tsx
@@ -1,4 +1,4 @@
-import { Grid2 } from "@mui/material";
+import { Grid2, styled, Tooltip } from "@mui/material";
import { GridCellParams, GridRenderCellParams } from "@mui/x-data-grid";
import styles from "components/metrics.module.css";
import { TablePanelWithData } from "components/metrics/panels/TablePanel";
@@ -12,6 +12,23 @@ import {
UNIT_FOR_METRIC,
} from "lib/benchmark/llms/common";
import { combineLeftAndRight } from "lib/benchmark/llms/utils/llmUtils";
+import { MdError } from "react-icons/md";
+import { VscError } from "react-icons/vsc";
+
+const FlexDiv = styled("div")({
+ display: "flex",
+ flexDirection: "row",
+ justifyContent: "flex-start",
+ alignItems: "center",
+});
+
+const FlexDivCenter = styled("div")({
+ display: "flex",
+ flexDirection: "row",
+ justifyContent: "center",
+ alignItems: "center",
+ margin: "3px",
+});
const getDeviceArch = (
device: string | undefined,
@@ -62,6 +79,7 @@ export default function LLMsSummaryPanel({
lPerfData,
rPerfData
);
+
const columns: any[] = [
{
field: "metadata",
@@ -121,9 +139,18 @@ export default function LLMsSummaryPanel({
? `${model} (${metadata.origins.join(",")})`
: model;
return (
-
- {displayName}
-
+
+ {params.row.FAILURE_REPORT && (
+
+ )}
+
+ {displayName}
+
+
);
},
},
@@ -233,6 +260,7 @@ export default function LLMsSummaryPanel({
return params.value;
},
},
+ // add all other metrics as columns
...metricNames
.filter((metric: string) => {
// TODO (huydhn): Just a temp fix, remove this after a few weeks
@@ -242,6 +270,9 @@ export default function LLMsSummaryPanel({
(metric !== "speedup" && metric !== "Speedup")
);
})
+ .filter((metric: string) => {
+ return metric !== "FAILURE_REPORT";
+ })
.map((metric: string) => {
return {
field: metric,
@@ -252,6 +283,12 @@ export default function LLMsSummaryPanel({
flex: 1,
cellClassName: (params: GridCellParams) => {
const v = params.value;
+
+ // If the row data has failure, we render it in grey color
+ if (params.row.FAILURE_REPORT) {
+ return styles.failure;
+ }
+
if (v === undefined) {
return "";
}
@@ -307,6 +344,9 @@ export default function LLMsSummaryPanel({
renderCell: (params: GridRenderCellParams) => {
const v = params.value;
if (v === undefined) {
+ if (params.row.FAILURE_REPORT) {
+ return "N/A";
+ }
return "";
}
@@ -329,6 +369,21 @@ export default function LLMsSummaryPanel({
const showTarget =
target && target != 0 ? `[target = ${target}]` : "";
+ // A Failure is detected for a model and backend
+ if (params.row.FAILURE_REPORT) {
+ return handleModelBackendFailure(
+ params.row,
+ lCommit,
+ rCommit,
+ unit,
+ showTarget,
+ l,
+ r,
+ lPercent,
+ rPercent
+ );
+ }
+
if (lCommit === rCommit || !v.highlight) {
return `${r}${unit} ${rPercent} ${showTarget}`;
} else {
@@ -359,3 +414,115 @@ export default function LLMsSummaryPanel({
);
}
+
+// handle failure report for a row.
+const handleModelBackendFailure = (
+ row: any,
+ lCommit: string,
+ rCommit: string,
+ unit: string,
+ showTarget: string,
+ lactual: number,
+ ractual: number,
+ lPercent: string,
+ rPercent: string
+) => {
+ const isLFailure =
+ row.FAILURE_REPORT?.l.actual == Number.MAX_SAFE_INTEGER ? true : false;
+ const isRFailure =
+ row.FAILURE_REPORT?.r.actual == Number.MAX_SAFE_INTEGER ? true : false;
+
+ // render the row's value in other metric columns
+ if (isLFailure && isRFailure) {
+ if (lCommit === rCommit) {
+ return (
+
+ );
+ }
+ return (
+
+
+ ;
+
+ );
+ } else if (isLFailure) {
+ return (
+
+
+ →
+
+ {ractual}
+ {unit}
+ {rPercent} {showTarget}
+
+
+ );
+ } else if (isRFailure) {
+ return (
+
+
+ {lactual}
+ {unit}
+ {lPercent}
+
+ →
+
+
+ );
+ }
+};
+
+const RenderWarningOnNameForFailure = ({
+ lCommit,
+ rCommit,
+ row,
+}: {
+ lCommit: string;
+ rCommit: string;
+ row: any;
+}) => {
+ const isLFailure =
+ row.FAILURE_REPORT?.l.actual == Number.MAX_SAFE_INTEGER ? true : false;
+ const isRFailure =
+ row.FAILURE_REPORT?.r.actual == Number.MAX_SAFE_INTEGER ? true : false;
+ // Indicate the failure details in Failure Report column
+ if (lCommit === rCommit) {
+ return (
+
+ );
+ }
+ if (isLFailure && isRFailure) {
+ return (
+
+ );
+ }
+ if (isLFailure) {
+ return (
+
+ );
+ }
+ if (isRFailure) {
+ return (
+
+ );
+ }
+ return <>>;
+};
+
+const FailureElementWithTooltip = ({ message = "" }) => (
+
+
+
+
+
+);
+
+const WarningElementWithTooltip = ({ message = "" }) => (
+
+
+
+
+
+
+
+);
diff --git a/torchci/components/metrics.module.css b/torchci/components/metrics.module.css
index 1a7d9ff914..bf775b2fab 100644
--- a/torchci/components/metrics.module.css
+++ b/torchci/components/metrics.module.css
@@ -8,6 +8,11 @@
color: var(--text-color);
}
+.failure {
+ background-color: var(--workflow-box-none-bg, lightgray);
+ color: var(--text-color);
+}
+
.error {
background-color: var(--workflow-box-fail-bg, lightpink);
color: var(--text-color);
diff --git a/torchci/lib/benchmark/llms/utils/llmUtils.ts b/torchci/lib/benchmark/llms/utils/llmUtils.ts
index b8ebc04b47..2c12f85bb2 100644
--- a/torchci/lib/benchmark/llms/utils/llmUtils.ts
+++ b/torchci/lib/benchmark/llms/utils/llmUtils.ts
@@ -95,6 +95,7 @@ export function combineLeftAndRight(
const rData = rPerfData.data;
const dataGroupedByModel: { [k: string]: any } = {};
+
rData.forEach((record: LLMsBenchmarkData) => {
const model = record.model;
const backend = record.backend;
@@ -144,33 +145,6 @@ export function combineLeftAndRight(
});
}
- // NB: This is a hack to keep track of valid devices. The problem is that the records
- // in the benchmark database alone don't have the information to differentiate between
- // benchmarks that are failed to run and benchmarks that are not run. Both show up as
- // 0 on the dashboard. Note that we can do a join with workflow_job table to get this
- // information, but it's a rather slow and expensive route
- const validDevices = new Set();
- const validBackends = new Set();
- // First round to get all the valid devices
- Object.keys(dataGroupedByModel).forEach((key: string) => {
- const [model, backend, mode, dtype, device, arch, extra] = key.split(";");
- const row: { [k: string]: any } = {
- // Keep the name as as the row ID as DataGrid requires it
- name: `${model} ${backend} (${mode} / ${dtype} / ${device} / ${arch})`,
- };
-
- for (const metric in dataGroupedByModel[key]) {
- const record = dataGroupedByModel[key][metric];
- const hasL = "l" in record;
- const hasR = "r" in record;
-
- if (hasL && hasR) {
- validDevices.add(device);
- validBackends.add(`${model} ${backend}`);
- }
- }
- });
-
// Transform the data into a displayable format
const data: { [k: string]: any }[] = [];
Object.keys(dataGroupedByModel).forEach((key: string) => {
@@ -182,27 +156,14 @@ export function combineLeftAndRight(
for (const metric in dataGroupedByModel[key]) {
const record = dataGroupedByModel[key][metric];
+
const hasL = "l" in record;
const hasR = "r" in record;
- // Skip devices and models that weren't run in this commit
- if (
- (validDevices.size !== 0 && !validDevices.has(device)) ||
- (validBackends.size !== 0 && !validBackends.has(`${model} ${backend}`))
- ) {
- continue;
- }
-
- // No overlapping between left and right commits, just show what it's on the
- // right commit instead of showing a blank page
- if (!hasR) {
- continue;
- }
-
if (!("metadata" in row)) {
row["metadata"] = {
model: model,
- origins: record["r"].origins,
+ origins: hasR ? record["r"].origins : [],
backend: backend,
mode: mode,
dtype: dtype,
@@ -280,32 +241,51 @@ export function combineLeftAndRight(
const extraInfo = JSON.parse(extra);
row["is_dynamic"] = extraInfo["is_dynamic"];
}
-
- row[metric] = {
- l: hasL
- ? {
- actual: record["l"].actual,
- target: record["l"].target,
- }
- : {
- actual: 0,
- target: 0,
- },
- r: hasR
- ? {
- actual: record["r"].actual,
- target: record["r"].target,
- }
- : {
- actual: 0,
- target: 0,
- },
- highlight:
- validDevices.size !== 0 &&
- validBackends.has(`${model} ${backend}`) &&
- hasL &&
- hasR,
- };
+ if (metric == "FAILURE_REPORT") {
+ row[metric] = {
+ l: hasL
+ ? {
+ actual: Number.MAX_SAFE_INTEGER, // indicate the failure on left side
+ target: 0,
+ }
+ : {
+ actual: 0,
+ target: 0,
+ },
+ r: hasR
+ ? {
+ actual: Number.MAX_SAFE_INTEGER, // indicate the failure on right side
+ target: 0,
+ }
+ : {
+ actual: 0,
+ target: 0,
+ },
+ highlight: hasL && hasR,
+ };
+ } else {
+ row[metric] = {
+ l: hasL
+ ? {
+ actual: record["l"].actual,
+ target: record["l"].target,
+ }
+ : {
+ actual: 0,
+ target: 0,
+ },
+ r: hasR
+ ? {
+ actual: record["r"].actual,
+ target: record["r"].target,
+ }
+ : {
+ actual: 0,
+ target: 0,
+ },
+ highlight: hasL && hasR,
+ };
+ }
}
if ("metadata" in row) {