Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions tabularpriors/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import torch

def get_ticl_prior_config(prior_type: str) -> dict:
"""Return the default kwargs for MLPPrior or GPPrior."""
def get_ticl_prior_config(prior_type: str, max_num_classes: int = None) -> dict:
"""Return the default kwargs for MLPPrior, GPPrior, or classification priors."""

if prior_type == "mlp":
return {
Expand Down Expand Up @@ -33,5 +33,34 @@ def get_ticl_prior_config(prior_type: str) -> dict:
"outputscale": 1.0,
"lengthscale": 0.2,
}
elif prior_type == "classification_adapter":
return {
"max_num_classes": 50, # this is a global upper bound for how many output classes the generator or model can handle;
# setting it > 0 keeps classification mode active (0 = regression),
# and actual tasks will use up to this many classes depending on num_classes and random sampling
"num_classes": max_num_classes,
"balanced": False,
"output_multiclass_ordered_p": 0.1,
"multiclass_type": "rank",
"categorical_feature_p": 0.15,
"nan_prob_no_reason": 0.05,
"nan_prob_a_reason": 0.03,
"set_value_to_nan": 0.9,
"num_features_sampler": "uniform",
"pad_zeros": False,
"feature_curriculum": False,
}
elif prior_type == "boolean_conjunctions":
return {
'max_rank': 20,
'max_fraction_uninformative': 0.3,
'p_uninformative': 0.3,
'verbose': False
}
elif prior_type == "step_function":
return {
"max_steps": 1,
"sampling": "uniform",
}
else:
raise ValueError(f"Unsupported TICL prior type: {prior_type}")
10 changes: 7 additions & 3 deletions tabularpriors/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def main():
parser.add_argument("--num_batches", type=int, default=100, help="Number of batches to dump.")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for dumping.")
parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to run prior sampling on.")
parser.add_argument("--prior_type", type=str, default="mlp", choices=["mlp", "gp"], help="Which TICL prior to use.")
parser.add_argument("--prior_type", type=str, default="mlp", choices=["mlp", "gp", "classification_adapter", "boolean_conjunctions", "step_function"], help="Which TICL prior to use.")
parser.add_argument("--base_prior_type", type=str, default="mlp", choices=["mlp", "gp"], help="Base regression prior for classification_adapter.")
parser.add_argument("--min_features", type=int, default=1, help="Minimum number of input features.")
parser.add_argument("--max_features", type=int, default=100, help="Maximum number of input features.")
parser.add_argument("--min_seq_len", type=int, default=None, help="Minimum number of data points per function.")
Expand All @@ -42,16 +43,19 @@ def main():
args.save_path = f"prior_{args.lib}{prior_name}_{args.num_batches}x{args.batch_size}_{args.max_seq_len}x{args.max_features}.h5"

if args.lib == "ticl":
# determine if this is a classification prior
is_classification_prior = args.prior_type in ["classification_adapter", "boolean_conjunctions", "step_function"]

prior = TICLPriorDataLoader(
prior=build_ticl_prior(args.prior_type),
prior=build_ticl_prior(args.prior_type, args.base_prior_type, args.max_classes),
num_steps=args.num_batches,
batch_size=args.batch_size,
num_datapoints_max=args.max_seq_len,
num_features=args.max_features,
device=device,
min_eval_pos=args.min_eval_pos,
)
problem_type = "regression"
problem_type = "classification" if is_classification_prior else "regression"
else:
if args.min_seq_len == args.max_seq_len:
args.min_seq_len = None # TabICL prior requires min_seq_len < max_seq_len
Expand Down
19 changes: 15 additions & 4 deletions tabularpriors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,30 @@
import h5py
import numpy as np
import torch
from ticl.priors import GPPrior, MLPPrior
from ticl.priors import GPPrior, MLPPrior, ClassificationAdapterPrior, BooleanConjunctionPrior, StepFunctionPrior

from .config import get_ticl_prior_config


def build_ticl_prior(prior_type: str) -> Union[MLPPrior, GPPrior]:
"""Builds a TICL prior (MLP or GP) based on the prior type string using the defaults in config.py."""
def build_ticl_prior(prior_type: str, base_prior_type: str = None, max_num_classes: int = None) -> Union[MLPPrior, GPPrior, ClassificationAdapterPrior, BooleanConjunctionPrior, StepFunctionPrior]:
"""Builds a TICL prior based on the prior type string using the defaults in config.py."""

cfg = get_ticl_prior_config(prior_type)
cfg = get_ticl_prior_config(prior_type, max_num_classes)

if prior_type == "mlp":
return MLPPrior(cfg)
elif prior_type == "gp":
return GPPrior(cfg)
elif prior_type == "classification_adapter":
if base_prior_type is None:
base_prior_type = "mlp" # default to MLP
# build the base regression prior
base_prior = build_ticl_prior(base_prior_type)
return ClassificationAdapterPrior(base_prior, **cfg)
elif prior_type == "boolean_conjunctions":
return BooleanConjunctionPrior(hyperparameters=cfg)
elif prior_type == "step_function":
return StepFunctionPrior(cfg)
else:
raise ValueError(f"Unsupported TICL prior type: {prior_type}")

Expand Down
456 changes: 432 additions & 24 deletions visualization_demo.ipynb

Large diffs are not rendered by default.