Skip to content

Commit fa5c985

Browse files
authored
chore(refactor): track grpcProcess in the model structure (#3663)
* chore(refactor): track grpcProcess in the model structure This avoids to have to handle in two parts the data relative to the same model. It makes it easier to track and use mutex with. This also fixes races conditions while accessing to the model. Signed-off-by: Ettore Di Giacinto <[email protected]> * chore(tests): run protogen-go before starting aio tests Signed-off-by: Ettore Di Giacinto <[email protected]> * chore(tests): install protoc in aio tests Signed-off-by: Ettore Di Giacinto <[email protected]> --------- Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 3d12d20 commit fa5c985

File tree

7 files changed

+71
-44
lines changed

7 files changed

+71
-44
lines changed

.github/workflows/test.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,22 @@ jobs:
178178
uses: actions/checkout@v4
179179
with:
180180
submodules: true
181+
- name: Dependencies
182+
run: |
183+
# Install protoc
184+
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
185+
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
186+
rm protoc.zip
187+
go install google.golang.org/protobuf/cmd/[email protected]
188+
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
189+
PATH="$PATH:$HOME/go/bin" make protogen-go
181190
- name: Build images
182191
run: |
183192
docker build --build-arg FFMPEG=true --build-arg IMAGE_TYPE=extras --build-arg EXTRA_BACKENDS=rerankers --build-arg MAKEFLAGS="--jobs=5 --output-sync=target" -t local-ai:tests -f Dockerfile .
184193
BASE_IMAGE=local-ai:tests DOCKER_AIO_IMAGE=local-ai-aio:test make docker-aio
185194
- name: Test
186195
run: |
187-
LOCALAI_MODELS_DIR=$PWD/models LOCALAI_IMAGE_TAG=test LOCALAI_IMAGE=local-ai-aio \
196+
PATH="$PATH:$HOME/go/bin" LOCALAI_MODELS_DIR=$PWD/models LOCALAI_IMAGE_TAG=test LOCALAI_IMAGE=local-ai-aio \
188197
make run-e2e-aio
189198
- name: Setup tmate session if tests fail
190199
if: ${{ failure() }}

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ run-e2e-image:
468468
ls -liah $(abspath ./tests/e2e-fixtures)
469469
docker run -p 5390:8080 -e MODELS_PATH=/models -e THREADS=1 -e DEBUG=true -d --rm -v $(TEST_DIR):/models --gpus all --name e2e-tests-$(RANDOM) localai-tests
470470

471-
run-e2e-aio:
471+
run-e2e-aio: protogen-go
472472
@echo 'Running e2e AIO tests'
473473
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./tests/e2e-aio
474474

pkg/model/initializers.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,18 +304,19 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
304304
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
305305
}
306306
// Make sure the process is executable
307-
if err := ml.startProcess(uri, o.model, serverAddress); err != nil {
307+
process, err := ml.startProcess(uri, o.model, serverAddress)
308+
if err != nil {
308309
log.Error().Err(err).Str("path", uri).Msg("failed to launch ")
309310
return nil, err
310311
}
311312

312313
log.Debug().Msgf("GRPC Service Started")
313314

314-
client = NewModel(modelName, serverAddress)
315+
client = NewModel(modelName, serverAddress, process)
315316
} else {
316317
log.Debug().Msg("external backend is uri")
317318
// address
318-
client = NewModel(modelName, uri)
319+
client = NewModel(modelName, uri, nil)
319320
}
320321
} else {
321322
grpcProcess := backendPath(o.assetDir, backend)
@@ -346,13 +347,14 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
346347
args, grpcProcess = library.LoadLDSO(o.assetDir, args, grpcProcess)
347348

348349
// Make sure the process is executable in any circumstance
349-
if err := ml.startProcess(grpcProcess, o.model, serverAddress, args...); err != nil {
350+
process, err := ml.startProcess(grpcProcess, o.model, serverAddress, args...)
351+
if err != nil {
350352
return nil, err
351353
}
352354

353355
log.Debug().Msgf("GRPC Service Started")
354356

355-
client = NewModel(modelName, serverAddress)
357+
client = NewModel(modelName, serverAddress, process)
356358
}
357359

358360
log.Debug().Msgf("Wait for the service to start up")
@@ -374,6 +376,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
374376

375377
if !ready {
376378
log.Debug().Msgf("GRPC Service NOT ready")
379+
ml.deleteProcess(o.model)
377380
return nil, fmt.Errorf("grpc service not ready")
378381
}
379382

@@ -385,9 +388,11 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
385388

386389
res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options)
387390
if err != nil {
391+
ml.deleteProcess(o.model)
388392
return nil, fmt.Errorf("could not load model: %w", err)
389393
}
390394
if !res.Success {
395+
ml.deleteProcess(o.model)
391396
return nil, fmt.Errorf("could not load model (no success): %s", res.Message)
392397
}
393398

