Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DATA-3467: Cloud Inference CLI #4748

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -2748,6 +2748,52 @@ This won't work unless you have an existing installation of our GitHub app on yo
},
},
},
{
Name: "inference",
Usage: "work with cloud hosted inference service",
UsageText: createUsageText("inference", nil, false, true),
HideHelpCommand: true,
Subcommands: []*cli.Command{
{
Name: "infer",
Usage: "run inference on an image",
UsageText: createUsageText("inference infer", []string{generalFlagOrgID, inferenceFlagFileOrgID, inferenceFlagFileID, inferenceFlagFileLocationID, inferenceFlagModelID, inferenceFlagModelVersionID}, true, false),
Flags: []cli.Flag{
&cli.StringFlag{
Name: generalFlagOrgID,
Usage: "organization ID that is executing the inference job",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagFileOrgID,
Usage: "organization ID that owns the file to run inference on",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagFileID,
Usage: "file ID of the file to run inference on",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagFileLocationID,
Usage: "location ID of the file to run inference on",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagModelID,
Usage: "ID of the model to use to run inference",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagModelVersionID,
Usage: "version ID of the model to use to run inference",
Required: true,
},
},
Action: createCommandWithT[mlInferenceInferArgs](MLInferenceInferAction),
},
},
},
{
Name: "version",
Usage: "print version info for this program",
Expand Down
2 changes: 2 additions & 0 deletions cli/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
buildpb "go.viam.com/api/app/build/v1"
datapb "go.viam.com/api/app/data/v1"
datasetpb "go.viam.com/api/app/dataset/v1"
mlinferencepb "go.viam.com/api/app/mlinference/v1"
mltrainingpb "go.viam.com/api/app/mltraining/v1"
packagepb "go.viam.com/api/app/packages/v1"
apppb "go.viam.com/api/app/v1"
Expand Down Expand Up @@ -544,6 +545,7 @@ func (c *viamClient) ensureLoggedInInner() error {
c.packageClient = packagepb.NewPackageServiceClient(conn)
c.datasetClient = datasetpb.NewDatasetServiceClient(conn)
c.mlTrainingClient = mltrainingpb.NewMLTrainingServiceClient(conn)
c.mlInferenceClient = mlinferencepb.NewMLInferenceServiceClient(conn)
c.buildClient = buildpb.NewBuildServiceClient(conn)

return nil
Expand Down
22 changes: 12 additions & 10 deletions cli/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
buildpb "go.viam.com/api/app/build/v1"
datapb "go.viam.com/api/app/data/v1"
datasetpb "go.viam.com/api/app/dataset/v1"
mlinferencepb "go.viam.com/api/app/mlinference/v1"
mltrainingpb "go.viam.com/api/app/mltraining/v1"
packagepb "go.viam.com/api/app/packages/v1"
apppb "go.viam.com/api/app/v1"
Expand Down Expand Up @@ -69,16 +70,17 @@ var errNoShellService = errors.New("shell service is not enabled on this machine
// viamClient wraps a cli.Context and provides all the CLI command functionality
// needed to talk to the app and data services but not directly to robot parts.
type viamClient struct {
c *cli.Context
conf *Config
client apppb.AppServiceClient
dataClient datapb.DataServiceClient
packageClient packagepb.PackageServiceClient
datasetClient datasetpb.DatasetServiceClient
mlTrainingClient mltrainingpb.MLTrainingServiceClient
buildClient buildpb.BuildServiceClient
baseURL *url.URL
authFlow *authFlow
c *cli.Context
conf *Config
client apppb.AppServiceClient
dataClient datapb.DataServiceClient
packageClient packagepb.PackageServiceClient
datasetClient datasetpb.DatasetServiceClient
mlTrainingClient mltrainingpb.MLTrainingServiceClient
mlInferenceClient mlinferencepb.MLInferenceServiceClient
buildClient buildpb.BuildServiceClient
baseURL *url.URL
authFlow *authFlow

selectedOrg *apppb.Organization
selectedLoc *apppb.Location
Expand Down
114 changes: 114 additions & 0 deletions cli/ml_inference.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package cli

import (
"context"
"fmt"

"github.com/pkg/errors"
"github.com/urfave/cli/v2"
v1 "go.viam.com/api/app/data/v1"
mlinferencepb "go.viam.com/api/app/mlinference/v1"
)

const (
inferenceFlagFileOrgID = "file-org-id"
inferenceFlagFileID = "file-id"
inferenceFlagFileLocationID = "file-location-id"
inferenceFlagModelID = "model-id"
inferenceFlagModelVersionID = "model-version"
)

type mlInferenceInferArgs struct {
OrgID string
FileOrgID string
FileID string
FileLocationID string
ModelID string
ModelVersion string
}

// MLInferenceInferAction is the corresponding action for 'inference infer'.
func MLInferenceInferAction(c *cli.Context, args mlInferenceInferArgs) error {
client, err := newViamClient(c)
if err != nil {
return err
}

_, err = client.mlRunInference(
args.OrgID, args.FileOrgID, args.FileID, args.FileLocationID,
args.ModelID, args.ModelVersion)
if err != nil {
return err
}
return nil
}

// mlRunInference runs inference on an image with the specified parameters.
func (c *viamClient) mlRunInference(orgID, fileOrgID, fileID, fileLocation, modelID, modelVersion string) (*mlinferencepb.GetInferenceResponse, error) {
if err := c.ensureLoggedIn(); err != nil {
return nil, err
}

req := &mlinferencepb.GetInferenceRequest{
OrganizationId: orgID,
BinaryId: &v1.BinaryID{
FileId: fileID,
OrganizationId: fileOrgID,
LocationId: fileLocation,
},
RegistryItemId: modelID,
RegistryItemVersion: modelVersion,
}

resp, err := c.mlInferenceClient.GetInference(context.Background(), req)
if err != nil {
return nil, errors.Wrapf(err, "received error from server")
}
printInferenceResponse(resp)
return resp, nil
}

// printInferenceResponse prints a neat representation of the GetInferenceResponse.
func printInferenceResponse(resp *mlinferencepb.GetInferenceResponse) {
fmt.Println("Inference Response:")
fmt.Println("Output Tensors:")
if resp.OutputTensors != nil {
for name, tensor := range resp.OutputTensors.Tensors {
fmt.Printf(" Tensor Name: %s\n", name)
fmt.Printf(" Shape: %v\n", tensor.Shape)
if tensor.Tensor != nil {
fmt.Print(" Values: [")
for i, value := range tensor.GetDoubleTensor().GetData() {
if i > 0 {
fmt.Print(", ")
}
fmt.Printf("%.4f", value)
}
fmt.Println("]")
} else {
fmt.Println(" No values available.")
}
}
} else {
fmt.Println(" No output tensors.")
}

fmt.Println("Annotations:")
if resp.Annotations != nil {
for _, bbox := range resp.Annotations.Bboxes {
fmt.Printf(" Bounding Box ID: %s, Label: %s\n", bbox.Id, bbox.Label)
fmt.Printf(" Coordinates: [%f, %f, %f, %f]\n", bbox.XMinNormalized, bbox.YMinNormalized, bbox.XMaxNormalized, bbox.YMaxNormalized)
if bbox.Confidence != nil {
fmt.Printf(" Confidence: %.4f\n", *bbox.Confidence)
}
}
for _, classification := range resp.Annotations.Classifications {
fmt.Printf(" Classification Label: %s\n", classification.Label)
if classification.Confidence != nil {
fmt.Printf(" Confidence: %.4f\n", *classification.Confidence)
}
}
} else {
fmt.Println(" No annotations.")
}
}
Loading