Skip to content

Commit cc80645

Browse files
authored
CLI related fixes and improvements (#85)
* CLI - show image after generation w/ PIL * CLI rename fine-tune -> fine-tuning for consistency with API * Other CLI updates & fixes
1 parent c9a524b commit cc80645

File tree

9 files changed

+45
-37
lines changed

9 files changed

+45
-37
lines changed

src/together/abstract/api_requestor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,9 @@ def handle_error_response(
338338
try:
339339
assert isinstance(resp.data, dict)
340340
error_resp = resp.data.get("error")
341-
assert isinstance(error_resp, dict)
341+
assert isinstance(
342+
error_resp, dict
343+
), f"Unexpected error response {error_resp}"
342344
error_data = TogetherErrorResponse(**(error_resp))
343345
except (KeyError, TypeError):
344346
raise error.JSONError(

src/together/cli/api/chat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,11 @@ def do_say(self, arg: str) -> None:
8080

8181
token = chunk.choices[0].delta.content
8282

83-
print(token, end="", flush=True)
83+
click.echo(token, nl=False)
8484

8585
output += token
8686

87-
print("\n")
87+
click.echo("\n")
8888

8989
self.messages.append({"role": "assistant", "content": output})
9090

@@ -239,7 +239,7 @@ def chat(
239239
for i, choice in enumerate(response.choices):
240240
if should_print_header:
241241
click.echo(f"===== Completion {i} =====")
242-
click.echo(json.dumps(choice.message, indent=4))
242+
click.echo(choice.message.content) # type: ignore
243243

244244
if should_print_header:
245245
click.echo("\n")

src/together/cli/api/completions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def completions(
7878
# assertions for type checking
7979
assert isinstance(stream_choice, CompletionChoicesChunk)
8080
assert stream_choice.delta
81-
assert stream_choice.delta.content
8281

8382
if should_print_header:
8483
click.echo(f"\n===== Completion {stream_choice.index} =====\n")

src/together/cli/api/finetune.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88
from together.utils import finetune_price_to_dollars, parse_timestamp
99

1010

11-
@click.group(name="fine-tune")
11+
@click.group(name="fine-tuning")
1212
@click.pass_context
13-
def fine_tune(ctx: click.Context) -> None:
13+
def fine_tuning(ctx: click.Context) -> None:
1414
"""Fine-tunes API commands"""
1515
pass
1616

1717

18-
@fine_tune.command()
18+
@fine_tuning.command()
1919
@click.pass_context
2020
@click.option(
21-
"--training_file", type=str, required=True, help="Training file ID from Files API"
21+
"--training-file", type=str, required=True, help="Training file ID from Files API"
2222
)
2323
@click.option("--model", type=str, required=True, help="Base model name")
2424
@click.option("--n-epochs", type=int, default=1, help="Number of epochs to train for")
@@ -30,7 +30,7 @@ def fine_tune(ctx: click.Context) -> None:
3030
@click.option(
3131
"--suffix", type=str, default=None, help="Suffix for the fine-tuned model name"
3232
)
33-
@click.option("--wandb-api-key", prompt=True, hide_input=True, help="Wandb API key")
33+
@click.option("--wandb-api-key", type=str, default=None, help="Wandb API key")
3434
def create(
3535
ctx: click.Context,
3636
training_file: str,
@@ -59,10 +59,10 @@ def create(
5959
click.echo(json.dumps(response.model_dump(), indent=4))
6060

6161

62-
@fine_tune.command()
62+
@fine_tuning.command()
6363
@click.pass_context
6464
def list(ctx: click.Context) -> None:
65-
"""List fine-tuning tasks"""
65+
"""List fine-tuning jobs"""
6666
client: Together = ctx.obj
6767

6868
response = client.fine_tuning.list()
@@ -89,38 +89,34 @@ def list(ctx: click.Context) -> None:
8989
click.echo(table)
9090

9191

92-
@fine_tune.command()
92+
@fine_tuning.command()
9393
@click.pass_context
9494
@click.argument("fine_tune_id", type=str, required=True)
9595
def retrieve(ctx: click.Context, fine_tune_id: str) -> None:
96-
"""Retrieve fine-tuning task"""
96+
"""Retrieve fine-tuning job details"""
9797
client: Together = ctx.obj
9898

9999
response = client.fine_tuning.retrieve(fine_tune_id)
100100

101-
table_data = [
102-
{"Key": key, "Value": value}
103-
for key, value in response.model_dump().items()
104-
if key not in ["events"]
105-
]
106-
table = tabulate(table_data, tablefmt="grid")
101+
# remove events from response for cleaner output
102+
response.events = None
107103

108-
click.echo(table)
104+
click.echo(json.dumps(response.model_dump(), indent=4))
109105

110106

111-
@fine_tune.command()
107+
@fine_tuning.command()
112108
@click.pass_context
113109
@click.argument("fine_tune_id", type=str, required=True)
114110
def cancel(ctx: click.Context, fine_tune_id: str) -> None:
115-
"""Cancel fine-tuning task"""
111+
"""Cancel fine-tuning job"""
116112
client: Together = ctx.obj
117113

118114
response = client.fine_tuning.cancel(fine_tune_id)
119115

120116
click.echo(json.dumps(response.model_dump(), indent=4))
121117

122118

123-
@fine_tune.command()
119+
@fine_tuning.command()
124120
@click.pass_context
125121
@click.argument("fine_tune_id", type=str, required=True)
126122
def list_events(ctx: click.Context, fine_tune_id: str) -> None:
@@ -135,7 +131,7 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None:
135131
for i in response.data:
136132
display_list.append(
137133
{
138-
"Message": i.message,
134+
"Message": "\n".join(wrap(i.message or "", width=50)),
139135
"Type": i.type,
140136
"Created At": parse_timestamp(i.created_at or ""),
141137
"Hash": i.hash,
@@ -146,14 +142,14 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None:
146142
click.echo(table)
147143

148144

149-
@fine_tune.command()
145+
@fine_tuning.command()
150146
@click.pass_context
151147
@click.argument("fine_tune_id", type=str, required=True)
152148
@click.option(
153149
"--output_dir",
154150
type=click.Path(exists=True, file_okay=False, resolve_path=True),
155151
required=False,
156-
default=".",
152+
default=None,
157153
help="Output directory",
158154
)
159155
@click.option(

src/together/cli/api/images.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pathlib
33

44
import click
5+
from PIL import Image
56

67
from together import Together
78
from together.types import ImageResponse
@@ -33,6 +34,7 @@ def images(ctx: click.Context) -> None:
3334
help="Output directory",
3435
)
3536
@click.option("--prefix", type=str, required=False, default="image-")
37+
@click.option("--no-show", is_flag=True, help="Do not open images in viewer")
3638
def generate(
3739
ctx: click.Context,
3840
prompt: str,
@@ -45,6 +47,7 @@ def generate(
4547
negative_prompt: str,
4648
output: pathlib.Path,
4749
prefix: str,
50+
no_show: bool,
4851
) -> None:
4952
"""Generate image"""
5053

@@ -63,8 +66,17 @@ def generate(
6366

6467
assert isinstance(response, ImageResponse)
6568
assert isinstance(response.data, list)
66-
for choice in response.data:
69+
70+
for i, choice in enumerate(response.data):
6771
assert isinstance(choice, ImageChoicesData)
6872

6973
with open(f"{output}/{prefix}{choice.index}.png", "wb") as f:
7074
f.write(base64.b64decode(choice.b64_json))
75+
76+
click.echo(
77+
f"Image [{i + 1}/{len(response.data)}] saved to {output}/{prefix}{choice.index}.png"
78+
)
79+
80+
if not no_show:
81+
image = Image.open(f"{output}/{prefix}{choice.index}.png")
82+
image.show()

src/together/cli/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from together.cli.api.chat import chat, interactive
1010
from together.cli.api.completions import completions
1111
from together.cli.api.files import files
12-
from together.cli.api.finetune import fine_tune
12+
from together.cli.api.finetune import fine_tuning
1313
from together.cli.api.images import images
1414
from together.cli.api.models import models
1515
from together.constants import MAX_RETRIES, TIMEOUT_SECS
@@ -70,7 +70,7 @@ def main(
7070
main.add_command(completions)
7171
main.add_command(images)
7272
main.add_command(files)
73-
main.add_command(fine_tune)
73+
main.add_command(fine_tuning)
7474
main.add_command(models)
7575

7676
if __name__ == "__main__":

src/together/filemanager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,19 @@ def chmod_and_replace(src: Path, dst: Path) -> None:
3636
"""
3737

3838
# Get umask by creating a temporary file in the cache folder.
39-
tmp_file = dst.parent.parent / f"tmp_{uuid.uuid4()}"
39+
tmp_file = dst.parent / f"tmp_{uuid.uuid4()}"
4040

4141
try:
4242
tmp_file.touch()
4343

4444
cache_dir_mode = Path(tmp_file).stat().st_mode
4545

46-
os.chmod(src, stat.S_IMODE(cache_dir_mode))
46+
os.chmod(src.as_posix(), stat.S_IMODE(cache_dir_mode))
4747

4848
finally:
4949
tmp_file.unlink()
5050

51-
shutil.move(src, dst)
51+
shutil.move(src.as_posix(), dst.as_posix())
5252

5353

5454
def _get_file_size(
@@ -237,7 +237,7 @@ def download(
237237

238238
os.remove(lock_path)
239239

240-
return file_path.as_posix(), file_size
240+
return str(file_path.resolve()), file_size
241241

242242

243243
class UploadManager:

src/together/resources/finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def download(
219219
output = Path(output)
220220

221221
downloaded_filename, file_size = download_manager.download(
222-
url, output, normalize_key(remote_name or ""), fetch_metadata=True
222+
url, output, normalize_key(remote_name or id), fetch_metadata=True
223223
)
224224

225225
return FinetuneDownloadResult(

src/together/types/finetune.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
from enum import Enum
4-
from pathlib import Path
54
from typing import List, Literal
65

76
from pydantic import Field
@@ -202,6 +201,6 @@ class FinetuneDownloadResult(BaseModel):
202201
# checkpoint step number
203202
checkpoint_step: int | None = None
204203
# local path filename
205-
filename: Path | None = None
204+
filename: str | None = None
206205
# size in bytes
207206
size: int | None = None

0 commit comments

Comments
 (0)