diff --git a/cli/app.go b/cli/app.go index aa4590f1dbc..1ab596dc25b 100644 --- a/cli/app.go +++ b/cli/app.go @@ -2748,6 +2748,52 @@ This won't work unless you have an existing installation of our GitHub app on yo }, }, }, + { + Name: "infer", + Usage: "run cloud hosted inference on an image", + UsageText: createUsageText("inference infer", []string{ + generalFlagOrgID, inferenceFlagFileOrgID, inferenceFlagFileID, + inferenceFlagFileLocationID, inferenceFlagModelOrgID, inferenceFlagModelName, inferenceFlagModelVersion, + }, true, false), + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: generalFlagOrgID, + Usage: "organization ID that is executing the inference job", + Required: true, + }, + &cli.StringFlag{ + Name: inferenceFlagFileOrgID, + Usage: "organization ID that owns the file to run inference on", + Required: true, + }, + &cli.StringFlag{ + Name: inferenceFlagFileID, + Usage: "file ID of the file to run inference on", + Required: true, + }, + &cli.StringFlag{ + Name: inferenceFlagFileLocationID, + Usage: "location ID of the file to run inference on", + Required: true, + }, + &cli.StringFlag{ + Name: inferenceFlagModelOrgID, + Usage: "organization ID that hosts the model to use to run inference", + Required: true, + }, + &cli.StringFlag{ + Name: inferenceFlagModelName, + Usage: "name of the model to use to run inference", + Required: true, + }, + &cli.StringFlag{ + Name: inferenceFlagModelVersion, + Usage: "version of the model to use to run inference", + Required: true, + }, + }, + Action: createCommandWithT[mlInferenceInferArgs](MLInferenceInferAction), + }, { Name: "version", Usage: "print version info for this program", diff --git a/cli/auth.go b/cli/auth.go index c35c04d7bfe..0119598ff44 100644 --- a/cli/auth.go +++ b/cli/auth.go @@ -21,6 +21,7 @@ import ( buildpb "go.viam.com/api/app/build/v1" datapb "go.viam.com/api/app/data/v1" datasetpb "go.viam.com/api/app/dataset/v1" + mlinferencepb "go.viam.com/api/app/mlinference/v1" mltrainingpb "go.viam.com/api/app/mltraining/v1" packagepb "go.viam.com/api/app/packages/v1" apppb "go.viam.com/api/app/v1" @@ -544,6 +545,7 @@ func (c *viamClient) ensureLoggedInInner() error { c.packageClient = packagepb.NewPackageServiceClient(conn) c.datasetClient = datasetpb.NewDatasetServiceClient(conn) c.mlTrainingClient = mltrainingpb.NewMLTrainingServiceClient(conn) + c.mlInferenceClient = mlinferencepb.NewMLInferenceServiceClient(conn) c.buildClient = buildpb.NewBuildServiceClient(conn) return nil diff --git a/cli/client.go b/cli/client.go index aae57c46267..8624e2ce3a4 100644 --- a/cli/client.go +++ b/cli/client.go @@ -30,6 +30,7 @@ import ( buildpb "go.viam.com/api/app/build/v1" datapb "go.viam.com/api/app/data/v1" datasetpb "go.viam.com/api/app/dataset/v1" + mlinferencepb "go.viam.com/api/app/mlinference/v1" mltrainingpb "go.viam.com/api/app/mltraining/v1" packagepb "go.viam.com/api/app/packages/v1" apppb "go.viam.com/api/app/v1" @@ -69,16 +70,17 @@ var errNoShellService = errors.New("shell service is not enabled on this machine // viamClient wraps a cli.Context and provides all the CLI command functionality // needed to talk to the app and data services but not directly to robot parts. type viamClient struct { - c *cli.Context - conf *Config - client apppb.AppServiceClient - dataClient datapb.DataServiceClient - packageClient packagepb.PackageServiceClient - datasetClient datasetpb.DatasetServiceClient - mlTrainingClient mltrainingpb.MLTrainingServiceClient - buildClient buildpb.BuildServiceClient - baseURL *url.URL - authFlow *authFlow + c *cli.Context + conf *Config + client apppb.AppServiceClient + dataClient datapb.DataServiceClient + packageClient packagepb.PackageServiceClient + datasetClient datasetpb.DatasetServiceClient + mlTrainingClient mltrainingpb.MLTrainingServiceClient + mlInferenceClient mlinferencepb.MLInferenceServiceClient + buildClient buildpb.BuildServiceClient + baseURL *url.URL + authFlow *authFlow selectedOrg *apppb.Organization selectedLoc *apppb.Location diff --git a/cli/ml_inference.go b/cli/ml_inference.go new file mode 100644 index 00000000000..fbf3b2d7d28 --- /dev/null +++ b/cli/ml_inference.go @@ -0,0 +1,122 @@ +package cli + +import ( + "context" + "fmt" + "strings" + + "github.com/pkg/errors" + "github.com/urfave/cli/v2" + v1 "go.viam.com/api/app/data/v1" + mlinferencepb "go.viam.com/api/app/mlinference/v1" +) + +const ( + inferenceFlagFileOrgID = "file-org-id" + inferenceFlagFileID = "file-id" + inferenceFlagFileLocationID = "file-location-id" + inferenceFlagModelOrgID = "model-org-id" + inferenceFlagModelName = "model-name" + inferenceFlagModelVersion = "model-version" +) + +type mlInferenceInferArgs struct { + OrgID string + FileOrgID string + FileID string + FileLocationID string + ModelOrgID string + ModelName string + ModelVersion string +} + +// MLInferenceInferAction is the corresponding action for 'inference infer'. +func MLInferenceInferAction(c *cli.Context, args mlInferenceInferArgs) error { + client, err := newViamClient(c) + if err != nil { + return err + } + + _, err = client.mlRunInference( + args.OrgID, args.FileOrgID, args.FileID, args.FileLocationID, + args.ModelOrgID, args.ModelName, args.ModelVersion) + if err != nil { + return err + } + return nil +} + +// mlRunInference runs inference on an image with the specified parameters. +func (c *viamClient) mlRunInference(orgID, fileOrgID, fileID, fileLocation, modelOrgID, + modelName, modelVersion string, +) (*mlinferencepb.GetInferenceResponse, error) { + if err := c.ensureLoggedIn(); err != nil { + return nil, err + } + + req := &mlinferencepb.GetInferenceRequest{ + OrganizationId: orgID, + BinaryId: &v1.BinaryID{ + FileId: fileID, + OrganizationId: fileOrgID, + LocationId: fileLocation, + }, + RegistryItemId: fmt.Sprintf("%s:%s", modelOrgID, modelName), + RegistryItemVersion: modelVersion, + } + + resp, err := c.mlInferenceClient.GetInference(context.Background(), req) + if err != nil { + return nil, errors.Wrapf(err, "received error from server") + } + c.printInferenceResponse(resp) + return resp, nil +} + +// printInferenceResponse prints a neat representation of the GetInferenceResponse. +func (c *viamClient) printInferenceResponse(resp *mlinferencepb.GetInferenceResponse) { + printf(c.c.App.Writer, "Inference Response:") + printf(c.c.App.Writer, "Output Tensors:") + if resp.OutputTensors != nil { + for name, tensor := range resp.OutputTensors.Tensors { + printf(c.c.App.Writer, " Tensor Name: %s", name) + printf(c.c.App.Writer, " Shape: %v", tensor.Shape) + if tensor.Tensor != nil { + var sb strings.Builder + for i, value := range tensor.GetDoubleTensor().GetData() { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(fmt.Sprintf("%.4f", value)) + } + printf(c.c.App.Writer, " Values: [%s]", sb.String()) + } else { + printf(c.c.App.Writer, " No values available.") + } + } + } else { + printf(c.c.App.Writer, " No output tensors.") + } + + printf(c.c.App.Writer, "Annotations:") + printf(c.c.App.Writer, "Bounding Box Format: [x_min, y_min, x_max, y_max]") + if resp.Annotations != nil { + for _, bbox := range resp.Annotations.Bboxes { + printf(c.c.App.Writer, " Bounding Box ID: %s, Label: %s", + bbox.Id, bbox.Label) + printf(c.c.App.Writer, " Coordinates: [%f, %f, %f, %f]", + bbox.XMinNormalized, bbox.YMinNormalized, bbox.XMaxNormalized, bbox.YMaxNormalized) + if bbox.Confidence != nil { + printf(c.c.App.Writer, " Confidence: %.4f", *bbox.Confidence) + } + } + for _, classification := range resp.Annotations.Classifications { + printf(c.c.App.Writer, " Classification Label: %s", classification.Label) + if classification.Confidence != nil { + printf(c.c.App.Writer, " Confidence: %.4f", *classification.Confidence) + } + } + } else { + printf(c.c.App.Writer, " No annotations.") + } +}