Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 4 additions & 40 deletions expense_estimation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from typing import Literal

import numpy as np


VERTEX_GPU_FACTOR = 1e-11
from utils import estimate_memory_usage, VERTEX_GPU_FACTOR


def estimate_duration(
Expand All @@ -17,39 +14,6 @@ def estimate_duration(
"""
Estimates the duration of a prediction task.
"""

# Logic comes from _estimate_model_usage in base.py of the TabPFN codebase.
CONSTANT_COMPUTE_OVERHEAD = 8000
NUM_SAMPLES_FACTOR = 4
NUM_SAMPLES_PLUS_FEATURES = 6.5
CELLS_FACTOR = 0.25
CELLS_SQUARED_FACTOR = 1.3e-7

EMBEDDING_SIZE = 192
NUM_HEADS = 6
NUM_LAYERS = 12
FEATURES_PER_GROUP = 2

n_estimators = tabpfn_config.get(
"n_estimators", 4 if task == "classification" else 8
)

num_samples = num_rows
num_feature_groups = int(np.ceil(num_features / FEATURES_PER_GROUP))

num_cells = (num_feature_groups + 1) * num_samples
compute_cost = (EMBEDDING_SIZE**2) * NUM_HEADS * NUM_LAYERS

base_duration = (
n_estimators
* compute_cost
* (
CONSTANT_COMPUTE_OVERHEAD
+ num_samples * NUM_SAMPLES_FACTOR
+ (num_samples + num_feature_groups) * NUM_SAMPLES_PLUS_FEATURES
+ num_cells * CELLS_FACTOR
+ num_cells**2 * CELLS_SQUARED_FACTOR
)
)

return round(base_duration * duration_factor + latency_offset, 3)
base_memory_usage = estimate_memory_usage(num_rows, num_features, task, tabpfn_config)

return round(base_memory_usage * duration_factor + latency_offset, 3)
42 changes: 42 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,48 @@
from typing_extensions import override


VERTEX_GPU_FACTOR = 1e-11

def estimate_memory_usage(num_rows: int, num_features: int, task: str, tabpfn_config: dict = {}) -> float:
"""
Estimates the memory usage for a prediction task.
"""
CONSTANT_COMPUTE_OVERHEAD = 8000
NUM_SAMPLES_FACTOR = 4
NUM_SAMPLES_PLUS_FEATURES = 6.5
CELLS_FACTOR = 0.25
CELLS_SQUARED_FACTOR = 1.3e-7

EMBEDDING_SIZE = 192
NUM_HEADS = 6
NUM_LAYERS = 12
FEATURES_PER_GROUP = 2

n_estimators = tabpfn_config.get(
"n_estimators", 4 if task == "classification" else 8
)

num_samples = num_rows
num_feature_groups = int(np.ceil(num_features / FEATURES_PER_GROUP))

num_cells = (num_feature_groups + 1) * num_samples
compute_cost = (EMBEDDING_SIZE**2) * NUM_HEADS * NUM_LAYERS

base_memory_usage = (
n_estimators
* compute_cost
* (
CONSTANT_COMPUTE_OVERHEAD
+ num_samples * NUM_SAMPLES_FACTOR
+ (num_samples + num_feature_groups) * NUM_SAMPLES_PLUS_FEATURES
+ num_cells * CELLS_FACTOR
+ num_cells**2 * CELLS_SQUARED_FACTOR
)
)

return base_memory_usage


def serialize_to_csv_formatted_bytes(
data: typing.Union[pd.DataFrame, pd.Series, np.ndarray],
) -> bytes:
Expand Down