pkg/model/loader.go

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,25 @@ import (
1313

1414
"github.com/mudler/LocalAI/pkg/utils"
1515

16-
process "github.com/mudler/go-processmanager"
1716
"github.com/rs/zerolog/log"
1817
)
1918

2019
// new idea: what if we declare a struct of these here, and use a loop to check?
2120

2221
// TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we seperate directories for .bin/.yaml and .tmpl
2322
type ModelLoader struct {
24-
ModelPath string
25-
mu sync.Mutex
26-
models map[string]*Model
27-
grpcProcesses map[string]*process.Process
28-
templates *templates.TemplateCache
29-
wd *WatchDog
23+
ModelPath string
24+
mu sync.Mutex
25+
models map[string]*Model
26+
templates *templates.TemplateCache
27+
wd *WatchDog
3028
}
3129

3230
func NewModelLoader(modelPath string) *ModelLoader {
3331
nml := &ModelLoader{
34-
ModelPath: modelPath,
35-
models: make(map[string]*Model),
36-
templates: templates.NewTemplateCache(modelPath),
37-
grpcProcesses: make(map[string]*process.Process),
32+
ModelPath: modelPath,
33+
models: make(map[string]*Model),
34+
templates: templates.NewTemplateCache(modelPath),
3835
}
3936

4037
return nml
@@ -127,6 +124,8 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
127124
modelFile := filepath.Join(ml.ModelPath, modelName)
128125
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
129126

127+
ml.mu.Lock()
128+
defer ml.mu.Unlock()
130129
model, err := loader(modelName, modelFile)
131130
if err != nil {
132131
return nil, err
@@ -136,8 +135,6 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
136135
return nil, fmt.Errorf("loader didn't return a model")
137136
}
138137

139-
ml.mu.Lock()
140-
defer ml.mu.Unlock()
141138
ml.models[modelName] = model
142139

143140
return model, nil
@@ -146,14 +143,13 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
146143
func (ml *ModelLoader) ShutdownModel(modelName string) error {
147144
ml.mu.Lock()
148145
defer ml.mu.Unlock()
149-
150-
_, ok := ml.models[modelName]
146+
model, ok := ml.models[modelName]
151147
if !ok {
152148
return fmt.Errorf("model %s not found", modelName)
153149
}
154150

155151
retries := 1
156-
for ml.models[modelName].GRPC(false, ml.wd).IsBusy() {
152+
for model.GRPC(false, ml.wd).IsBusy() {
157153
log.Debug().Msgf("%s busy. Waiting.", modelName)
158154
dur := time.Duration(retries*2) * time.Second
159155
if dur > retryTimeout {
@@ -185,8 +181,8 @@ func (ml *ModelLoader) CheckIsLoaded(s string) *Model {
185181
if !alive {
186182
log.Warn().Msgf("GRPC Model not responding: %s", err.Error())
187183
log.Warn().Msgf("Deleting the process in order to recreate it")
188-
process, exists := ml.grpcProcesses[s]
189-
if !exists {
184+
process := m.Process()
185+
if process == nil {
190186
log.Error().Msgf("Process not found for '%s' and the model is not responding anymore !", s)
191187
return m
192188
}

pkg/model/loader_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ var _ = Describe("ModelLoader", func() {
6363

6464
Context("LoadModel", func() {
6565
It("should load a model and keep it in memory", func() {
66-
mockModel = model.NewModel("foo", "test.model")
66+
mockModel = model.NewModel("foo", "test.model", nil)
6767

6868
mockLoader := func(modelName, modelFile string) (*model.Model, error) {
6969
return mockModel, nil
@@ -88,7 +88,7 @@ var _ = Describe("ModelLoader", func() {
8888

8989
Context("ShutdownModel", func() {
9090
It("should shutdown a loaded model", func() {
91-
mockModel = model.NewModel("foo", "test.model")
91+
mockModel = model.NewModel("foo", "test.model", nil)
9292

9393
mockLoader := func(modelName, modelFile string) (*model.Model, error) {
9494
return mockModel, nil

pkg/model/model.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,32 @@
11
package model
22

3-
import grpc "github.com/mudler/LocalAI/pkg/grpc"
3+
import (
4+
"sync"
5+
6+
grpc "github.com/mudler/LocalAI/pkg/grpc"
7+
process "github.com/mudler/go-processmanager"
8+
)
49

510
type Model struct {
611
ID string `json:"id"`
712
address string
813
client grpc.Backend
14+
process *process.Process
15+
sync.Mutex
916
}
1017

11-
func NewModel(ID, address string) *Model {
18+
func NewModel(ID, address string, process *process.Process) *Model {
1219
return &Model{
1320
ID: ID,
1421
address: address,
22+
process: process,
1523
}
1624
}
1725

26+
func (m *Model) Process() *process.Process {
27+
return m.process
28+
}
29+
1830
func (m *Model) GRPC(parallel bool, wd *WatchDog) grpc.Backend {
1931
if m.client != nil {
2032
return m.client
@@ -25,6 +37,8 @@ func (m *Model) GRPC(parallel bool, wd *WatchDog) grpc.Backend {
2537
enableWD = true
2638
}
2739

40+
m.Lock()
41+
defer m.Unlock()
2842
m.client = grpc.NewClient(m.address, parallel, wd, enableWD)
2943
return m.client
3044
}

pkg/model/process.go

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,22 @@ import (
1616
)
1717

1818
func (ml *ModelLoader) deleteProcess(s string) error {
19-
if _, exists := ml.grpcProcesses[s]; exists {
20-
if err := ml.grpcProcesses[s].Stop(); err != nil {
21-
log.Error().Err(err).Msgf("(deleteProcess) error while deleting grpc process %s", s)
19+
if m, exists := ml.models[s]; exists {
20+
process := m.Process()
21+
if process != nil {
22+
if err := process.Stop(); err != nil {
23+
log.Error().Err(err).Msgf("(deleteProcess) error while deleting process %s", s)
24+
}
2225
}
2326
}
24-
delete(ml.grpcProcesses, s)
2527
delete(ml.models, s)
2628
return nil
2729
}
2830

2931
func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) error {
3032
var err error = nil
31-
for k, p := range ml.grpcProcesses {
32-
if filter(k, p) {
33+
for k, m := range ml.models {
34+
if filter(k, m.Process()) {
3335
e := ml.ShutdownModel(k)
3436
err = errors.Join(err, e)
3537
}
@@ -44,17 +46,20 @@ func (ml *ModelLoader) StopAllGRPC() error {
4446
func (ml *ModelLoader) GetGRPCPID(id string) (int, error) {
4547
ml.mu.Lock()
4648
defer ml.mu.Unlock()
47-
p, exists := ml.grpcProcesses[id]
49+
p, exists := ml.models[id]
4850
if !exists {
4951
return -1, fmt.Errorf("no grpc backend found for %s", id)
5052
}
51-
return strconv.Atoi(p.PID)
53+
if p.Process() == nil {
54+
return -1, fmt.Errorf("no grpc backend found for %s", id)
55+
}
56+
return strconv.Atoi(p.Process().PID)
5257
}
5358

54-
func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string, args ...string) error {
59+
func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string, args ...string) (*process.Process, error) {
5560
// Make sure the process is executable
5661
if err := os.Chmod(grpcProcess, 0700); err != nil {
57-
return err
62+
return nil, err
5863
}
5964

6065
log.Debug().Msgf("Loading GRPC Process: %s", grpcProcess)
@@ -63,7 +68,7 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string
6368

6469
workDir, err := filepath.Abs(filepath.Dir(grpcProcess))
6570
if err != nil {
66-
return err
71+
return nil, err
6772
}
6873

6974
grpcControlProcess := process.New(
@@ -79,10 +84,8 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string
7984
ml.wd.AddAddressModelMap(serverAddress, id)
8085
}
8186

82-
ml.grpcProcesses[id] = grpcControlProcess
83-
8487
if err := grpcControlProcess.Run(); err != nil {
85-
return err
88+
return grpcControlProcess, err
8689
}
8790

8891
log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir())
@@ -116,5 +119,5 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string
116119
}
117120
}()
118121

119-
return nil
122+
return grpcControlProcess, nil
120123
}

0 commit comments

Comments
 (0)