From 134bae68924a407eed786475445554eeab45d555 Mon Sep 17 00:00:00 2001 From: Tahiya Salam Date: Thu, 30 Jan 2025 15:03:41 -0500 Subject: [PATCH] Move proto conversion to ML --- ml/ml.go | 117 ++++++++++++++++++++++++++++++++++++ services/mlmodel/client.go | 118 +------------------------------------ services/mlmodel/server.go | 2 +- 3 files changed, 119 insertions(+), 118 deletions(-) diff --git a/ml/ml.go b/ml/ml.go index c7cf229483f..9fd30b0f96b 100644 --- a/ml/ml.go +++ b/ml/ml.go @@ -6,11 +6,15 @@ import ( "strconv" "strings" "sync" + "unsafe" "github.com/montanaflynn/stats" "github.com/pkg/errors" + pb "go.viam.com/api/service/mlmodel/v1" "go.viam.com/rdk/vision/classification" "golang.org/x/exp/constraints" + + "gorgonia.org/tensor" ) const classifierProbabilityName = "probability" @@ -65,6 +69,119 @@ func FormatClassificationOutputs( return classifications, nil } +// ProtoToTensors takes pb.FlatTensors and turns it into a Tensors map. +func ProtoToTensors(pbft *pb.FlatTensors) (Tensors, error) { + if pbft == nil { + return nil, errors.New("protobuf FlatTensors is nil") + } + tensors := Tensors{} + for name, ftproto := range pbft.Tensors { + t, err := createNewTensor(ftproto) + if err != nil { + return nil, err + } + tensors[name] = t + } + return tensors, nil +} + +// createNewTensor turns a proto FlatTensor into a *tensor.Dense. +func createNewTensor(pft *pb.FlatTensor) (*tensor.Dense, error) { + shape := make([]int, 0, len(pft.Shape)) + for _, s := range pft.Shape { + shape = append(shape, int(s)) + } + pt := pft.Tensor + switch t := pt.(type) { + case *pb.FlatTensor_Int8Tensor: + data := t.Int8Tensor + if data == nil { + return nil, errors.New("tensor of type Int8Tensor is nil") + } + dataSlice := data.GetData() + unsafeInt8Slice := *(*[]int8)(unsafe.Pointer(&dataSlice)) //nolint:gosec + int8Slice := make([]int8, 0, len(dataSlice)) + int8Slice = append(int8Slice, unsafeInt8Slice...) + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int8Slice)), nil + case *pb.FlatTensor_Uint8Tensor: + data := t.Uint8Tensor + if data == nil { + return nil, errors.New("tensor of type Uint8Tensor is nil") + } + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil + case *pb.FlatTensor_Int16Tensor: + data := t.Int16Tensor + if data == nil { + return nil, errors.New("tensor of type Int16Tensor is nil") + } + int16Data := uint32ToInt16(data.GetData()) + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int16Data)), nil + case *pb.FlatTensor_Uint16Tensor: + data := t.Uint16Tensor + if data == nil { + return nil, errors.New("tensor of type Uint16Tensor is nil") + } + uint16Data := uint32ToUint16(data.GetData()) + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(uint16Data)), nil + case *pb.FlatTensor_Int32Tensor: + data := t.Int32Tensor + if data == nil { + return nil, errors.New("tensor of type Int32Tensor is nil") + } + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil + case *pb.FlatTensor_Uint32Tensor: + data := t.Uint32Tensor + if data == nil { + return nil, errors.New("tensor of type Uint32Tensor is nil") + } + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil + case *pb.FlatTensor_Int64Tensor: + data := t.Int64Tensor + if data == nil { + return nil, errors.New("tensor of type Int64Tensor is nil") + } + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil + case *pb.FlatTensor_Uint64Tensor: + data := t.Uint64Tensor + if data == nil { + return nil, errors.New("tensor of type Uint64Tensor is nil") + } + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil + case *pb.FlatTensor_FloatTensor: + data := t.FloatTensor + if data == nil { + return nil, errors.New("tensor of type FloatTensor is nil") + } + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil + case *pb.FlatTensor_DoubleTensor: + data := t.DoubleTensor + if data == nil { + return nil, errors.New("tensor of type DoubleTensor is nil") + } + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil + default: + return nil, errors.Errorf("don't know how to create tensor.Dense from proto type %T", pt) + } +} + +func uint32ToInt16(uint32Slice []uint32) []int16 { + int16Slice := make([]int16, len(uint32Slice)) + + for i, value := range uint32Slice { + int16Slice[i] = int16(value) + } + return int16Slice +} + +func uint32ToUint16(uint32Slice []uint32) []uint16 { + uint16Slice := make([]uint16, len(uint32Slice)) + + for i, value := range uint32Slice { + uint16Slice[i] = uint16(value) + } + return uint16Slice +} + // number interface for converting between numbers. type number interface { constraints.Integer | constraints.Float diff --git a/services/mlmodel/client.go b/services/mlmodel/client.go index 6b112a38658..e5576addca8 100644 --- a/services/mlmodel/client.go +++ b/services/mlmodel/client.go @@ -2,12 +2,9 @@ package mlmodel import ( "context" - "unsafe" - "github.com/pkg/errors" pb "go.viam.com/api/service/mlmodel/v1" "go.viam.com/utils/rpc" - "gorgonia.org/tensor" "go.viam.com/rdk/logging" "go.viam.com/rdk/ml" @@ -59,126 +56,13 @@ func (c *client) Infer(ctx context.Context, tensors ml.Tensors) (ml.Tensors, err if err != nil { return nil, err } - tensorResp, err := ProtoToTensors(resp.OutputTensors) + tensorResp, err := ml.ProtoToTensors(resp.OutputTensors) if err != nil { return nil, err } return tensorResp, nil } -// ProtoToTensors takes pb.FlatTensors and turns it into a Tensors map. -func ProtoToTensors(pbft *pb.FlatTensors) (ml.Tensors, error) { - if pbft == nil { - return nil, errors.New("protobuf FlatTensors is nil") - } - tensors := ml.Tensors{} - for name, ftproto := range pbft.Tensors { - t, err := createNewTensor(ftproto) - if err != nil { - return nil, err - } - tensors[name] = t - } - return tensors, nil -} - -// createNewTensor turns a proto FlatTensor into a *tensor.Dense. -func createNewTensor(pft *pb.FlatTensor) (*tensor.Dense, error) { - shape := make([]int, 0, len(pft.Shape)) - for _, s := range pft.Shape { - shape = append(shape, int(s)) - } - pt := pft.Tensor - switch t := pt.(type) { - case *pb.FlatTensor_Int8Tensor: - data := t.Int8Tensor - if data == nil { - return nil, errors.New("tensor of type Int8Tensor is nil") - } - dataSlice := data.GetData() - unsafeInt8Slice := *(*[]int8)(unsafe.Pointer(&dataSlice)) //nolint:gosec - int8Slice := make([]int8, 0, len(dataSlice)) - int8Slice = append(int8Slice, unsafeInt8Slice...) - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int8Slice)), nil - case *pb.FlatTensor_Uint8Tensor: - data := t.Uint8Tensor - if data == nil { - return nil, errors.New("tensor of type Uint8Tensor is nil") - } - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil - case *pb.FlatTensor_Int16Tensor: - data := t.Int16Tensor - if data == nil { - return nil, errors.New("tensor of type Int16Tensor is nil") - } - int16Data := uint32ToInt16(data.GetData()) - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int16Data)), nil - case *pb.FlatTensor_Uint16Tensor: - data := t.Uint16Tensor - if data == nil { - return nil, errors.New("tensor of type Uint16Tensor is nil") - } - uint16Data := uint32ToUint16(data.GetData()) - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(uint16Data)), nil - case *pb.FlatTensor_Int32Tensor: - data := t.Int32Tensor - if data == nil { - return nil, errors.New("tensor of type Int32Tensor is nil") - } - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil - case *pb.FlatTensor_Uint32Tensor: - data := t.Uint32Tensor - if data == nil { - return nil, errors.New("tensor of type Uint32Tensor is nil") - } - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil - case *pb.FlatTensor_Int64Tensor: - data := t.Int64Tensor - if data == nil { - return nil, errors.New("tensor of type Int64Tensor is nil") - } - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil - case *pb.FlatTensor_Uint64Tensor: - data := t.Uint64Tensor - if data == nil { - return nil, errors.New("tensor of type Uint64Tensor is nil") - } - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil - case *pb.FlatTensor_FloatTensor: - data := t.FloatTensor - if data == nil { - return nil, errors.New("tensor of type FloatTensor is nil") - } - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil - case *pb.FlatTensor_DoubleTensor: - data := t.DoubleTensor - if data == nil { - return nil, errors.New("tensor of type DoubleTensor is nil") - } - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil - default: - return nil, errors.Errorf("don't know how to create tensor.Dense from proto type %T", pt) - } -} - -func uint32ToInt16(uint32Slice []uint32) []int16 { - int16Slice := make([]int16, len(uint32Slice)) - - for i, value := range uint32Slice { - int16Slice[i] = int16(value) - } - return int16Slice -} - -func uint32ToUint16(uint32Slice []uint32) []uint16 { - uint16Slice := make([]uint16, len(uint32Slice)) - - for i, value := range uint32Slice { - uint16Slice[i] = uint16(value) - } - return uint16Slice -} - func (c *client) Metadata(ctx context.Context) (MLMetadata, error) { resp, err := c.client.Metadata(ctx, &pb.MetadataRequest{ Name: c.name, diff --git a/services/mlmodel/server.go b/services/mlmodel/server.go index 30df1ebce43..c53f74813dc 100644 --- a/services/mlmodel/server.go +++ b/services/mlmodel/server.go @@ -29,7 +29,7 @@ func (server *serviceServer) Infer(ctx context.Context, req *pb.InferRequest) (* var it ml.Tensors if req.InputTensors != nil { - it, err = ProtoToTensors(req.InputTensors) + it, err = ml.ProtoToTensors(req.InputTensors) if err != nil { return nil, err }