diff --git a/ml/api.go b/ml/api.go index 031f0129588..a34d34b10a9 100644 --- a/ml/api.go +++ b/ml/api.go @@ -1,9 +1,10 @@ // Package ml provides some fundamental machine learning primitives. package ml -import "gorgonia.org/tensor" +import ( + "gorgonia.org/tensor" +) // Tensors are a data structure to hold the input and output map of tensors that will fed into a // model, or come from the result of a model. type Tensors map[string]*tensor.Dense - diff --git a/ml/classifications.go b/ml/classifications.go new file mode 100644 index 00000000000..7210a14a02d --- /dev/null +++ b/ml/classifications.go @@ -0,0 +1,88 @@ +package ml + +import ( + "strconv" + "strings" + "sync" + + "github.com/montanaflynn/stats" + "github.com/pkg/errors" + + "go.viam.com/rdk/vision/classification" +) + +const classifierProbabilityName = "probability" + +// FormatClassificationOutputs formats the output tensors from a model into classifications. +func FormatClassificationOutputs( + outNameMap *sync.Map, outMap Tensors, labels []string, +) (classification.Classifications, error) { + // check if output tensor name that classifier is looking for is already present + // in the nameMap. If not, find the probability name, and cache it in the nameMap + pName, ok := outNameMap.Load(classifierProbabilityName) + if !ok { + _, ok := outMap[classifierProbabilityName] + if !ok { + if len(outMap) == 1 { + for name := range outMap { // only 1 element in map, assume its probabilities + outNameMap.Store(classifierProbabilityName, name) + pName = name + } + } + } else { + outNameMap.Store(classifierProbabilityName, classifierProbabilityName) + pName = classifierProbabilityName + } + } + probabilityName, ok := pName.(string) + if !ok { + return nil, errors.Errorf("name map did not store a string of the tensor name, but an object of type %T instead", pName) + } + data, ok := outMap[probabilityName] + if !ok { + return nil, errors.Errorf("no tensor named 'probability' among output tensors [%s]", strings.Join(TensorNames(outMap), ", ")) + } + probs, err := ConvertToFloat64Slice(data.Data()) + if err != nil { + return nil, err + } + confs := checkClassificationScores(probs) + if labels != nil && len(labels) != len(confs) { + return nil, errors.Errorf("length of output (%d) expected to be length of label list (%d)", len(confs), len(labels)) + } + classifications := make(classification.Classifications, 0, len(confs)) + for i := 0; i < len(confs); i++ { + if labels == nil { + classifications = append(classifications, classification.NewClassification(confs[i], strconv.Itoa(i))) + } else { + if i >= len(labels) { + return nil, errors.Errorf("cannot access label number %v from label file with %v labels", i, len(labels)) + } + classifications = append(classifications, classification.NewClassification(confs[i], labels[i])) + } + } + return classifications, nil +} + +// checkClassification scores ensures that the input scores (output of classifier) +// will represent confidence values (from 0-1). +func checkClassificationScores(in []float64) []float64 { + if len(in) > 1 { + for _, p := range in { + if p < 0 || p > 1 { // is logit, needs softmax + confs := softmax(in) + return confs + } + } + return in // no need to softmax + } + // otherwise, this is a binary classifier + if in[0] < -1 || in[0] > 1 { // needs sigmoid + out, err := stats.Sigmoid(in) + if err != nil { + return in + } + return out + } + return in // no need to sigmoid +} diff --git a/ml/ml.go b/ml/ml.go new file mode 100644 index 00000000000..094176e518f --- /dev/null +++ b/ml/ml.go @@ -0,0 +1,217 @@ +// Package ml provides some fundamental machine learning primitives. +package ml + +import ( + "math" + "unsafe" + + "github.com/pkg/errors" + pb "go.viam.com/api/service/mlmodel/v1" + "golang.org/x/exp/constraints" + "gorgonia.org/tensor" +) + +// ProtoToTensors takes pb.FlatTensors and turns it into a Tensors map. +func ProtoToTensors(pbft *pb.FlatTensors) (Tensors, error) { + if pbft == nil { + return nil, errors.New("protobuf FlatTensors is nil") + } + tensors := Tensors{} + for name, ftproto := range pbft.Tensors { + t, err := CreateNewTensor(ftproto) + if err != nil { + return nil, err + } + tensors[name] = t + } + return tensors, nil +} + +// CreateNewTensor turns a proto FlatTensor into a *tensor.Dense. +func CreateNewTensor(pft *pb.FlatTensor) (*tensor.Dense, error) { + shape := make([]int, 0, len(pft.Shape)) + for _, s := range pft.Shape { + shape = append(shape, int(s)) + } + pt := pft.Tensor + switch t := pt.(type) { + case *pb.FlatTensor_Int8Tensor: + data := t.Int8Tensor + if data == nil { + return nil, errors.New("tensor of type Int8Tensor is nil") + } + dataSlice := data.GetData() + unsafeInt8Slice := *(*[]int8)(unsafe.Pointer(&dataSlice)) //nolint:gosec + int8Slice := make([]int8, 0, len(dataSlice)) + int8Slice = append(int8Slice, unsafeInt8Slice...) + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int8Slice)), nil + case *pb.FlatTensor_Uint8Tensor: + data := t.Uint8Tensor + if data == nil { + return nil, errors.New("tensor of type Uint8Tensor is nil") + } + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil + case *pb.FlatTensor_Int16Tensor: + data := t.Int16Tensor + if data == nil { + return nil, errors.New("tensor of type Int16Tensor is nil") + } + int16Data := uint32ToInt16(data.GetData()) + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int16Data)), nil + case *pb.FlatTensor_Uint16Tensor: + data := t.Uint16Tensor + if data == nil { + return nil, errors.New("tensor of type Uint16Tensor is nil") + } + uint16Data := uint32ToUint16(data.GetData()) + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(uint16Data)), nil + case *pb.FlatTensor_Int32Tensor: + data := t.Int32Tensor + if data == nil { + return nil, errors.New("tensor of type Int32Tensor is nil") + } + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil + case *pb.FlatTensor_Uint32Tensor: + data := t.Uint32Tensor + if data == nil { + return nil, errors.New("tensor of type Uint32Tensor is nil") + } + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil + case *pb.FlatTensor_Int64Tensor: + data := t.Int64Tensor + if data == nil { + return nil, errors.New("tensor of type Int64Tensor is nil") + } + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil + case *pb.FlatTensor_Uint64Tensor: + data := t.Uint64Tensor + if data == nil { + return nil, errors.New("tensor of type Uint64Tensor is nil") + } + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil + case *pb.FlatTensor_FloatTensor: + data := t.FloatTensor + if data == nil { + return nil, errors.New("tensor of type FloatTensor is nil") + } + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil + case *pb.FlatTensor_DoubleTensor: + data := t.DoubleTensor + if data == nil { + return nil, errors.New("tensor of type DoubleTensor is nil") + } + return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil + default: + return nil, errors.Errorf("don't know how to create tensor.Dense from proto type %T", pt) + } +} + +func uint32ToInt16(uint32Slice []uint32) []int16 { + int16Slice := make([]int16, len(uint32Slice)) + + for i, value := range uint32Slice { + int16Slice[i] = int16(value) + } + return int16Slice +} + +func uint32ToUint16(uint32Slice []uint32) []uint16 { + uint16Slice := make([]uint16, len(uint32Slice)) + + for i, value := range uint32Slice { + uint16Slice[i] = uint16(value) + } + return uint16Slice +} + +// number interface for converting between numbers. +type number interface { + constraints.Integer | constraints.Float +} + +// convertNumberSlice converts any number slice into another number slice. +func convertNumberSlice[T1, T2 number](t1 []T1) []T2 { + t2 := make([]T2, len(t1)) + for i := range t1 { + t2[i] = T2(t1[i]) + } + return t2 +} + +// ConvertToFloat64Slice converts any numbers or slice of numbers into a float64 slice. +func ConvertToFloat64Slice(slice interface{}) ([]float64, error) { + switch v := slice.(type) { + case []float64: + return v, nil + case float64: + return []float64{v}, nil + case []float32: + return convertNumberSlice[float32, float64](v), nil + case float32: + return convertNumberSlice[float32, float64]([]float32{v}), nil + case []int: + return convertNumberSlice[int, float64](v), nil + case int: + return convertNumberSlice[int, float64]([]int{v}), nil + case []uint: + return convertNumberSlice[uint, float64](v), nil + case uint: + return convertNumberSlice[uint, float64]([]uint{v}), nil + case []int8: + return convertNumberSlice[int8, float64](v), nil + case int8: + return convertNumberSlice[int8, float64]([]int8{v}), nil + case []int16: + return convertNumberSlice[int16, float64](v), nil + case int16: + return convertNumberSlice[int16, float64]([]int16{v}), nil + case []int32: + return convertNumberSlice[int32, float64](v), nil + case int32: + return convertNumberSlice[int32, float64]([]int32{v}), nil + case []int64: + return convertNumberSlice[int64, float64](v), nil + case int64: + return convertNumberSlice[int64, float64]([]int64{v}), nil + case []uint8: + return convertNumberSlice[uint8, float64](v), nil + case uint8: + return convertNumberSlice[uint8, float64]([]uint8{v}), nil + case []uint16: + return convertNumberSlice[uint16, float64](v), nil + case uint16: + return convertNumberSlice[uint16, float64]([]uint16{v}), nil + case []uint32: + return convertNumberSlice[uint32, float64](v), nil + case uint32: + return convertNumberSlice[uint32, float64]([]uint32{v}), nil + case []uint64: + return convertNumberSlice[uint64, float64](v), nil + case uint64: + return convertNumberSlice[uint64, float64]([]uint64{v}), nil + default: + return nil, errors.Errorf("dont know how to convert slice of %T into a []float64", slice) + } +} + +// softmax takes the input slice and applies the softmax function. +func softmax(in []float64) []float64 { + out := make([]float64, 0, len(in)) + bigSum := 0.0 + for _, x := range in { + bigSum += math.Exp(x) + } + for _, x := range in { + out = append(out, math.Exp(x)/bigSum) + } + return out +} + +// TensorNames returns all the names of the tensors. +func TensorNames(t Tensors) []string { + names := []string{} + for name := range t { + names = append(names, name) + } + return names +} diff --git a/services/mlmodel/client.go b/services/mlmodel/client.go index 6b112a38658..e5576addca8 100644 --- a/services/mlmodel/client.go +++ b/services/mlmodel/client.go @@ -2,12 +2,9 @@ package mlmodel import ( "context" - "unsafe" - "github.com/pkg/errors" pb "go.viam.com/api/service/mlmodel/v1" "go.viam.com/utils/rpc" - "gorgonia.org/tensor" "go.viam.com/rdk/logging" "go.viam.com/rdk/ml" @@ -59,126 +56,13 @@ func (c *client) Infer(ctx context.Context, tensors ml.Tensors) (ml.Tensors, err if err != nil { return nil, err } - tensorResp, err := ProtoToTensors(resp.OutputTensors) + tensorResp, err := ml.ProtoToTensors(resp.OutputTensors) if err != nil { return nil, err } return tensorResp, nil } -// ProtoToTensors takes pb.FlatTensors and turns it into a Tensors map. -func ProtoToTensors(pbft *pb.FlatTensors) (ml.Tensors, error) { - if pbft == nil { - return nil, errors.New("protobuf FlatTensors is nil") - } - tensors := ml.Tensors{} - for name, ftproto := range pbft.Tensors { - t, err := createNewTensor(ftproto) - if err != nil { - return nil, err - } - tensors[name] = t - } - return tensors, nil -} - -// createNewTensor turns a proto FlatTensor into a *tensor.Dense. -func createNewTensor(pft *pb.FlatTensor) (*tensor.Dense, error) { - shape := make([]int, 0, len(pft.Shape)) - for _, s := range pft.Shape { - shape = append(shape, int(s)) - } - pt := pft.Tensor - switch t := pt.(type) { - case *pb.FlatTensor_Int8Tensor: - data := t.Int8Tensor - if data == nil { - return nil, errors.New("tensor of type Int8Tensor is nil") - } - dataSlice := data.GetData() - unsafeInt8Slice := *(*[]int8)(unsafe.Pointer(&dataSlice)) //nolint:gosec - int8Slice := make([]int8, 0, len(dataSlice)) - int8Slice = append(int8Slice, unsafeInt8Slice...) - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int8Slice)), nil - case *pb.FlatTensor_Uint8Tensor: - data := t.Uint8Tensor - if data == nil { - return nil, errors.New("tensor of type Uint8Tensor is nil") - } - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil - case *pb.FlatTensor_Int16Tensor: - data := t.Int16Tensor - if data == nil { - return nil, errors.New("tensor of type Int16Tensor is nil") - } - int16Data := uint32ToInt16(data.GetData()) - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int16Data)), nil - case *pb.FlatTensor_Uint16Tensor: - data := t.Uint16Tensor - if data == nil { - return nil, errors.New("tensor of type Uint16Tensor is nil") - } - uint16Data := uint32ToUint16(data.GetData()) - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(uint16Data)), nil - case *pb.FlatTensor_Int32Tensor: - data := t.Int32Tensor - if data == nil { - return nil, errors.New("tensor of type Int32Tensor is nil") - } - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil - case *pb.FlatTensor_Uint32Tensor: - data := t.Uint32Tensor - if data == nil { - return nil, errors.New("tensor of type Uint32Tensor is nil") - } - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil - case *pb.FlatTensor_Int64Tensor: - data := t.Int64Tensor - if data == nil { - return nil, errors.New("tensor of type Int64Tensor is nil") - } - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil - case *pb.FlatTensor_Uint64Tensor: - data := t.Uint64Tensor - if data == nil { - return nil, errors.New("tensor of type Uint64Tensor is nil") - } - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil - case *pb.FlatTensor_FloatTensor: - data := t.FloatTensor - if data == nil { - return nil, errors.New("tensor of type FloatTensor is nil") - } - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil - case *pb.FlatTensor_DoubleTensor: - data := t.DoubleTensor - if data == nil { - return nil, errors.New("tensor of type DoubleTensor is nil") - } - return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil - default: - return nil, errors.Errorf("don't know how to create tensor.Dense from proto type %T", pt) - } -} - -func uint32ToInt16(uint32Slice []uint32) []int16 { - int16Slice := make([]int16, len(uint32Slice)) - - for i, value := range uint32Slice { - int16Slice[i] = int16(value) - } - return int16Slice -} - -func uint32ToUint16(uint32Slice []uint32) []uint16 { - uint16Slice := make([]uint16, len(uint32Slice)) - - for i, value := range uint32Slice { - uint16Slice[i] = uint16(value) - } - return uint16Slice -} - func (c *client) Metadata(ctx context.Context) (MLMetadata, error) { resp, err := c.client.Metadata(ctx, &pb.MetadataRequest{ Name: c.name, diff --git a/services/mlmodel/mlmodel_test.go b/services/mlmodel/mlmodel_test.go index fb60f304b94..b1c60cf0e0a 100644 --- a/services/mlmodel/mlmodel_test.go +++ b/services/mlmodel/mlmodel_test.go @@ -5,6 +5,8 @@ import ( "go.viam.com/test" "gorgonia.org/tensor" + + "go.viam.com/rdk/ml" ) func TestTensorRoundTrip(t *testing.T) { @@ -31,7 +33,7 @@ func TestTensorRoundTrip(t *testing.T) { test.That(t, resp.Shape, test.ShouldHaveLength, 2) test.That(t, resp.Shape[0], test.ShouldEqual, 2) test.That(t, resp.Shape[1], test.ShouldEqual, 3) - back, err := createNewTensor(resp) + back, err := ml.CreateNewTensor(resp) test.That(t, err, test.ShouldBeNil) test.That(t, back.Shape(), test.ShouldResemble, tensor.tensor.Shape()) test.That(t, back.Data(), test.ShouldResemble, tensor.tensor.Data()) diff --git a/services/mlmodel/server.go b/services/mlmodel/server.go index 30df1ebce43..c53f74813dc 100644 --- a/services/mlmodel/server.go +++ b/services/mlmodel/server.go @@ -29,7 +29,7 @@ func (server *serviceServer) Infer(ctx context.Context, req *pb.InferRequest) (* var it ml.Tensors if req.InputTensors != nil { - it, err = ProtoToTensors(req.InputTensors) + it, err = ml.ProtoToTensors(req.InputTensors) if err != nil { return nil, err } diff --git a/services/vision/mlvision/classifier.go b/services/vision/mlvision/classifier.go index b7cdf6a8603..1f532800751 100644 --- a/services/vision/mlvision/classifier.go +++ b/services/vision/mlvision/classifier.go @@ -3,8 +3,6 @@ package mlvision import ( "context" "image" - "strconv" - "strings" "sync" "github.com/nfnt/resize" @@ -101,50 +99,11 @@ func attemptToBuildClassifier(mlm mlmodel.Service, return nil, err } - // check if output tensor name that classifier is looking for is already present - // in the nameMap. If not, find the probability name, and cache it in the nameMap - pName, ok := outNameMap.Load(classifierProbabilityName) - if !ok { - _, ok := outMap[classifierProbabilityName] - if !ok { - if len(outMap) == 1 { - for name := range outMap { // only 1 element in map, assume its probabilities - outNameMap.Store(classifierProbabilityName, name) - pName = name - } - } - } else { - outNameMap.Store(classifierProbabilityName, classifierProbabilityName) - pName = classifierProbabilityName - } - } - probabilityName, ok := pName.(string) - if !ok { - return nil, errors.Errorf("name map did not store a string of the tensor name, but an object of type %T instead", pName) - } - data, ok := outMap[probabilityName] - if !ok { - return nil, errors.Errorf("no tensor named 'probability' among output tensors [%s]", strings.Join(tensorNames(outMap), ", ")) - } - probs, err := convertToFloat64Slice(data.Data()) + classifications, err := ml.FormatClassificationOutputs(outNameMap, outMap, labels) if err != nil { return nil, err } - confs := checkClassificationScores(probs) - if labels != nil && len(labels) != len(confs) { - return nil, errors.Errorf("length of output (%d) expected to be length of label list (%d)", len(confs), len(labels)) - } - classifications := make(classification.Classifications, 0, len(confs)) - for i := 0; i < len(confs); i++ { - if labels == nil { - classifications = append(classifications, classification.NewClassification(confs[i], strconv.Itoa(i))) - } else { - if i >= len(labels) { - return nil, errors.Errorf("cannot access label number %v from label file with %v labels", i, len(labels)) - } - classifications = append(classifications, classification.NewClassification(confs[i], labels[i])) - } - } + if postprocessor != nil { classifications = postprocessor(classifications) } diff --git a/services/vision/mlvision/detector.go b/services/vision/mlvision/detector.go index f74b6ba4776..9deadfe18d1 100644 --- a/services/vision/mlvision/detector.go +++ b/services/vision/mlvision/detector.go @@ -121,11 +121,11 @@ func attemptToBuildDetector(mlm mlmodel.Service, if err != nil { return nil, err } - locations, err := convertToFloat64Slice(outMap[locationName].Data()) + locations, err := ml.ConvertToFloat64Slice(outMap[locationName].Data()) if err != nil { return nil, err } - scores, err := convertToFloat64Slice(outMap[scoreName].Data()) + scores, err := ml.ConvertToFloat64Slice(outMap[scoreName].Data()) if err != nil { return nil, err } @@ -133,7 +133,7 @@ func attemptToBuildDetector(mlm mlmodel.Service, categories := make([]float64, len(scores)) // default 0 category if no category output if categoryName != "" { hasCategoryTensor = true - categories, err = convertToFloat64Slice(outMap[categoryName].Data()) + categories, err = ml.ConvertToFloat64Slice(outMap[categoryName].Data()) if err != nil { return nil, err } @@ -304,7 +304,7 @@ func findDetectionTensorNames(outMap ml.Tensors, nameMap *sync.Map) (string, str func guessDetectionTensorNames(outMap ml.Tensors) (string, string, string, error) { foundTensor := map[string]bool{} mappedNames := map[string]string{} - outNames := tensorNames(outMap) + outNames := ml.TensorNames(outMap) _, okLoc := outMap[detectorLocationName] if okLoc { foundTensor[detectorLocationName] = true @@ -332,7 +332,7 @@ func guessDetectionTensorNames(outMap ml.Tensors) (string, string, string, error if err != nil { return "", "", "", err } - val64, err := convertToFloat64Slice(val) + val64, err := ml.ConvertToFloat64Slice(val) if err != nil { return "", "", "", err } @@ -390,7 +390,7 @@ func guessDetectionTensorNames(outMap ml.Tensors) (string, string, string, error return "", "", "", err } } - val, err := convertToFloat64Slice(whole.Data()) + val, err := ml.ConvertToFloat64Slice(whole.Data()) if err != nil { return "", "", "", err } diff --git a/services/vision/mlvision/ml_model.go b/services/vision/mlvision/ml_model.go index 9def91d40c6..364f5fb074f 100644 --- a/services/vision/mlvision/ml_model.go +++ b/services/vision/mlvision/ml_model.go @@ -6,19 +6,15 @@ import ( "bufio" "context" "fmt" - "math" "os" "path/filepath" "strings" "sync" - "github.com/montanaflynn/stats" "github.com/pkg/errors" "go.opencensus.io/trace" - "golang.org/x/exp/constraints" "go.viam.com/rdk/logging" - "go.viam.com/rdk/ml" "go.viam.com/rdk/resource" "go.viam.com/rdk/robot" "go.viam.com/rdk/services/mlmodel" @@ -299,117 +295,3 @@ func getIndex(s []int, num int) int { } return -1 } - -// softmax takes the input slice and applies the softmax function. -func softmax(in []float64) []float64 { - out := make([]float64, 0, len(in)) - bigSum := 0.0 - for _, x := range in { - bigSum += math.Exp(x) - } - for _, x := range in { - out = append(out, math.Exp(x)/bigSum) - } - return out -} - -// checkClassification scores ensures that the input scores (output of classifier) -// will represent confidence values (from 0-1). -func checkClassificationScores(in []float64) []float64 { - if len(in) > 1 { - for _, p := range in { - if p < 0 || p > 1 { // is logit, needs softmax - confs := softmax(in) - return confs - } - } - return in // no need to softmax - } - // otherwise, this is a binary classifier - if in[0] < -1 || in[0] > 1 { // needs sigmoid - out, err := stats.Sigmoid(in) - if err != nil { - return in - } - return out - } - return in // no need to sigmoid -} - -// Number interface for converting between numbers. -type number interface { - constraints.Integer | constraints.Float -} - -// convertNumberSlice converts any number slice into another number slice. -func convertNumberSlice[T1, T2 number](t1 []T1) []T2 { - t2 := make([]T2, len(t1)) - for i := range t1 { - t2[i] = T2(t1[i]) - } - return t2 -} - -func convertToFloat64Slice(slice interface{}) ([]float64, error) { - switch v := slice.(type) { - case []float64: - return v, nil - case float64: - return []float64{v}, nil - case []float32: - return convertNumberSlice[float32, float64](v), nil - case float32: - return convertNumberSlice[float32, float64]([]float32{v}), nil - case []int: - return convertNumberSlice[int, float64](v), nil - case int: - return convertNumberSlice[int, float64]([]int{v}), nil - case []uint: - return convertNumberSlice[uint, float64](v), nil - case uint: - return convertNumberSlice[uint, float64]([]uint{v}), nil - case []int8: - return convertNumberSlice[int8, float64](v), nil - case int8: - return convertNumberSlice[int8, float64]([]int8{v}), nil - case []int16: - return convertNumberSlice[int16, float64](v), nil - case int16: - return convertNumberSlice[int16, float64]([]int16{v}), nil - case []int32: - return convertNumberSlice[int32, float64](v), nil - case int32: - return convertNumberSlice[int32, float64]([]int32{v}), nil - case []int64: - return convertNumberSlice[int64, float64](v), nil - case int64: - return convertNumberSlice[int64, float64]([]int64{v}), nil - case []uint8: - return convertNumberSlice[uint8, float64](v), nil - case uint8: - return convertNumberSlice[uint8, float64]([]uint8{v}), nil - case []uint16: - return convertNumberSlice[uint16, float64](v), nil - case uint16: - return convertNumberSlice[uint16, float64]([]uint16{v}), nil - case []uint32: - return convertNumberSlice[uint32, float64](v), nil - case uint32: - return convertNumberSlice[uint32, float64]([]uint32{v}), nil - case []uint64: - return convertNumberSlice[uint64, float64](v), nil - case uint64: - return convertNumberSlice[uint64, float64]([]uint64{v}), nil - default: - return nil, errors.Errorf("dont know how to convert slice of %T into a []float64", slice) - } -} - -// tensorNames returns all the names of the tensors. -func tensorNames(t ml.Tensors) []string { - names := []string{} - for name := range t { - names = append(names, name) - } - return names -}