Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"use client";

import { BranchPickerPrimitive, useMessage } from "@assistant-ui/react";
import { BranchPickerPrimitive, useAuiState } from "@assistant-ui/react";
import { ChevronLeftIcon, ChevronRightIcon } from "@radix-ui/react-icons";
import {
DEFAULT_BUTTON_CLASSNAME,
Expand Down Expand Up @@ -42,11 +42,8 @@ export function BranchNavigation({
hideWhenSingleBranch?: boolean;
renderCounter?: (current: number, total: number) => React.ReactNode;
} = {}) {
// TODO: migrate to store — useMessage to useAuiState(({ message }) => message)
const message = useMessage();

const branchCount = message.branchCount ?? 1;
const branchNumber = message.branchNumber ?? 1;
const branchCount = useAuiState((s) => s.message.branchCount ?? 1);
const branchNumber = useAuiState((s) => s.message.branchNumber ?? 1);

// Hide when single branch (if enabled)
if (hideWhenSingleBranch && branchCount <= 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen } from "@testing-library/react";

vi.mock("@assistant-ui/react", () => ({
useMessage: vi.fn(),
useAuiState: vi.fn(),
BranchPickerPrimitive: {
Root: ({ children, className }: any) => (
<div data-testid="branch-root" className={className}>
Expand All @@ -15,48 +15,58 @@ vi.mock("@assistant-ui/react", () => ({
}));

import { BranchNavigation } from "../BranchNavigation";
import { useMessage } from "@assistant-ui/react";
import { useAuiState } from "@assistant-ui/react";

const mockMessage = useMessage as ReturnType<typeof vi.fn>;
const mockUseAuiState = useAuiState as ReturnType<typeof vi.fn>;

function setMockState(message: {
branchCount?: number;
branchNumber?: number;
}) {
const state = { message };
mockUseAuiState.mockImplementation((selector: (s: any) => any) =>
selector(state),
);
}

describe("BranchNavigation", () => {
beforeEach(() => {
vi.clearAllMocks();
});

it("returns null when single branch and hideWhenSingleBranch is true", () => {
mockMessage.mockReturnValue({ branchCount: 1, branchNumber: 1 });
setMockState({ branchCount: 1, branchNumber: 1 });
const { container } = render(<BranchNavigation />);
expect(container.innerHTML).toBe("");
});

it("renders when multiple branches exist", () => {
mockMessage.mockReturnValue({ branchCount: 3, branchNumber: 2 });
setMockState({ branchCount: 3, branchNumber: 2 });
render(<BranchNavigation />);
expect(screen.getByTestId("branch-root")).toBeInTheDocument();
});

it("shows counter with branch number and count", () => {
mockMessage.mockReturnValue({ branchCount: 3, branchNumber: 2 });
setMockState({ branchCount: 3, branchNumber: 2 });
render(<BranchNavigation />);
expect(screen.getByText("2 / 3")).toBeInTheDocument();
});

it("renders when single branch and hideWhenSingleBranch is false", () => {
mockMessage.mockReturnValue({ branchCount: 1, branchNumber: 1 });
setMockState({ branchCount: 1, branchNumber: 1 });
render(<BranchNavigation hideWhenSingleBranch={false} />);
expect(screen.getByTestId("branch-root")).toBeInTheDocument();
});

it("renders prev and next buttons", () => {
mockMessage.mockReturnValue({ branchCount: 3, branchNumber: 2 });
setMockState({ branchCount: 3, branchNumber: 2 });
render(<BranchNavigation />);
expect(screen.getByLabelText("Previous branch")).toBeInTheDocument();
expect(screen.getByLabelText("Next branch")).toBeInTheDocument();
});

it("uses custom renderCounter", () => {
mockMessage.mockReturnValue({ branchCount: 5, branchNumber: 3 });
setMockState({ branchCount: 5, branchNumber: 3 });
render(
<BranchNavigation
renderCounter={(current, total) => (
Expand All @@ -70,13 +80,13 @@ describe("BranchNavigation", () => {
});

it("uses custom className", () => {
mockMessage.mockReturnValue({ branchCount: 2, branchNumber: 1 });
setMockState({ branchCount: 2, branchNumber: 1 });
render(<BranchNavigation className="custom-root" />);
expect(screen.getByTestId("branch-root").className).toBe("custom-root");
});

it("defaults branchCount and branchNumber to 1 when undefined", () => {
mockMessage.mockReturnValue({});
setMockState({});
const { container } = render(<BranchNavigation />);
expect(container.innerHTML).toBe("");
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
ComposerActionState,
deriveComposerActionState,
} from "./deriveComposerActionState";
import { ComposerPrimitive, useComposer, useThread } from "@assistant-ui/react";
import { ComposerPrimitive, useAuiState } from "@assistant-ui/react";

export function ComposerActionStatus({
/**
Expand All @@ -36,16 +36,13 @@ export function ComposerActionStatus({
idleButtonClassName?: string;
renderVisual?: (state: ComposerActionState) => React.ReactNode;
}) {
// TODO: migrate to store — useThread to useAuiState(({ thread }) => thread)
const thread = useThread();
// TODO: migrate to store — useComposer to useAuiState(({ composer }) => composer)
const composer = useComposer();

const state = deriveComposerActionState({
isRunning: thread.isRunning,
isEditing: composer.isEditing,
isEmpty: composer.isEmpty,
});
const state = useAuiState((s) =>
deriveComposerActionState({
isRunning: s.thread.isRunning,
isEditing: s.composer.isEditing,
isEmpty: s.composer.isEmpty,
})
);

const visual = renderVisual ? (
renderVisual(state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen } from "@testing-library/react";

vi.mock("@assistant-ui/react", () => ({
useThread: vi.fn(),
useComposer: vi.fn(),
useAuiState: vi.fn(),
ComposerPrimitive: {
Send: ({ children, className }: any) => (
<button data-testid="send" className={className}>
Expand All @@ -19,27 +18,29 @@ vi.mock("@assistant-ui/react", () => ({
}));

import { ComposerActionStatus } from "../ComposerActionStatus";
import { useThread, useComposer } from "@assistant-ui/react";
import { useAuiState } from "@assistant-ui/react";
import {
DEFAULT_BUTTON_CLASSNAME,
DEFAULT_IDLE_BUTTON_CLASSNAME,
} from "../defaults";

const mockThread = useThread as ReturnType<typeof vi.fn>;
const mockComposer = useComposer as ReturnType<typeof vi.fn>;
const mockUseAuiState = useAuiState as ReturnType<typeof vi.fn>;

function setMockState(overrides: {
isRunning?: boolean;
isEditing?: boolean;
isEmpty?: boolean;
}) {
mockThread.mockReturnValue({
isRunning: overrides.isRunning ?? false,
});
mockComposer.mockReturnValue({
isEditing: overrides.isEditing ?? false,
isEmpty: overrides.isEmpty ?? true,
});
const state = {
thread: { isRunning: overrides.isRunning ?? false },
composer: {
isEditing: overrides.isEditing ?? false,
isEmpty: overrides.isEmpty ?? true,
},
};
mockUseAuiState.mockImplementation((selector: (s: any) => any) =>
selector(state),
);
}

describe("ComposerActionStatus", () => {
Expand Down
10 changes: 5 additions & 5 deletions packages/core/src/primitives/edit-composer/EditComposer.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"use client";

import { ComposerPrimitive, useComposer } from "@assistant-ui/react";
import { ComposerPrimitive, useAuiState } from "@assistant-ui/react";
import {
DEFAULT_ROOT_CLASSNAME,
DEFAULT_INPUT_CLASSNAME,
Expand Down Expand Up @@ -46,8 +46,8 @@ export function EditComposer({
cancelLabel?: string;
saveLabel?: string;
} = {}) {
// TODO: migrate to store — useComposer to useAuiState(({ composer }) => composer)
const composer = useComposer();
const canCancel = useAuiState((s) => s.composer.canCancel);
const isEmpty = useAuiState((s) => s.composer.isEmpty);

return (
<ComposerPrimitive.Root className={className ?? DEFAULT_ROOT_CLASSNAME}>
Expand All @@ -58,13 +58,13 @@ export function EditComposer({
<div className={actionsClassName ?? DEFAULT_ACTIONS_CLASSNAME}>
<ComposerPrimitive.Cancel
className={buttonClassName ?? DEFAULT_BUTTON_CLASSNAME}
disabled={!composer.canCancel}
disabled={!canCancel}
>
{cancelLabel}
</ComposerPrimitive.Cancel>
<ComposerPrimitive.Send
className={buttonClassName ?? DEFAULT_BUTTON_CLASSNAME}
disabled={composer.isEmpty}
disabled={isEmpty}
>
{saveLabel}
</ComposerPrimitive.Send>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen } from "@testing-library/react";

vi.mock("@assistant-ui/react", () => ({
useComposer: vi.fn(),
useAuiState: vi.fn(),
ComposerPrimitive: {
Root: ({ children, className }: any) => (
<div data-testid="composer-root" className={className}>
Expand Down Expand Up @@ -30,14 +30,26 @@ vi.mock("@assistant-ui/react", () => ({
}));

import { EditComposer } from "../EditComposer";
import { useComposer } from "@assistant-ui/react";
import { useAuiState } from "@assistant-ui/react";

const mockComposer = useComposer as ReturnType<typeof vi.fn>;
const mockUseAuiState = useAuiState as ReturnType<typeof vi.fn>;

function setMockState(overrides: { canCancel?: boolean; isEmpty?: boolean }) {
const state = {
composer: {
canCancel: overrides.canCancel ?? true,
isEmpty: overrides.isEmpty ?? false,
},
};
mockUseAuiState.mockImplementation((selector: (s: any) => any) =>
selector(state),
);
}

describe("EditComposer", () => {
beforeEach(() => {
vi.clearAllMocks();
mockComposer.mockReturnValue({ canCancel: true, isEmpty: false });
setMockState({ canCancel: true, isEmpty: false });
});

it("renders input with default placeholder", () => {
Expand Down Expand Up @@ -69,13 +81,13 @@ describe("EditComposer", () => {
});

it("disables cancel when canCancel is false", () => {
mockComposer.mockReturnValue({ canCancel: false, isEmpty: false });
setMockState({ canCancel: false, isEmpty: false });
render(<EditComposer />);
expect(screen.getByTestId("cancel")).toBeDisabled();
});

it("disables send when isEmpty is true", () => {
mockComposer.mockReturnValue({ canCancel: true, isEmpty: true });
setMockState({ canCancel: true, isEmpty: true });
render(<EditComposer />);
expect(screen.getByTestId("send")).toBeDisabled();
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"use client";

import { ThreadPrimitive, useThread } from "@assistant-ui/react";
import { ThreadPrimitive, useAuiState } from "@assistant-ui/react";
import {
DEFAULT_CONTAINER_CLASSNAME,
DEFAULT_CHIP_CLASSNAME,
Expand Down Expand Up @@ -31,12 +31,10 @@ export function FollowUpSuggestions({
autoSend?: boolean;
renderChip?: (prompt: string, index: number) => React.ReactNode;
} = {}) {
// TODO: migrate to store — useThread to useAuiState(({ thread }) => thread)
const thread = useThread();
const suggestions = useAuiState((s) => s.thread.suggestions);
const isRunning = useAuiState((s) => s.thread.isRunning);

const suggestions = thread.suggestions;

if (!suggestions?.length || thread.isRunning) return null;
if (!suggestions?.length || isRunning) return null;

return (
<div className={className ?? DEFAULT_CONTAINER_CLASSNAME}>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen } from "@testing-library/react";

vi.mock("@assistant-ui/react", () => ({
useThread: vi.fn(),
useAuiState: vi.fn(),
ThreadPrimitive: {
Suggestion: ({ children, prompt, className }: any) => (
<button data-testid={`suggestion-${prompt}`} className={className}>
Expand All @@ -13,29 +13,39 @@ vi.mock("@assistant-ui/react", () => ({
}));

import { FollowUpSuggestions } from "../FollowUpSuggestions";
import { useThread } from "@assistant-ui/react";
import { useAuiState } from "@assistant-ui/react";

const mockThread = useThread as ReturnType<typeof vi.fn>;
const mockUseAuiState = useAuiState as ReturnType<typeof vi.fn>;

function setMockState(thread: {
suggestions?: any[];
isRunning?: boolean;
}) {
const state = { thread };
mockUseAuiState.mockImplementation((selector: (s: any) => any) =>
selector(state),
);
}

describe("FollowUpSuggestions", () => {
beforeEach(() => {
vi.clearAllMocks();
});

it("returns null when no suggestions", () => {
mockThread.mockReturnValue({ suggestions: [], isRunning: false });
setMockState({ suggestions: [], isRunning: false });
const { container } = render(<FollowUpSuggestions />);
expect(container.innerHTML).toBe("");
});

it("returns null when suggestions is undefined", () => {
mockThread.mockReturnValue({ isRunning: false });
setMockState({ isRunning: false });
const { container } = render(<FollowUpSuggestions />);
expect(container.innerHTML).toBe("");
});

it("returns null when thread is running", () => {
mockThread.mockReturnValue({
setMockState({
suggestions: [{ prompt: "Hello" }],
isRunning: true,
});
Expand All @@ -44,7 +54,7 @@ describe("FollowUpSuggestions", () => {
});

it("renders suggestions when available and not running", () => {
mockThread.mockReturnValue({
setMockState({
suggestions: [{ prompt: "Tell me more" }, { prompt: "Give an example" }],
isRunning: false,
});
Expand All @@ -54,7 +64,7 @@ describe("FollowUpSuggestions", () => {
});

it("uses custom className", () => {
mockThread.mockReturnValue({
setMockState({
suggestions: [{ prompt: "Test" }],
isRunning: false,
});
Expand All @@ -65,7 +75,7 @@ describe("FollowUpSuggestions", () => {
});

it("uses custom renderChip", () => {
mockThread.mockReturnValue({
setMockState({
suggestions: [{ prompt: "Test prompt" }],
isRunning: false,
});
Expand Down
Loading