Skip to content

Commit 4af87c7

Browse files
committed
update
1 parent 1872783 commit 4af87c7

File tree

9 files changed

+312
-127
lines changed

9 files changed

+312
-127
lines changed

package-lock.json

Lines changed: 16 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"ldapts": "^7.1.0",
5353
"looks-same": "^9.0.0",
5454
"odiff-bin": "^2.6.1",
55+
"ollama": "^0.6.3",
5556
"passport": "^0.6.0",
5657
"passport-jwt": "^4.0.1",
5758
"passport-local": "^1.0.0",

src/compare/libs/vlm/ollama.controller.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,14 @@ export class OllamaController {
5050

5151
return this.ollamaService.generate({
5252
model,
53-
prompt,
53+
messages: [
54+
{
55+
role: 'user',
56+
content: prompt,
57+
images: files.map((f) => new Uint8Array(f.buffer)),
58+
},
59+
],
5460
format: 'json',
55-
images: files.map((f) => f.buffer.toString('base64')),
5661
options: { temperature: Number(temperature) },
5762
});
5863
}

src/compare/libs/vlm/ollama.service.spec.ts

Lines changed: 159 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,28 @@ import { Test, TestingModule } from '@nestjs/testing';
22
import { ConfigService } from '@nestjs/config';
33
import { OllamaService } from './ollama.service';
44

