Skip to content

Commit cd465f1

Browse files
authored
feat(pkg/modelfile): add support for CODE and DATASET in modelfile (#66)
Signed-off-by: Gaius <[email protected]>
1 parent ca91651 commit cd465f1

14 files changed

+276
-28
lines changed

docs/getting-started.md

+6
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ CONFIG generation_config.json
5454

5555
# Model weight.
5656
MODEL \.safetensors$
57+
58+
# Model code.
59+
CODE \.py$
60+
61+
# Model dataset.
62+
DATASET \.csv$
5763
```
5864

5965
Then run the following command to build the model artifact:

pkg/backend/build_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ import (
2020
"testing"
2121
"time"
2222

23-
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
2423
"github.com/CloudNativeAI/modctl/test/mocks/modelfile"
24+
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
2525

2626
"github.com/stretchr/testify/assert"
2727
)

pkg/backend/processor/license_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import (
2121
"testing"
2222
"testing/fstest"
2323

24-
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
2524
"github.com/CloudNativeAI/modctl/test/mocks/storage"
25+
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
2626

2727
"github.com/stretchr/testify/assert"
2828
"github.com/stretchr/testify/mock"

pkg/backend/processor/model_config_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import (
2121
"testing"
2222
"testing/fstest"
2323

24-
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
2524
"github.com/CloudNativeAI/modctl/test/mocks/storage"
25+
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
2626

2727
"github.com/stretchr/testify/assert"
2828
"github.com/stretchr/testify/mock"

pkg/backend/processor/model_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import (
2121
"testing"
2222
"testing/fstest"
2323

24-
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
2524
"github.com/CloudNativeAI/modctl/test/mocks/storage"
25+
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
2626

2727
"github.com/stretchr/testify/assert"
2828
"github.com/stretchr/testify/mock"

pkg/backend/processor/readme_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import (
2121
"testing"
2222
"testing/fstest"
2323

24-
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
2524
"github.com/CloudNativeAI/modctl/test/mocks/storage"
25+
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
2626

2727
"github.com/stretchr/testify/assert"
2828
"github.com/stretchr/testify/mock"

pkg/modelfile/command/command.go

+18-2
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,26 @@ const (
2525
CONFIG = "config"
2626

2727
// MODEL is the command to set the model file path. The value of this command
28-
// is the regex of the model file path to match the model file name.
28+
// is the glob of the model file path to match the model file name.
2929
// The MODEL command can be used multiple times in a modelfile, it will scan
30-
// the model file path by the regex and copy each model file to the artifact
30+
// the model file path by the glob and copy each model file to the artifact
3131
// package, and each model file will be a layer.
3232
MODEL = "model"
3333

34+
// CODE is the command to set the code file path. The value of this commands
35+
// is the glob of the code file path to match the code file name.
36+
// The CODE command can be used multiple times in a modelfile, it will scan
37+
// the code file path by the glob and copy each code file to the artifact
38+
// package, and each code file will be a layer.
39+
CODE = "code"
40+
41+
// DATASET is the command to set the dataset file path. The value of this commands
42+
// is the glob of the dataset file path to match the dataset file name.
43+
// The DATASET command can be used multiple times in a modelfile, it will scan
44+
// the dataset file path by the glob and copy each dataset file to the artifact
45+
// package, and each dataset file will be a layer.
46+
DATASET = "dataset"
47+
3448
// NAME is the command to set the model name, such as llama3-8b-instruct, gpt2-xl,
3549
// qwen2-vl-72b-instruct, etc.
3650
NAME = "name"
@@ -59,6 +73,8 @@ const (
5973
var Commands = []string{
6074
CONFIG,
6175
MODEL,
76+
CODE,
77+
DATASET,
6278
NAME,
6379
ARCH,
6480
FAMILY,

pkg/modelfile/modelfile.go

+55-3
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,20 @@ type Modelfile interface {
3535
GetConfigs() []string
3636

3737
// GetModels returns the args of the model command in the modelfile,
38-
// and deduplicates the args. The order of the args is the same as The
38+
// and deduplicates the args. The order of the args is the same as the
3939
// order in the modelfile.
4040
GetModels() []string
4141

42+
// GetCode returns the args of the code command in the modelfile,
43+
// and deduplicates the args. The order of the args is the same as the
44+
// order in the modelfile.
45+
GetCodes() []string
46+
47+
// GetDatasets returns the args of the dataset command in the modelfile,
48+
// and deduplicates the args. The order of the args is the same as the
49+
// order in the modelfile.
50+
GetDatasets() []string
51+
4252
// GetName returns the value of the name command in the modelfile.
4353
GetName() string
4454

@@ -65,6 +75,8 @@ type Modelfile interface {
6575
type modelfile struct {
6676
config *hashset.Set
6777
model *hashset.Set
78+
code *hashset.Set
79+
dataset *hashset.Set
6880
name string
6981
arch string
7082
family string
@@ -78,8 +90,10 @@ type modelfile struct {
7890
// It parses the modelfile and returns the modelfile interface.
7991
func NewModelfile(path string) (Modelfile, error) {
8092
mf := &modelfile{
81-
config: hashset.New(),
82-
model: hashset.New(),
93+
config: hashset.New(),
94+
model: hashset.New(),
95+
code: hashset.New(),
96+
dataset: hashset.New(),
8397
}
8498
if err := mf.parseFile(path); err != nil {
8599
return nil, err
@@ -107,6 +121,10 @@ func (mf *modelfile) parseFile(path string) error {
107121
mf.config.Add(child.GetNext().GetValue())
108122
case modefilecommand.MODEL:
109123
mf.model.Add(child.GetNext().GetValue())
124+
case modefilecommand.CODE:
125+
mf.code.Add(child.GetNext().GetValue())
126+
case modefilecommand.DATASET:
127+
mf.dataset.Add(child.GetNext().GetValue())
110128
case modefilecommand.NAME:
111129
if mf.name != "" {
112130
return fmt.Errorf("duplicate name command on line %d", child.GetStartLine())
@@ -184,6 +202,40 @@ func (mf *modelfile) GetModels() []string {
184202
return models
185203
}
186204

205+
// GetCode returns the args of the code command in the modelfile,
206+
// and deduplicates the args. The order of the args is the same as the
207+
// order in the modelfile.
208+
func (mf *modelfile) GetCodes() []string {
209+
var codes []string
210+
for _, rawCode := range mf.code.Values() {
211+
code, ok := rawCode.(string)
212+
if !ok {
213+
continue
214+
}
215+
216+
codes = append(codes, code)
217+
}
218+
219+
return codes
220+
}
221+
222+
// GetDatasets returns the args of the dataset command in the modelfile,
223+
// and deduplicates the args. The order of the args is the same as the
224+
// order in the modelfile.
225+
func (mf *modelfile) GetDatasets() []string {
226+
var datasets []string
227+
for _, rawDataset := range mf.dataset.Values() {
228+
dataset, ok := rawDataset.(string)
229+
if !ok {
230+
continue
231+
}
232+
233+
datasets = append(datasets, dataset)
234+
}
235+
236+
return datasets
237+
}
238+
187239
// GetName returns the value of the name command in the modelfile.
188240
func (mf *modelfile) GetName() string {
189241
return mf.name

pkg/modelfile/modelfile_test.go

+35-5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ func TestNewModelfile(t *testing.T) {
3131
expectErr error
3232
configs []string
3333
models []string
34+
codes []string
35+
datasets []string
3436
name string
3537
arch string
3638
family string
@@ -44,6 +46,8 @@ func TestNewModelfile(t *testing.T) {
4446
# This is a comment
4547
config config1
4648
model model1
49+
code code1
50+
dataset dataset1
4751
name name1
4852
arch arch1
4953
family family1
@@ -55,6 +59,8 @@ quantization quantization1
5559
expectErr: nil,
5660
configs: []string{"config1"},
5761
models: []string{"model1"},
62+
codes: []string{"code1"},
63+
datasets: []string{"dataset1"},
5864
name: "name1",
5965
arch: "arch1",
6066
family: "family1",
@@ -68,6 +74,8 @@ quantization quantization1
6874
# This is a comment
6975
config config1
7076
model model1
77+
code code1
78+
dataset dataset1
7179
name name1
7280
arch arch1
7381
family family1
@@ -79,6 +87,8 @@ quantization quantization1
7987
expectErr: nil,
8088
configs: []string{"config1"},
8189
models: []string{"model1"},
90+
codes: []string{"code1"},
91+
datasets: []string{"dataset1"},
8292
name: "name1",
8393
arch: "arch1",
8494
family: "family1",
@@ -89,10 +99,14 @@ quantization quantization1
8999
},
90100
{
91101
input: `
92-
model model1
93-
model model2
94102
config config1
95103
config config2
104+
model model1
105+
model model2
106+
code code1
107+
code code2
108+
dataset dataset1
109+
dataset dataset2
96110
name name1
97111
arch arch1
98112
family family1
@@ -104,6 +118,8 @@ quantization quantization1
104118
expectErr: nil,
105119
configs: []string{"config1", "config2"},
106120
models: []string{"model1", "model2"},
121+
codes: []string{"code1", "code2"},
122+
datasets: []string{"dataset1", "dataset2"},
107123
name: "name1",
108124
arch: "arch1",
109125
family: "family1",
@@ -114,12 +130,18 @@ quantization quantization1
114130
},
115131
{
116132
input: `
117-
model model1
118-
model model1
119-
model model2
120133
config config1
121134
config config1
122135
config config2
136+
model model1
137+
model model1
138+
model model2
139+
code code1
140+
code code1
141+
code code2
142+
dataset dataset1
143+
dataset dataset1
144+
dataset dataset2
123145
name name1
124146
arch arch1
125147
family family1
@@ -131,6 +153,8 @@ quantization quantization1
131153
expectErr: nil,
132154
configs: []string{"config1", "config2"},
133155
models: []string{"model1", "model2"},
156+
codes: []string{"code1", "code2"},
157+
datasets: []string{"dataset1", "dataset2"},
134158
name: "name1",
135159
arch: "arch1",
136160
family: "family1",
@@ -202,10 +226,16 @@ name bar
202226
assert.NotNil(mf)
203227
configs := mf.GetConfigs()
204228
models := mf.GetModels()
229+
codes := mf.GetCodes()
230+
datasets := mf.GetDatasets()
205231
sort.Strings(configs)
206232
sort.Strings(models)
233+
sort.Strings(codes)
234+
sort.Strings(datasets)
207235
assert.Equal(tc.configs, configs)
208236
assert.Equal(tc.models, models)
237+
assert.Equal(tc.codes, codes)
238+
assert.Equal(tc.datasets, datasets)
209239
assert.Equal(tc.name, mf.GetName())
210240
assert.Equal(tc.arch, mf.GetArch())
211241
assert.Equal(tc.family, mf.GetFamily())

pkg/modelfile/parser/parser.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func parseCommandLine(line string, start, end int) (Node, error) {
102102
}
103103

104104
switch cmd {
105-
case command.CONFIG, command.MODEL, command.NAME, command.ARCH, command.FAMILY, command.FORMAT, command.PARAMSIZE, command.PRECISION, command.QUANTIZATION:
105+
case command.CONFIG, command.MODEL, command.CODE, command.DATASET, command.NAME, command.ARCH, command.FAMILY, command.FORMAT, command.PARAMSIZE, command.PRECISION, command.QUANTIZATION:
106106
argsNode, err := parseStringArgs(args, start, end)
107107
if err != nil {
108108
return nil, err

pkg/modelfile/parser/parser_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ func TestParseCommandLine(t *testing.T) {
149149
{"config foo", 1, 2, false, "config", []string{"foo"}},
150150
{"CONFIG foo", 1, 2, false, "config", []string{"foo"}},
151151
{"model foo", 1, 2, false, "model", []string{"foo"}},
152+
{"code foo", 1, 2, false, "code", []string{"foo"}},
153+
{"dataset foo", 1, 2, false, "dataset", []string{"foo"}},
152154
{"name bar", 3, 4, false, "name", []string{"bar"}},
153155
{"arch transformer", 5, 6, false, "arch", []string{"transformer"}},
154156
{"family llama3", 7, 8, false, "family", []string{"llama3"}},

test/mocks/backend/backend.go

+49-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)