Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clarifai/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,7 +1211,7 @@ def local_runner(ctx, model_path, pool_size, suppress_toolkit_logs, mode, keep_i
model_path = os.path.abspath(model_path)
_ensure_hf_token(ctx, model_path)
builder = ModelBuilder(model_path, download_validation_only=True)
manager = ModelRunLocally(model_path)
manager = ModelRunLocally(model_path, model_builder=builder)

port = 8080
if mode == "env":
Expand Down Expand Up @@ -1674,7 +1674,7 @@ def print_code_snippet():
else:
print_code_snippet()
# This reads the config.yaml from the model_path so we alter it above first.
server = ModelServer(model_path=model_path, model_runner_local=None)
server = ModelServer(model_path=model_path, model_runner_local=None, model_builder=builder)
server.serve(**serving_args)


Expand Down
24 changes: 17 additions & 7 deletions clarifai/runners/models/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def __init__(
platform: Optional[str] = None,
pat: Optional[str] = None,
base_url: Optional[str] = None,
compute_info_required: bool = False,
):
"""
:param folder: The folder containing the model.py, config.yaml, requirements.txt and
Expand All @@ -194,6 +195,7 @@ def __init__(
:param platform: Target platform(s) for Docker image build (e.g., "linux/amd64" or "linux/amd64,linux/arm64"). This overrides the platform specified in config.yaml.
:param pat: Personal access token for authentication. If None, will use environment variables.
:param base_url: Base URL for the API. If None, will use environment variables.
:param compute_info_required: Whether inference compute info is required. This affects certain validation and behavior.
"""
assert app_not_found_action in ["auto_create", "prompt", "error"], ValueError(
f"Expected one of {['auto_create', 'prompt', 'error']}, got {app_not_found_action=}"
Expand All @@ -214,7 +216,9 @@ def __init__(
self.model_proto = self._get_model_proto()
self.model_id = self.model_proto.id
self.model_version_id = None
self.inference_compute_info = self._get_inference_compute_info()
self.inference_compute_info = self._get_inference_compute_info(
compute_info_required=compute_info_required
)
self.is_v3 = True # Do model build for v3

def create_model_instance(self, load_model=True, mocking=False) -> ModelClass:
Expand Down Expand Up @@ -944,11 +948,12 @@ def _get_model_proto(self):

return model_proto

def _get_inference_compute_info(self):
assert "inference_compute_info" in self.config, (
"inference_compute_info not found in the config file"
)
inference_compute_info = self.config.get('inference_compute_info')
def _get_inference_compute_info(self, compute_info_required=False):
if compute_info_required:
assert "inference_compute_info" in self.config, (
"inference_compute_info not found in the config file"
)
inference_compute_info = self.config.get('inference_compute_info') or {}
# Ensure cpu_limit is a string if it exists and is an int
if 'cpu_limit' in inference_compute_info and isinstance(
inference_compute_info['cpu_limit'], int
Expand Down Expand Up @@ -1945,7 +1950,12 @@ def upload_model(
:param base_url: Base URL for the API. If None, will use environment variables.
"""
builder = ModelBuilder(
folder, app_not_found_action="prompt", platform=platform, pat=pat, base_url=base_url
folder,
app_not_found_action="prompt",
platform=platform,
pat=pat,
base_url=base_url,
compute_info_required=True,
)
builder.download_checkpoints(stage=stage)

Expand Down
8 changes: 6 additions & 2 deletions clarifai/runners/models/model_run_locally.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@


class ModelRunLocally:
def __init__(self, model_path):
def __init__(self, model_path, model_builder: ModelBuilder = None):
self.model_path = os.path.abspath(model_path)
self.requirements_file = os.path.join(self.model_path, "requirements.txt")

# ModelBuilder contains multiple useful methods to interact with the model
self.builder = ModelBuilder(self.model_path, download_validation_only=True)
self.builder = (
model_builder
if model_builder
else ModelBuilder(self.model_path, download_validation_only=True)
)
self.config = self.builder.config

def _get_method_signatures(self):
Expand Down
13 changes: 11 additions & 2 deletions clarifai/runners/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,12 @@ def main():


class ModelServer:
def __init__(self, model_path, model_runner_local: ModelRunLocally = None):
def __init__(
self,
model_path,
model_runner_local: ModelRunLocally = None,
model_builder: ModelBuilder = None,
):
"""Initialize the ModelServer.
Args:
model_path: Path to the model directory
Expand All @@ -158,7 +163,11 @@ def __init__(self, model_path, model_runner_local: ModelRunLocally = None):
self._initialize_secrets_system()

# Build model after secrets are loaded
self._builder = ModelBuilder(model_path, download_validation_only=True)
self._builder = (
model_builder
if model_builder
else ModelBuilder(model_path, download_validation_only=True)
)
self._current_model = self._builder.create_model_instance()

logger.info("ModelServer initialized successfully")
Expand Down
2 changes: 1 addition & 1 deletion tests/runners/test_model_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_model_uploader_flow(dummy_models_path, client):
7. Delete the deployment
"""
# Initialize
builder = ModelBuilder(folder=str(dummy_models_path))
builder = ModelBuilder(folder=str(dummy_models_path), compute_info_required=True)
assert builder.folder == str(dummy_models_path), "Uploader folder mismatch"

# Basic checks on config
Expand Down
Loading