Skip to content

Commit

Permalink
Move proto conversion to ML
Browse files Browse the repository at this point in the history
  • Loading branch information
tahiyasalam committed Jan 30, 2025
1 parent 3bead35 commit 134bae6
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 118 deletions.
117 changes: 117 additions & 0 deletions ml/ml.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ import (
"strconv"
"strings"
"sync"
"unsafe"

"github.com/montanaflynn/stats"
"github.com/pkg/errors"
pb "go.viam.com/api/service/mlmodel/v1"
"go.viam.com/rdk/vision/classification"
"golang.org/x/exp/constraints"

"gorgonia.org/tensor"
)

const classifierProbabilityName = "probability"
Expand Down Expand Up @@ -65,6 +69,119 @@ func FormatClassificationOutputs(
return classifications, nil
}

// ProtoToTensors takes pb.FlatTensors and turns it into a Tensors map.
func ProtoToTensors(pbft *pb.FlatTensors) (Tensors, error) {
if pbft == nil {
return nil, errors.New("protobuf FlatTensors is nil")
}
tensors := Tensors{}
for name, ftproto := range pbft.Tensors {
t, err := createNewTensor(ftproto)
if err != nil {
return nil, err
}
tensors[name] = t
}
return tensors, nil
}

// createNewTensor turns a proto FlatTensor into a *tensor.Dense.
func createNewTensor(pft *pb.FlatTensor) (*tensor.Dense, error) {
shape := make([]int, 0, len(pft.Shape))
for _, s := range pft.Shape {
shape = append(shape, int(s))
}
pt := pft.Tensor
switch t := pt.(type) {
case *pb.FlatTensor_Int8Tensor:
data := t.Int8Tensor
if data == nil {
return nil, errors.New("tensor of type Int8Tensor is nil")
}
dataSlice := data.GetData()
unsafeInt8Slice := *(*[]int8)(unsafe.Pointer(&dataSlice)) //nolint:gosec
int8Slice := make([]int8, 0, len(dataSlice))
int8Slice = append(int8Slice, unsafeInt8Slice...)
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int8Slice)), nil
case *pb.FlatTensor_Uint8Tensor:
data := t.Uint8Tensor
if data == nil {
return nil, errors.New("tensor of type Uint8Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Int16Tensor:
data := t.Int16Tensor
if data == nil {
return nil, errors.New("tensor of type Int16Tensor is nil")
}
int16Data := uint32ToInt16(data.GetData())
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int16Data)), nil
case *pb.FlatTensor_Uint16Tensor:
data := t.Uint16Tensor
if data == nil {
return nil, errors.New("tensor of type Uint16Tensor is nil")
}
uint16Data := uint32ToUint16(data.GetData())
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(uint16Data)), nil
case *pb.FlatTensor_Int32Tensor:
data := t.Int32Tensor
if data == nil {
return nil, errors.New("tensor of type Int32Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Uint32Tensor:
data := t.Uint32Tensor
if data == nil {
return nil, errors.New("tensor of type Uint32Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Int64Tensor:
data := t.Int64Tensor
if data == nil {
return nil, errors.New("tensor of type Int64Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Uint64Tensor:
data := t.Uint64Tensor
if data == nil {
return nil, errors.New("tensor of type Uint64Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_FloatTensor:
data := t.FloatTensor
if data == nil {
return nil, errors.New("tensor of type FloatTensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_DoubleTensor:
data := t.DoubleTensor
if data == nil {
return nil, errors.New("tensor of type DoubleTensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
default:
return nil, errors.Errorf("don't know how to create tensor.Dense from proto type %T", pt)
}
}

func uint32ToInt16(uint32Slice []uint32) []int16 {
int16Slice := make([]int16, len(uint32Slice))

for i, value := range uint32Slice {
int16Slice[i] = int16(value)
}
return int16Slice
}

func uint32ToUint16(uint32Slice []uint32) []uint16 {
uint16Slice := make([]uint16, len(uint32Slice))

for i, value := range uint32Slice {
uint16Slice[i] = uint16(value)
}
return uint16Slice
}

// number interface for converting between numbers.
type number interface {
constraints.Integer | constraints.Float
Expand Down
118 changes: 1 addition & 117 deletions services/mlmodel/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@ package mlmodel

import (
"context"
"unsafe"

"github.com/pkg/errors"
pb "go.viam.com/api/service/mlmodel/v1"
"go.viam.com/utils/rpc"
"gorgonia.org/tensor"

"go.viam.com/rdk/logging"
"go.viam.com/rdk/ml"
Expand Down Expand Up @@ -59,126 +56,13 @@ func (c *client) Infer(ctx context.Context, tensors ml.Tensors) (ml.Tensors, err
if err != nil {
return nil, err
}
tensorResp, err := ProtoToTensors(resp.OutputTensors)
tensorResp, err := ml.ProtoToTensors(resp.OutputTensors)
if err != nil {
return nil, err
}
return tensorResp, nil
}

// ProtoToTensors takes pb.FlatTensors and turns it into a Tensors map.
func ProtoToTensors(pbft *pb.FlatTensors) (ml.Tensors, error) {
if pbft == nil {
return nil, errors.New("protobuf FlatTensors is nil")
}
tensors := ml.Tensors{}
for name, ftproto := range pbft.Tensors {
t, err := createNewTensor(ftproto)
if err != nil {
return nil, err
}
tensors[name] = t
}
return tensors, nil
}

// createNewTensor turns a proto FlatTensor into a *tensor.Dense.
func createNewTensor(pft *pb.FlatTensor) (*tensor.Dense, error) {
shape := make([]int, 0, len(pft.Shape))
for _, s := range pft.Shape {
shape = append(shape, int(s))
}
pt := pft.Tensor
switch t := pt.(type) {
case *pb.FlatTensor_Int8Tensor:
data := t.Int8Tensor
if data == nil {
return nil, errors.New("tensor of type Int8Tensor is nil")
}
dataSlice := data.GetData()
unsafeInt8Slice := *(*[]int8)(unsafe.Pointer(&dataSlice)) //nolint:gosec
int8Slice := make([]int8, 0, len(dataSlice))
int8Slice = append(int8Slice, unsafeInt8Slice...)
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int8Slice)), nil
case *pb.FlatTensor_Uint8Tensor:
data := t.Uint8Tensor
if data == nil {
return nil, errors.New("tensor of type Uint8Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Int16Tensor:
data := t.Int16Tensor
if data == nil {
return nil, errors.New("tensor of type Int16Tensor is nil")
}
int16Data := uint32ToInt16(data.GetData())
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int16Data)), nil
case *pb.FlatTensor_Uint16Tensor:
data := t.Uint16Tensor
if data == nil {
return nil, errors.New("tensor of type Uint16Tensor is nil")
}
uint16Data := uint32ToUint16(data.GetData())
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(uint16Data)), nil
case *pb.FlatTensor_Int32Tensor:
data := t.Int32Tensor
if data == nil {
return nil, errors.New("tensor of type Int32Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Uint32Tensor:
data := t.Uint32Tensor
if data == nil {
return nil, errors.New("tensor of type Uint32Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Int64Tensor:
data := t.Int64Tensor
if data == nil {
return nil, errors.New("tensor of type Int64Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_Uint64Tensor:
data := t.Uint64Tensor
if data == nil {
return nil, errors.New("tensor of type Uint64Tensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_FloatTensor:
data := t.FloatTensor
if data == nil {
return nil, errors.New("tensor of type FloatTensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
case *pb.FlatTensor_DoubleTensor:
data := t.DoubleTensor
if data == nil {
return nil, errors.New("tensor of type DoubleTensor is nil")
}
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil
default:
return nil, errors.Errorf("don't know how to create tensor.Dense from proto type %T", pt)
}
}

func uint32ToInt16(uint32Slice []uint32) []int16 {
int16Slice := make([]int16, len(uint32Slice))

for i, value := range uint32Slice {
int16Slice[i] = int16(value)
}
return int16Slice
}

func uint32ToUint16(uint32Slice []uint32) []uint16 {
uint16Slice := make([]uint16, len(uint32Slice))

for i, value := range uint32Slice {
uint16Slice[i] = uint16(value)
}
return uint16Slice
}

func (c *client) Metadata(ctx context.Context) (MLMetadata, error) {
resp, err := c.client.Metadata(ctx, &pb.MetadataRequest{
Name: c.name,
Expand Down
2 changes: 1 addition & 1 deletion services/mlmodel/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (server *serviceServer) Infer(ctx context.Context, req *pb.InferRequest) (*

var it ml.Tensors
if req.InputTensors != nil {
it, err = ProtoToTensors(req.InputTensors)
it, err = ml.ProtoToTensors(req.InputTensors)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 134bae6

Please sign in to comment.