Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DATA-3461 - Refactoring code for classifications #4764

Merged
merged 8 commits into from
Feb 10, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Move proto conversion to ML
tahiyasalam committed Jan 30, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 134bae68924a407eed786475445554eeab45d555
117 changes: 117 additions & 0 deletions ml/ml.go
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you split this file into two, one with the functions that just have to do with classification, and the other the generic Tensor to protobuf conversion functions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course! thank you for the first pass. I will update the PR description with all the manual testing

Original file line number Diff line number Diff line change
@@ -6,11 +6,15 @@ import (
"strconv"
"strings"
"sync"
"unsafe"

"github.com/montanaflynn/stats"
"github.com/pkg/errors"
pb "go.viam.com/api/service/mlmodel/v1"
"go.viam.com/rdk/vision/classification"
"golang.org/x/exp/constraints"

"gorgonia.org/tensor"
)

const classifierProbabilityName = "probability"
@@ -65,6 +69,119 @@ func FormatClassificationOutputs(
return classifications, nil
}

// ProtoToTensors takes pb.FlatTensors and turns it into a Tensors map.
func ProtoToTensors(pbft *pb.FlatTensors) (Tensors, error) {
if pbft == nil {
return nil, errors.New("protobuf FlatTensors is nil")
}
tensors := Tensors{}
for name, ftproto := range pbft.Tensors {
t, err := createNewTensor(ftproto)
if err != nil {
return nil, err
}
tensors[name] = t
}
return tensors, nil
}

