Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
oliviamiller committed Apr 17, 2024
2 parents 5b3e871 + 513136a commit 28ff656
Show file tree
Hide file tree
Showing 10 changed files with 365 additions and 29 deletions.
14 changes: 13 additions & 1 deletion .github/workflows/staticbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,20 @@ on:
required: true

jobs:
build_timestamp:
name: Set Build Timestamp
runs-on: ubuntu-latest
outputs:
date: ${{ steps.build_time.outputs.date }}
steps:
- name: Build Time
id: build_time
if: inputs.release_type == 'latest'
run: echo "date=`date +%Y%m%d%H%M%S`" >> $GITHUB_OUTPUT

static:
name: Antique Build
needs: build_timestamp
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -76,7 +88,7 @@ jobs:
- name: Build (Latest)
if: inputs.release_type == 'latest'
run: |
sudo -Hu testbot bash -lc 'make BUILD_CHANNEL="latest" static-release'
sudo -Hu testbot bash -lc 'make RELEASE_TYPE="latest" BUILD_CHANNEL="${{needs.build_timestamp.outputs.date}}" static-release'
- name: Build (Tagged)
if: inputs.release_type == 'stable' || inputs.release_type == 'rc'
Expand Down
59 changes: 58 additions & 1 deletion cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ const (
loginFlagKeyID = "key-id"
loginFlagKey = "key"

// Flags shared by api-key, module and data subcommands.
// Flags shared by api-key, module, ml-training and data subcommands.
generalFlagOrgID = "org-id"
generalFlagLocationID = "location-id"
generalFlagMachineID = "machine-id"
Expand All @@ -57,6 +57,13 @@ const (
moduleBuildFlagPlatform = "platform"
moduleBuildFlagWait = "wait"

mlTrainingFlagPath = "path"
mlTrainingFlagName = "name"
mlTrainingFlagVersion = "version"
mlTrainingFlagFramework = "framework"
mlTrainingFlagType = "type"
mlTrainingFlagDraft = "draft"

dataFlagDestination = "destination"
dataFlagDataType = "data-type"
dataFlagOrgIDs = "org-ids"
Expand Down Expand Up @@ -1417,6 +1424,56 @@ Example:
},
},
},
{
Name: "training-script",
Usage: "manage training scripts for custom ML training",
Subcommands: []*cli.Command{
{
Name: "upload",
Usage: "upload ML training scripts for custom ML training",
UsageText: createUsageText("training-script upload", []string{mlTrainingFlagPath, mlTrainingFlagName}, true),
Flags: []cli.Flag{
&cli.StringFlag{
Name: mlTrainingFlagPath,
Usage: "path to ML training scripts for upload",
Required: true,
},
&cli.StringFlag{
Name: generalFlagOrgID,
Required: true,
Usage: "organization ID that will host the scripts",
},
&cli.StringFlag{
Name: mlTrainingFlagName,
Usage: "name of the ML training script to upload",
Required: true,
},
&cli.StringFlag{
Name: mlTrainingFlagVersion,
Usage: "version of the ML training script to upload",
Required: false,
},
&cli.StringFlag{
Name: mlTrainingFlagFramework,
Usage: "framework of the ML training script to upload, can be: " + strings.Join(modelFrameworks, ", "),
Required: false,
},
&cli.StringFlag{
Name: mlTrainingFlagType,
Usage: "task type of the ML training script to upload, can be: " + strings.Join(modelTypes, ", "),
Required: false,
},
&cli.BoolFlag{
Name: mlTrainingFlagDraft,
Usage: "indicate draft mode, drafts will not be viewable in the registry",
Required: false,
},
},
// Upload action
Action: MLTrainingUploadAction,
},
},
},
{
Name: "version",
Usage: "print version info for this program",
Expand Down
133 changes: 133 additions & 0 deletions cli/ml_training.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package cli

import (
"strings"

"github.com/pkg/errors"
"github.com/urfave/cli/v2"
"go.uber.org/multierr"
"google.golang.org/protobuf/types/known/structpb"
)

// MLTrainingUploadAction retrieves the logs for a specific build step.
func MLTrainingUploadAction(c *cli.Context) error {
client, err := newViamClient(c)
if err != nil {
return err
}

metadata, err := createMetadata(c.Bool(mlTrainingFlagDraft), c.String(mlTrainingFlagType),
c.String(mlTrainingFlagFramework))
if err != nil {
return err
}
metadataStruct, err := convertMetadataToStruct(*metadata)
if err != nil {
return err
}

if _, err := client.uploadPackage(c.String(generalFlagOrgID),
c.String(mlTrainingFlagName),
c.String(mlTrainingFlagVersion),
string(PackageTypeMLTraining),
c.Path(mlTrainingFlagPath),
metadataStruct,
); err != nil {
return err
}

moduleID := moduleID{
prefix: c.String(generalFlagOrgID),
name: c.String(mlTrainingFlagName),
}
url := moduleID.ToDetailURL(client.baseURL.Hostname(), PackageTypeMLTraining)
printf(c.App.Writer, "Version successfully uploaded! you can view your changes online here: %s", url)
return nil
}

// ModelType refers to the type of the model.
type ModelType string

// ModelType enumeration.
const (
ModelTypeUnspecified = ModelType("unspecified")
ModelTypeSingleLabelClassification = ModelType("single_label_classification")
ModelTypeMultiLabelClassification = ModelType("multi_label_classification")
ModelTypeObjectDetection = ModelType("object_detection")
)

var modelTypes = []string{
string(ModelTypeUnspecified), string(ModelTypeSingleLabelClassification),
string(ModelTypeMultiLabelClassification), string(ModelTypeObjectDetection),
}

// ModelFramework refers to the backend framework of the model.
type ModelFramework string

// ModelFramework enumeration.
const (
ModelFrameworkUnspecified = ModelFramework("unspecified")
ModelFrameworkTFLite = ModelFramework("tflite")
ModelFrameworkTensorFlow = ModelFramework("tensorflow")
ModelFrameworkPyTorch = ModelFramework("py_torch")
ModelFrameworkONNX = ModelFramework("onnx")
)

var modelFrameworks = []string{
string(ModelFrameworkUnspecified), string(ModelFrameworkTFLite), string(ModelFrameworkTensorFlow),
string(ModelFrameworkPyTorch), string(ModelFrameworkONNX),
}

// MLMetadata struct stores package info for ML training packages.
type MLMetadata struct {
Draft bool
ModelType string
Framework string
}

func createMetadata(draft bool, modelType, framework string) (*MLMetadata, error) {
t, typeErr := findValueOrSetDefault(modelTypes, modelType, string(ModelTypeUnspecified))
f, frameWorkErr := findValueOrSetDefault(modelFrameworks, framework, string(ModelFrameworkUnspecified))

if typeErr != nil || frameWorkErr != nil {
return nil, errors.Wrap(multierr.Combine(typeErr, frameWorkErr), "failed to set metadata")
}

return &MLMetadata{
Draft: draft,
ModelType: t,
Framework: f,
}, nil
}

// findValueOrSetDefault either finds the matching value from all possible values,
// sets a default if the value is not present, or errors if the value is not permissible.
func findValueOrSetDefault(arr []string, val, defaultVal string) (string, error) {
if val == "" {
return defaultVal, nil
}
for _, str := range arr {
if str == val {
return val, nil
}
}
return "", errors.New("value must be one of: " + strings.Join(arr, ", "))
}

var (
modelTypeKey = "model_type"
modelFrameworkKey = "model_framework"
draftKey = "draft"
)

func convertMetadataToStruct(metadata MLMetadata) (*structpb.Struct, error) {
metadataMap := make(map[string]interface{})
metadataMap[modelTypeKey] = metadata.ModelType
metadataMap[modelFrameworkKey] = metadata.Framework
metadataMap[draftKey] = metadata.Draft
metadataStruct, err := structpb.NewStruct(metadataMap)
if err != nil {
return nil, err
}
return metadataStruct, nil
}
42 changes: 34 additions & 8 deletions cli/module_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/pkg/errors"
"github.com/urfave/cli/v2"
"go.uber.org/multierr"
packagespb "go.viam.com/api/app/packages/v1"
apppb "go.viam.com/api/app/v1"
vutils "go.viam.com/utils"

Expand Down Expand Up @@ -392,7 +393,7 @@ func (c *viamClient) uploadModuleFile(
var errs error
// We do not add the EOF as an error because all server-side errors trigger an EOF on the stream
// This results in extra clutter to the error msg
if err := sendModuleUploadRequests(ctx, stream, file, c.c.App.Writer); err != nil && !errors.Is(err, io.EOF) {
if err := sendUploadRequests(ctx, stream, nil, file, c.c.App.Writer); err != nil && !errors.Is(err, io.EOF) {
errs = multierr.Combine(errs, errors.Wrapf(err, "could not upload %s", file.Name()))
}

Expand Down Expand Up @@ -732,7 +733,12 @@ func sameModels(a, b []ModuleComponent) bool {
return true
}

func sendModuleUploadRequests(ctx context.Context, stream apppb.AppService_UploadModuleFileClient, file *os.File, stdout io.Writer) error {
func sendUploadRequests(ctx context.Context, moduleStream apppb.AppService_UploadModuleFileClient,
pkgStream packagespb.PackageService_CreatePackageClient, file *os.File, stdout io.Writer,
) error {
if moduleStream != nil && pkgStream != nil {
return errors.New("can use either module or package client, not both")
}
stat, err := file.Stat()
if err != nil {
return err
Expand All @@ -742,15 +748,26 @@ func sendModuleUploadRequests(ctx context.Context, stream apppb.AppService_Uploa
// Close the line with the progress reading
defer printf(stdout, "")

//nolint:errcheck
defer stream.CloseSend()
if moduleStream != nil {
defer vutils.UncheckedErrorFunc(moduleStream.CloseSend)
}
if pkgStream != nil {
defer vutils.UncheckedErrorFunc(pkgStream.CloseSend)
}
// Loop until there is no more content to be read from file or the context expires.
for {
if ctx.Err() != nil {
return ctx.Err()
}
// Get the next UploadRequest from the file.
uploadReq, err := getNextModuleUploadRequest(file)
var moduleUploadReq *apppb.UploadModuleFileRequest
if moduleStream != nil {
moduleUploadReq, err = getNextModuleUploadRequest(file)
}
var pkgUploadReq *packagespb.CreatePackageRequest
if pkgStream != nil {
pkgUploadReq, err = getNextPackageUploadRequest(file)
}

// EOF means we've completed successfully.
if errors.Is(err, io.EOF) {
Expand All @@ -761,10 +778,19 @@ func sendModuleUploadRequests(ctx context.Context, stream apppb.AppService_Uploa
return errors.Wrap(err, "could not read file")
}

if err = stream.Send(uploadReq); err != nil {
return err
if moduleUploadReq != nil {
if err = moduleStream.Send(moduleUploadReq); err != nil {
return err
}
uploadedBytes += len(moduleUploadReq.GetFile())
}
uploadedBytes += len(uploadReq.GetFile())
if pkgUploadReq != nil {
if err = pkgStream.Send(pkgUploadReq); err != nil {
return err
}
uploadedBytes += len(pkgUploadReq.GetContents())
}

// Simple progress reading until we have a proper tui library
uploadPercent := int(math.Ceil(100 * float64(uploadedBytes) / float64(fileSize)))
fmt.Fprintf(stdout, "\rUploading... %d%% (%d/%d bytes)", uploadPercent, uploadedBytes, fileSize) // no newline
Expand Down
Loading

0 comments on commit 28ff656

Please sign in to comment.