5+
// Mock the ollama module
6+
const mockChat = jest.fn();
7+
const mockList = jest.fn();
8+
9+
jest.mock('ollama', () => {
10+
const MockOllama = jest.fn().mockImplementation(() => ({
11+
chat: mockChat,
12+
list: mockList,
13+
}));
14+
return {
15+
Ollama: MockOllama,
16+
};
17+
});
18+
19+
520
describe('OllamaService', () => {
621
let service: OllamaService;
722

823
beforeEach(async () => {
24+
// Reset mocks
25+
jest.clearAllMocks();
26+
927
const module: TestingModule = await Test.createTestingModule({
1028
providers: [
1129
OllamaService,
@@ -22,39 +40,106 @@ describe('OllamaService', () => {
2240
});
2341

2442
describe('generate', () => {
25-
it('should call Ollama API with correct parameters', async () => {
26-
const mockResponse = { response: 'YES', done: true };
27-
globalThis.fetch = jest.fn().mockResolvedValue({
28-
ok: true,
29-
json: () => Promise.resolve(mockResponse),
30-
});
43+
it('should call Ollama SDK with correct parameters for Uint8Array', async () => {
44+
const mockResponse = {
45+
model: 'llava',
46+
created_at: new Date(),
47+
message: { content: 'YES', role: 'assistant' },
48+
done: true,
49+
done_reason: 'stop',
50+
total_duration: 1000,
51+
load_duration: 100,
52+
prompt_eval_count: 10,
53+
prompt_eval_duration: 200,
54+
eval_count: 5,
55+
eval_duration: 300,
56+
};
57+
mockChat.mockResolvedValue(mockResponse);
3158

59+
const testBytes = new Uint8Array([1, 2, 3, 4]);
3260
const result = await service.generate({
3361
model: 'llava',
34-
prompt: 'Test prompt',
35-
images: ['base64img'],
62+
messages: [
63+
{
64+
role: 'user',
65+
content: 'Test prompt',
66+
images: [testBytes],
67+
},
68+
],
3669
});
3770

38-
expect(fetch).toHaveBeenCalledWith(
39-
'http://localhost:11434/api/generate',
40-
expect.objectContaining({
41-
method: 'POST',
42-
headers: { 'Content-Type': 'application/json' },
43-
})
44-
);
45-
expect(result).toEqual(mockResponse);
71+
expect(mockChat).toHaveBeenCalledWith({
72+
model: 'llava',
73+
messages: [
74+
{
75+
role: 'user',
76+
content: 'Test prompt',
77+
images: [testBytes],
78+
},
79+
],
80+
stream: false,
81+
format: undefined,
82+
options: undefined,
83+
});
84+
expect(result.message.content).toBe('YES');
85+
expect(result.done).toBe(true);
4686
});
4787

48-
it('should throw error when API returns non-ok status', async () => {
49-
globalThis.fetch = jest.fn().mockResolvedValue({
50-
ok: false,
51-
status: 500,
52-
text: () => Promise.resolve('Internal Server Error'),
88+
it('should call Ollama SDK with correct parameters for base64 strings', async () => {
89+
const mockResponse = {
90+
model: 'llava',
91+
created_at: new Date(),
92+
message: { content: 'YES', role: 'assistant' },
93+
done: true,
94+
done_reason: 'stop',
95+
total_duration: 1000,
96+
load_duration: 100,
97+
prompt_eval_count: 10,
98+
prompt_eval_duration: 200,
99+
eval_count: 5,
100+
eval_duration: 300,
101+
};
102+
mockChat.mockResolvedValue(mockResponse);
103+
104+
// Use a longer base64 string
105+
const longBase64 = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==';
106+
const result = await service.generate({
107+
model: 'llava',
108+
messages: [
109+
{
110+
role: 'user',
111+
content: 'Test prompt',
112+
images: [longBase64], // base64 string - passed through as-is
113+
},
114+
],
53115
});
54116

55-
await expect(service.generate({ model: 'llava', prompt: 'Test' })).rejects.toThrow(
56-
'Ollama API returned status 500'
57-
);
117+
expect(mockChat).toHaveBeenCalledWith({
118+
model: 'llava',
119+
messages: [
120+
{
121+
role: 'user',
122+
content: 'Test prompt',
123+
images: [longBase64],
124+
},
125+
],
126+
stream: false,
127+
format: undefined,
128+
options: undefined,
129+
});
130+
expect(result.message.content).toBe('YES');
131+
expect(result.done).toBe(true);
132+
});
133+
134+
it('should throw error when SDK call fails', async () => {
135+
mockChat.mockRejectedValue(new Error('Connection refused'));
136+
137+
await expect(
138+
service.generate({
139+
model: 'llava',
140+
messages: [{ role: 'user', content: 'Test' }],
141+
})
142+
).rejects.toThrow('Connection refused');
58143
});
59144

60145
it('should throw error when OLLAMA_BASE_URL is not configured', async () => {
@@ -65,32 +150,68 @@ describe('OllamaService', () => {
65150
} as any;
66151
const newService = new OllamaService(mockConfigService);
67152

68-
await expect(newService.generate({ model: 'llava', prompt: 'Test' })).rejects.toThrow('OLLAMA_BASE_URL');
153+
await expect(
154+
newService.generate({
155+
model: 'llava',
156+
messages: [{ role: 'user', content: 'Test' }],
157+
})
158+
).rejects.toThrow('OLLAMA_BASE_URL');
69159
});
70160
});
71161

72162
describe('listModels', () => {
73163
it('should return list of models', async () => {
74-
const mockModels = { models: [{ name: 'llava:7b' }, { name: 'moondream' }] };
75-
globalThis.fetch = jest.fn().mockResolvedValue({
76-
ok: true,
77-
json: () => Promise.resolve(mockModels),
78-
});
164+
const mockDate = new Date('2024-01-01');
165+
const mockResponse = {
166+
models: [
167+
{
168+
name: 'llava:7b',
169+
model: 'llava:7b',
170+
size: 1000,
171+
digest: 'abc123',
172+
modified_at: mockDate,
173+
expires_at: mockDate,
174+
size_vram: 500,
175+
details: {
176+
parent_model: '',
177+
format: 'gguf',
178+
family: 'llama',
179+
families: ['llama'],
180+
parameter_size: '7B',
181+
quantization_level: 'Q4_0',
182+
},
183+
},
184+
{
185+
name: 'moondream',
186+
model: 'moondream',
187+
size: 2000,
188+
digest: 'def456',
189+
modified_at: mockDate,
190+
expires_at: mockDate,
191+
size_vram: 1000,
192+
details: {
193+
parent_model: '',
194+
format: 'gguf',
195+
family: 'moondream',
196+
families: ['moondream'],
197+
parameter_size: '1.6B',
198+
quantization_level: 'Q4_0',
199+
},
200+
},
201+
],
202+
};
203+
mockList.mockResolvedValue(mockResponse);
79204

80205
const result = await service.listModels();
81206

82-
expect(fetch).toHaveBeenCalledWith('http://localhost:11434/api/tags');
83-
expect(result).toEqual(mockModels.models);
207+
expect(mockList).toHaveBeenCalled();
208+
expect(result).toEqual(mockResponse.models);
84209
});
85210

86211
it('should throw error when API fails', async () => {
87-
globalThis.fetch = jest.fn().mockResolvedValue({
88-
ok: false,
89-
status: 503,
90-
text: () => Promise.resolve('Service Unavailable'),
91-
});
212+
mockList.mockRejectedValue(new Error('Service Unavailable'));
92213

93-
await expect(service.listModels()).rejects.toThrow('Failed to list models');
214+
await expect(service.listModels()).rejects.toThrow('Service Unavailable');
94215
});
95216
});
96217
});

src/compare/libs/vlm/ollama.service.ts

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,43 @@
11
import { Injectable, Logger } from '@nestjs/common';
22
import { ConfigService } from '@nestjs/config';
3-
import { OllamaGenerateRequest, OllamaGenerateResponse, OllamaModel, OllamaModelsResponse } from './ollama.types';
3+
import { Ollama, ChatRequest, ChatResponse, ListResponse, ModelResponse } from 'ollama';
44

55
@Injectable()
66
export class OllamaService {
77
private readonly logger: Logger = new Logger(OllamaService.name);
8-
private baseUrl: string | null = null;
8+
private ollamaClient: Ollama | null = null;
99

1010
constructor(private readonly configService: ConfigService) {}
1111

12-
private getBaseUrl(): string {
13-
if (!this.baseUrl) {
14-
this.baseUrl = this.configService.getOrThrow<string>('OLLAMA_BASE_URL');
12+
private getOllamaClient(): Ollama {
13+
if (!this.ollamaClient) {
14+
const baseUrl = this.configService.getOrThrow<string>('OLLAMA_BASE_URL');
15+
this.ollamaClient = new Ollama({ host: baseUrl });
1516
}
16-
return this.baseUrl;
17+
return this.ollamaClient;
1718
}
1819

19-
async generate(request: OllamaGenerateRequest): Promise<OllamaGenerateResponse> {
20-
const baseUrl = this.getBaseUrl();
20+
async generate(request: ChatRequest): Promise<ChatResponse> {
21+
const client = this.getOllamaClient();
22+
2123
try {
22-
const response = await fetch(`${baseUrl}/api/generate`, {
23-
method: 'POST',
24-
headers: { 'Content-Type': 'application/json' },
25-
body: JSON.stringify({ ...request, stream: request.stream ?? false }),
24+
const response = await client.chat({
25+
...request,
26+
stream: false,
2627
});
2728

28-
if (!response.ok) {
29-
const errorText = await response.text();
30-
throw new Error(`Ollama API returned status ${response.status}: ${errorText}`);
31-
}
32-
33-
return await response.json();
29+
return response;
3430
} catch (error) {
3531
this.logger.error(`Ollama generate request failed: ${error.message}`);
3632
throw error;
3733
}
3834
}
3935

40-
async listModels(): Promise<OllamaModel[]> {
41-
const baseUrl = this.getBaseUrl();
36+
async listModels(): Promise<ModelResponse[]> {
37+
const client = this.getOllamaClient();
4238
try {
43-
const response = await fetch(`${baseUrl}/api/tags`);
44-
45-
if (!response.ok) {
46-
const errorText = await response.text();
47-
throw new Error(`Failed to list models: ${response.status} ${errorText}`);
48-
}
49-
50-
const data: OllamaModelsResponse = await response.json();
51-
return data.models;
39+
const response: ListResponse = await client.list();
40+
return response.models;
5241
} catch (error) {
5342
this.logger.error(`Failed to list models: ${error.message}`);
5443
throw error;

0 commit comments

Comments
 (0)