// createNewTensor turns a proto FlatTensor into a *tensor.Dense.
func createNewTensor(pft *pb.FlatTensor) (*tensor.Dense, error) {
shape := make([]int, 0, len(pft.Shape))
for _, s := range pft.Shape {
shape = append(shape, int(s))
}
pt := pft.Tensor
switch t := pt.(type) {
case *pb.FlatTensor_Int8Tensor:
data := t.Int8Tensor
if data == nil {
return nil, errors.New("tensor of type Int8Tensor is nil")
}
dataSlice := data.GetData()
unsafeInt8Slice := *(*[]int8)(unsafe.Pointer(&dataSlice)) //nolint:gosec
int8Slice := make([]int8, 0, len(dataSlice))
int8Slice = append(int8Slice, unsafeInt8Slice...)
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int8Slice)), nil
case *pb.FlatTensor_Uint8Tensor:
data := t.Uint8Tensor
if data == nil {
return nil, errors.New("tensor of type Uint8Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Int16Tensor:
data := t.Int16Tensor
if data == nil {
return nil, errors.New("tensor of type Int16Tensor is nil")
}
int16Data := uint32ToInt16(data.GetData())
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int16Data)), nil
case *pb.FlatTensor_Uint16Tensor:
data := t.Uint16Tensor
if data == nil {
return nil, errors.New("tensor of type Uint16Tensor is nil")
}
uint16Data := uint32ToUint16(data.GetData())
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(uint16Data)), nil
case *pb.FlatTensor_Int32Tensor:
data := t.Int32Tensor
if data == nil {
return nil, errors.New("tensor of type Int32Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Uint32Tensor:
data := t.Uint32Tensor
if data == nil {
return nil, errors.New("tensor of type Uint32Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Int64Tensor:
data := t.Int64Tensor
if data == nil {
return nil, errors.New("tensor of type Int64Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Uint64Tensor:
data := t.Uint64Tensor
if data == nil {
return nil, errors.New("tensor of type Uint64Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_FloatTensor:
data := t.FloatTensor
if data == nil {
return nil, errors.New("tensor of type FloatTensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_DoubleTensor:
data := t.DoubleTensor
if data == nil {
return nil, errors.New("tensor of type DoubleTensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
default:
return nil, errors.Errorf("don't know how to create tensor.Dense from proto type %T", pt)
}
}

func uint32ToInt16(uint32Slice []uint32) []int16 {
int16Slice := make([]int16, len(uint32Slice))

for i, value := range uint32Slice {
int16Slice[i] = int16(value)
}
return int16Slice
}

func uint32ToUint16(uint32Slice []uint32) []uint16 {
uint16Slice := make([]uint16, len(uint32Slice))

for i, value := range uint32Slice {
uint16Slice[i] = uint16(value)
}
return uint16Slice
}

// number interface for converting between numbers.
type number interface {
constraints.Integer | constraints.Float
118 changes: 1 addition & 117 deletions services/mlmodel/client.go
Original file line number Diff line number Diff line change
@@ -2,12 +2,9 @@ package mlmodel

import (
"context"
"unsafe"

"github.com/pkg/errors"
pb "go.viam.com/api/service/mlmodel/v1"
"go.viam.com/utils/rpc"
"gorgonia.org/tensor"

"go.viam.com/rdk/logging"
"go.viam.com/rdk/ml"
@@ -59,126 +56,13 @@ func (c *client) Infer(ctx context.Context, tensors ml.Tensors) (ml.Tensors, err
if err != nil {
return nil, err
}
tensorResp, err := ProtoToTensors(resp.OutputTensors)
tensorResp, err := ml.ProtoToTensors(resp.OutputTensors)
if err != nil {
return nil, err
}
return tensorResp, nil
}

// ProtoToTensors takes pb.FlatTensors and turns it into a Tensors map.
func ProtoToTensors(pbft *pb.FlatTensors) (ml.Tensors, error) {
if pbft == nil {
return nil, errors.New("protobuf FlatTensors is nil")
}
tensors := ml.Tensors{}
for name, ftproto := range pbft.Tensors {
t, err := createNewTensor(ftproto)
if err != nil {
return nil, err
}
tensors[name] = t
}
return tensors, nil
}

// createNewTensor turns a proto FlatTensor into a *tensor.Dense.
func createNewTensor(pft *pb.FlatTensor) (*tensor.Dense, error) {
shape := make([]int, 0, len(pft.Shape))
for _, s := range pft.Shape {
shape = append(shape, int(s))
}
pt := pft.Tensor
switch t := pt.(type) {
case *pb.FlatTensor_Int8Tensor:
data := t.Int8Tensor
if data == nil {
return nil, errors.New("tensor of type Int8Tensor is nil")
}
dataSlice := data.GetData()
unsafeInt8Slice := *(*[]int8)(unsafe.Pointer(&dataSlice)) //nolint:gosec
int8Slice := make([]int8, 0, len(dataSlice))
int8Slice = append(int8Slice, unsafeInt8Slice...)
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int8Slice)), nil
case *pb.FlatTensor_Uint8Tensor:
data := t.Uint8Tensor
if data == nil {
return nil, errors.New("tensor of type Uint8Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Int16Tensor:
data := t.Int16Tensor
if data == nil {
return nil, errors.New("tensor of type Int16Tensor is nil")
}
int16Data := uint32ToInt16(data.GetData())
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int16Data)), nil
case *pb.FlatTensor_Uint16Tensor:
data := t.Uint16Tensor
if data == nil {
return nil, errors.New("tensor of type Uint16Tensor is nil")
}
uint16Data := uint32ToUint16(data.GetData())
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(uint16Data)), nil
case *pb.FlatTensor_Int32Tensor:
data := t.Int32Tensor
if data == nil {
return nil, errors.New("tensor of type Int32Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Uint32Tensor:
data := t.Uint32Tensor
if data == nil {
return nil, errors.New("tensor of type Uint32Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Int64Tensor:
data := t.Int64Tensor
if data == nil {
return nil, errors.New("tensor of type Int64Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Uint64Tensor:
data := t.Uint64Tensor
if data == nil {
return nil, errors.New("tensor of type Uint64Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_FloatTensor:
data := t.FloatTensor
if data == nil {
return nil, errors.New("tensor of type FloatTensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_DoubleTensor:
data := t.DoubleTensor
if data == nil {
return nil, errors.New("tensor of type DoubleTensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
default:
return nil, errors.Errorf("don't know how to create tensor.Dense from proto type %T", pt)
}
}

func uint32ToInt16(uint32Slice []uint32) []int16 {
int16Slice := make([]int16, len(uint32Slice))

for i, value := range uint32Slice {
int16Slice[i] = int16(value)
}
return int16Slice
}

func uint32ToUint16(uint32Slice []uint32) []uint16 {
uint16Slice := make([]uint16, len(uint32Slice))

for i, value := range uint32Slice {
uint16Slice[i] = uint16(value)
}
return uint16Slice
}

func (c *client) Metadata(ctx context.Context) (MLMetadata, error) {
resp, err := c.client.Metadata(ctx, &pb.MetadataRequest{
Name: c.name,
2 changes: 1 addition & 1 deletion services/mlmodel/server.go
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ func (server *serviceServer) Infer(ctx context.Context, req *pb.InferRequest) (*

var it ml.Tensors
if req.InputTensors != nil {
it, err = ProtoToTensors(req.InputTensors)
it, err = ml.ProtoToTensors(req.InputTensors)
if err != nil {
return nil, err
}