forked from viamrobotics/rdk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmltraining_client.go
310 lines (287 loc) · 9.32 KB
/
mltraining_client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
package app
import (
"context"
"errors"
"time"
pb "go.viam.com/api/app/mltraining/v1"
"go.viam.com/utils/rpc"
status "google.golang.org/genproto/googleapis/rpc/status"
)
// TrainingStatus respresents the status of a training job.
type TrainingStatus int
const (
// TrainingStatusUnspecified respresents an unspecified training status.
TrainingStatusUnspecified TrainingStatus = iota
// TrainingStatusPending respresents a pending training job.
TrainingStatusPending
// TrainingStatusInProgress respresents a training job that is in progress.
TrainingStatusInProgress
// TrainingStatusCompleted respresents a completed training job.
TrainingStatusCompleted
// TrainingStatusFailed respresents a failed training job.
TrainingStatusFailed
// TrainingStatusCanceled respresents a canceled training job.
TrainingStatusCanceled
// TrainingStatusCanceling respresents a training job that is being canceled.
TrainingStatusCanceling
)
// TrainingJobMetadata contains the metadata for a training job.
type TrainingJobMetadata struct {
ID string
DatasetID string
OrganizationID string
ModelName string
ModelVersion string
ModelType ModelType
ModelFramework ModelFramework
IsCustomJob bool
RegistryItemID string
RegistryItemVersion string
Status TrainingStatus
ErrorStatus *status.Status
CreatedOn *time.Time
LastModified *time.Time
TrainingStarted *time.Time
TrainingEnded *time.Time
SyncedModelID string
Tags []string
}
// GetTrainingJobLogsOptions contains optional parameters for GetTrainingJobLogs.
type GetTrainingJobLogsOptions struct {
PageToken *string
}
// TrainingJobLogEntry is a log entry from a training job.
type TrainingJobLogEntry struct {
Level string
Time *time.Time
Message string
}
// SubmitTrainingJobArgs contains the necessary training job information to submit the job.
type SubmitTrainingJobArgs struct {
DatasetID string
OrganizationID string
ModelName string
ModelVersion string
}
// MLTrainingClient is a gRPC client for method calls to the ML Training API.
type MLTrainingClient struct {
client pb.MLTrainingServiceClient
}
func newMLTrainingClient(conn rpc.ClientConn) *MLTrainingClient {
return &MLTrainingClient{client: pb.NewMLTrainingServiceClient(conn)}
}
// SubmitTrainingJob submits a training job request and returns its ID.
func (c *MLTrainingClient) SubmitTrainingJob(
ctx context.Context, args SubmitTrainingJobArgs, modelType ModelType, tags []string,
) (string, error) {
err := args.isValid()
if err != nil {
return "", err
}
resp, err := c.client.SubmitTrainingJob(ctx, &pb.SubmitTrainingJobRequest{
DatasetId: args.DatasetID,
OrganizationId: args.OrganizationID,
ModelName: args.ModelName,
ModelVersion: args.ModelVersion,
ModelType: modelTypeToProto(modelType),
Tags: tags,
})
if err != nil {
return "", err
}
return resp.Id, nil
}
// SubmitCustomTrainingJob submits a custom training job request and returns its ID.
func (c *MLTrainingClient) SubmitCustomTrainingJob(
ctx context.Context, args SubmitTrainingJobArgs, registryItemID, registryItemVersion string, arguments map[string]string,
) (string, error) {
err := args.isValid()
if err != nil {
return "", err
}
resp, err := c.client.SubmitCustomTrainingJob(ctx, &pb.SubmitCustomTrainingJobRequest{
DatasetId: args.DatasetID,
RegistryItemId: registryItemID,
RegistryItemVersion: registryItemVersion,
OrganizationId: args.OrganizationID,
ModelName: args.ModelName,
ModelVersion: args.ModelVersion,
Arguments: arguments,
})
if err != nil {
return "", err
}
return resp.Id, nil
}
// GetTrainingJob retrieves a training job by its ID.
func (c *MLTrainingClient) GetTrainingJob(ctx context.Context, id string) (*TrainingJobMetadata, error) {
resp, err := c.client.GetTrainingJob(ctx, &pb.GetTrainingJobRequest{
Id: id,
})
if err != nil {
return nil, err
}
return trainingJobMetadataFromProto(resp.Metadata), nil
}
// ListTrainingJobs lists training jobs for a given organization ID and training status.
func (c *MLTrainingClient) ListTrainingJobs(
ctx context.Context, organizationID string, status TrainingStatus,
) ([]*TrainingJobMetadata, error) {
resp, err := c.client.ListTrainingJobs(ctx, &pb.ListTrainingJobsRequest{
OrganizationId: organizationID,
Status: trainingStatusToProto(status),
})
if err != nil {
return nil, err
}
var jobs []*TrainingJobMetadata
for _, job := range resp.Jobs {
jobs = append(jobs, trainingJobMetadataFromProto(job))
}
return jobs, nil
}
// CancelTrainingJob cancels a training job that has not yet completed.
func (c *MLTrainingClient) CancelTrainingJob(ctx context.Context, id string) error {
_, err := c.client.CancelTrainingJob(ctx, &pb.CancelTrainingJobRequest{
Id: id,
})
return err
}
// DeleteCompletedTrainingJob removes a completed training job from the database, whether the job succeeded or failed.
func (c *MLTrainingClient) DeleteCompletedTrainingJob(ctx context.Context, id string) error {
_, err := c.client.DeleteCompletedTrainingJob(ctx, &pb.DeleteCompletedTrainingJobRequest{
Id: id,
})
return err
}
// GetTrainingJobLogs gets the logs and the next page token for a given custom training job.
func (c *MLTrainingClient) GetTrainingJobLogs(
ctx context.Context, id string, opts *GetTrainingJobLogsOptions,
) ([]*TrainingJobLogEntry, string, error) {
var token *string
if opts != nil {
token = opts.PageToken
}
resp, err := c.client.GetTrainingJobLogs(ctx, &pb.GetTrainingJobLogsRequest{
Id: id,
PageToken: token,
})
if err != nil {
return nil, "", err
}
var logs []*TrainingJobLogEntry
for _, log := range resp.Logs {
logs = append(logs, trainingJobLogEntryFromProto(log))
}
return logs, resp.NextPageToken, nil
}
func (s *SubmitTrainingJobArgs) isValid() error {
if s.DatasetID == "" {
return errors.New("DatasetID should not be empty")
}
if s.OrganizationID == "" {
return errors.New("OrganizationID should not be empty")
}
if s.ModelName == "" {
return errors.New("ModelName should not be empty")
}
if s.ModelVersion == "" {
return errors.New("ModelVersion should not be empty")
}
return nil
}
func trainingJobLogEntryFromProto(log *pb.TrainingJobLogEntry) *TrainingJobLogEntry {
if log == nil {
return nil
}
var time *time.Time
if log.Time != nil {
t := log.Time.AsTime()
time = &t
}
return &TrainingJobLogEntry{
Level: log.Level,
Time: time,
Message: log.Message,
}
}
func trainingJobMetadataFromProto(metadata *pb.TrainingJobMetadata) *TrainingJobMetadata {
if metadata == nil {
return nil
}
var createdOn, lastModified, started, ended *time.Time
if metadata.CreatedOn != nil {
t := metadata.CreatedOn.AsTime()
createdOn = &t
}
if metadata.LastModified != nil {
t := metadata.LastModified.AsTime()
lastModified = &t
}
if metadata.TrainingStarted != nil {
t := metadata.TrainingStarted.AsTime()
started = &t
}
if metadata.TrainingEnded != nil {
t := metadata.TrainingEnded.AsTime()
ended = &t
}
return &TrainingJobMetadata{
ID: metadata.Id,
DatasetID: metadata.DatasetId,
OrganizationID: metadata.OrganizationId,
ModelName: metadata.ModelName,
ModelVersion: metadata.ModelVersion,
ModelType: modelTypeFromProto(metadata.ModelType),
ModelFramework: modelFrameworkFromProto(metadata.ModelFramework),
IsCustomJob: metadata.IsCustomJob,
RegistryItemID: metadata.RegistryItemId,
RegistryItemVersion: metadata.RegistryItemVersion,
Status: trainingStatusFromProto(metadata.Status),
ErrorStatus: metadata.ErrorStatus,
CreatedOn: createdOn,
LastModified: lastModified,
TrainingStarted: started,
TrainingEnded: ended,
SyncedModelID: metadata.SyncedModelId,
Tags: metadata.Tags,
}
}
func trainingStatusFromProto(status pb.TrainingStatus) TrainingStatus {
switch status {
case pb.TrainingStatus_TRAINING_STATUS_UNSPECIFIED:
return TrainingStatusUnspecified
case pb.TrainingStatus_TRAINING_STATUS_PENDING:
return TrainingStatusPending
case pb.TrainingStatus_TRAINING_STATUS_IN_PROGRESS:
return TrainingStatusInProgress
case pb.TrainingStatus_TRAINING_STATUS_COMPLETED:
return TrainingStatusCompleted
case pb.TrainingStatus_TRAINING_STATUS_FAILED:
return TrainingStatusFailed
case pb.TrainingStatus_TRAINING_STATUS_CANCELED:
return TrainingStatusCanceled
case pb.TrainingStatus_TRAINING_STATUS_CANCELING:
return TrainingStatusCanceling
}
return TrainingStatusUnspecified
}
func trainingStatusToProto(status TrainingStatus) pb.TrainingStatus {
switch status {
case TrainingStatusUnspecified:
return pb.TrainingStatus_TRAINING_STATUS_UNSPECIFIED
case TrainingStatusPending:
return pb.TrainingStatus_TRAINING_STATUS_PENDING
case TrainingStatusInProgress:
return pb.TrainingStatus_TRAINING_STATUS_IN_PROGRESS
case TrainingStatusCompleted:
return pb.TrainingStatus_TRAINING_STATUS_COMPLETED
case TrainingStatusFailed:
return pb.TrainingStatus_TRAINING_STATUS_FAILED
case TrainingStatusCanceled:
return pb.TrainingStatus_TRAINING_STATUS_CANCELED
case TrainingStatusCanceling:
return pb.TrainingStatus_TRAINING_STATUS_CANCELING
}
return pb.TrainingStatus_TRAINING_STATUS_UNSPECIFIED
}