-
Notifications
You must be signed in to change notification settings - Fork 111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DATA-3461 - Refactoring code for classifications #4764
Open
tahiyasalam
wants to merge
7
commits into
viamrobotics:main
Choose a base branch
from
tahiyasalam:export-classifications
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
3bead35
Refactor vision classifier to import common helper from ML package
tahiyasalam 134bae6
Move proto conversion to ML
tahiyasalam ad56d03
Export more stuff
tahiyasalam a7fd544
De-dupe softmax
tahiyasalam af59617
More de-duping
tahiyasalam b186135
Split into two files - generic tensor operations and classifications
tahiyasalam 2f8292c
Remove inmap information from FormatClassificationOutputs
tahiyasalam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
// Package ml provides some fundamental machine learning primitives. | ||
package ml | ||
|
||
import "gorgonia.org/tensor" | ||
import ( | ||
"gorgonia.org/tensor" | ||
) | ||
|
||
// Tensors are a data structure to hold the input and output map of tensors that will fed into a | ||
// model, or come from the result of a model. | ||
type Tensors map[string]*tensor.Dense | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
package ml | ||
|
||
import ( | ||
"strconv" | ||
"strings" | ||
"sync" | ||
|
||
"github.com/montanaflynn/stats" | ||
"github.com/pkg/errors" | ||
"go.viam.com/rdk/vision/classification" | ||
) | ||
|
||
const classifierProbabilityName = "probability" | ||
|
||
func FormatClassificationOutputs( | ||
outNameMap *sync.Map, outMap Tensors, labels []string, | ||
) (classification.Classifications, error) { | ||
// check if output tensor name that classifier is looking for is already present | ||
// in the nameMap. If not, find the probability name, and cache it in the nameMap | ||
pName, ok := outNameMap.Load(classifierProbabilityName) | ||
if !ok { | ||
_, ok := outMap[classifierProbabilityName] | ||
if !ok { | ||
if len(outMap) == 1 { | ||
for name := range outMap { // only 1 element in map, assume its probabilities | ||
outNameMap.Store(classifierProbabilityName, name) | ||
pName = name | ||
} | ||
} | ||
} else { | ||
outNameMap.Store(classifierProbabilityName, classifierProbabilityName) | ||
pName = classifierProbabilityName | ||
} | ||
} | ||
probabilityName, ok := pName.(string) | ||
if !ok { | ||
return nil, errors.Errorf("name map did not store a string of the tensor name, but an object of type %T instead", pName) | ||
} | ||
data, ok := outMap[probabilityName] | ||
if !ok { | ||
return nil, errors.Errorf("no tensor named 'probability' among output tensors [%s]", strings.Join(TensorNames(outMap), ", ")) | ||
} | ||
probs, err := ConvertToFloat64Slice(data.Data()) | ||
if err != nil { | ||
return nil, err | ||
} | ||
confs := checkClassificationScores(probs) | ||
if labels != nil && len(labels) != len(confs) { | ||
return nil, errors.Errorf("length of output (%d) expected to be length of label list (%d)", len(confs), len(labels)) | ||
} | ||
classifications := make(classification.Classifications, 0, len(confs)) | ||
for i := 0; i < len(confs); i++ { | ||
if labels == nil { | ||
classifications = append(classifications, classification.NewClassification(confs[i], strconv.Itoa(i))) | ||
} else { | ||
if i >= len(labels) { | ||
return nil, errors.Errorf("cannot access label number %v from label file with %v labels", i, len(labels)) | ||
} | ||
classifications = append(classifications, classification.NewClassification(confs[i], labels[i])) | ||
} | ||
} | ||
return classifications, nil | ||
} | ||
|
||
// checkClassification scores ensures that the input scores (output of classifier) | ||
// will represent confidence values (from 0-1). | ||
func checkClassificationScores(in []float64) []float64 { | ||
if len(in) > 1 { | ||
for _, p := range in { | ||
if p < 0 || p > 1 { // is logit, needs softmax | ||
confs := softmax(in) | ||
return confs | ||
} | ||
} | ||
return in // no need to softmax | ||
} | ||
// otherwise, this is a binary classifier | ||
if in[0] < -1 || in[0] > 1 { // needs sigmoid | ||
out, err := stats.Sigmoid(in) | ||
if err != nil { | ||
return in | ||
} | ||
return out | ||
} | ||
return in // no need to sigmoid | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
// Package ml provides some fundamental machine learning primitives. | ||
package ml | ||
|
||
import ( | ||
"math" | ||
"unsafe" | ||
|
||
"github.com/pkg/errors" | ||
pb "go.viam.com/api/service/mlmodel/v1" | ||
"golang.org/x/exp/constraints" | ||
|
||
"gorgonia.org/tensor" | ||
) | ||
|
||
// ProtoToTensors takes pb.FlatTensors and turns it into a Tensors map. | ||
func ProtoToTensors(pbft *pb.FlatTensors) (Tensors, error) { | ||
if pbft == nil { | ||
return nil, errors.New("protobuf FlatTensors is nil") | ||
} | ||
tensors := Tensors{} | ||
for name, ftproto := range pbft.Tensors { | ||
t, err := CreateNewTensor(ftproto) | ||
if err != nil { | ||
return nil, err | ||
} | ||
tensors[name] = t | ||
} | ||
return tensors, nil | ||
} | ||
|
||
// CreateNewTensor turns a proto FlatTensor into a *tensor.Dense. | ||
func CreateNewTensor(pft *pb.FlatTensor) (*tensor.Dense, error) { | ||
shape := make([]int, 0, len(pft.Shape)) | ||
for _, s := range pft.Shape { | ||
shape = append(shape, int(s)) | ||
} | ||
pt := pft.Tensor | ||
switch t := pt.(type) { | ||
case *pb.FlatTensor_Int8Tensor: | ||
data := t.Int8Tensor | ||
if data == nil { | ||
return nil, errors.New("tensor of type Int8Tensor is nil") | ||
} | ||
dataSlice := data.GetData() | ||
unsafeInt8Slice := *(*[]int8)(unsafe.Pointer(&dataSlice)) //nolint:gosec | ||
int8Slice := make([]int8, 0, len(dataSlice)) | ||
int8Slice = append(int8Slice, unsafeInt8Slice...) | ||
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int8Slice)), nil | ||
case *pb.FlatTensor_Uint8Tensor: | ||
data := t.Uint8Tensor | ||
if data == nil { | ||
return nil, errors.New("tensor of type Uint8Tensor is nil") | ||
} | ||
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil | ||
case *pb.FlatTensor_Int16Tensor: | ||
data := t.Int16Tensor | ||
if data == nil { | ||
return nil, errors.New("tensor of type Int16Tensor is nil") | ||
} | ||
int16Data := uint32ToInt16(data.GetData()) | ||
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(int16Data)), nil | ||
case *pb.FlatTensor_Uint16Tensor: | ||
data := t.Uint16Tensor | ||
if data == nil { | ||
return nil, errors.New("tensor of type Uint16Tensor is nil") | ||
} | ||
uint16Data := uint32ToUint16(data.GetData()) | ||
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(uint16Data)), nil | ||
case *pb.FlatTensor_Int32Tensor: | ||
data := t.Int32Tensor | ||
if data == nil { | ||
return nil, errors.New("tensor of type Int32Tensor is nil") | ||
} | ||
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil | ||
case *pb.FlatTensor_Uint32Tensor: | ||
data := t.Uint32Tensor | ||
if data == nil { | ||
return nil, errors.New("tensor of type Uint32Tensor is nil") | ||
} | ||
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil | ||
case *pb.FlatTensor_Int64Tensor: | ||
data := t.Int64Tensor | ||
if data == nil { | ||
return nil, errors.New("tensor of type Int64Tensor is nil") | ||
} | ||
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil | ||
case *pb.FlatTensor_Uint64Tensor: | ||
data := t.Uint64Tensor | ||
if data == nil { | ||
return nil, errors.New("tensor of type Uint64Tensor is nil") | ||
} | ||
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil | ||
case *pb.FlatTensor_FloatTensor: | ||
data := t.FloatTensor | ||
if data == nil { | ||
return nil, errors.New("tensor of type FloatTensor is nil") | ||
} | ||
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil | ||
case *pb.FlatTensor_DoubleTensor: | ||
data := t.DoubleTensor | ||
if data == nil { | ||
return nil, errors.New("tensor of type DoubleTensor is nil") | ||
} | ||
return tensor.New(tensor.WithShape(shape...), tensor.WithBacking(data.GetData())), nil | ||
default: | ||
return nil, errors.Errorf("don't know how to create tensor.Dense from proto type %T", pt) | ||
} | ||
} | ||
|
||
func uint32ToInt16(uint32Slice []uint32) []int16 { | ||
int16Slice := make([]int16, len(uint32Slice)) | ||
|
||
for i, value := range uint32Slice { | ||
int16Slice[i] = int16(value) | ||
} | ||
return int16Slice | ||
} | ||
|
||
func uint32ToUint16(uint32Slice []uint32) []uint16 { | ||
uint16Slice := make([]uint16, len(uint32Slice)) | ||
|
||
for i, value := range uint32Slice { | ||
uint16Slice[i] = uint16(value) | ||
} | ||
return uint16Slice | ||
} | ||
|
||
// number interface for converting between numbers. | ||
type number interface { | ||
constraints.Integer | constraints.Float | ||
} | ||
|
||
// convertNumberSlice converts any number slice into another number slice. | ||
func convertNumberSlice[T1, T2 number](t1 []T1) []T2 { | ||
t2 := make([]T2, len(t1)) | ||
for i := range t1 { | ||
t2[i] = T2(t1[i]) | ||
} | ||
return t2 | ||
} | ||
|
||
func ConvertToFloat64Slice(slice interface{}) ([]float64, error) { | ||
switch v := slice.(type) { | ||
case []float64: | ||
return v, nil | ||
case float64: | ||
return []float64{v}, nil | ||
case []float32: | ||
return convertNumberSlice[float32, float64](v), nil | ||
case float32: | ||
return convertNumberSlice[float32, float64]([]float32{v}), nil | ||
case []int: | ||
return convertNumberSlice[int, float64](v), nil | ||
case int: | ||
return convertNumberSlice[int, float64]([]int{v}), nil | ||
case []uint: | ||
return convertNumberSlice[uint, float64](v), nil | ||
case uint: | ||
return convertNumberSlice[uint, float64]([]uint{v}), nil | ||
case []int8: | ||
return convertNumberSlice[int8, float64](v), nil | ||
case int8: | ||
return convertNumberSlice[int8, float64]([]int8{v}), nil | ||
case []int16: | ||
return convertNumberSlice[int16, float64](v), nil | ||
case int16: | ||
return convertNumberSlice[int16, float64]([]int16{v}), nil | ||
case []int32: | ||
return convertNumberSlice[int32, float64](v), nil | ||
case int32: | ||
return convertNumberSlice[int32, float64]([]int32{v}), nil | ||
case []int64: | ||
return convertNumberSlice[int64, float64](v), nil | ||
case int64: | ||
return convertNumberSlice[int64, float64]([]int64{v}), nil | ||
case []uint8: | ||
return convertNumberSlice[uint8, float64](v), nil | ||
case uint8: | ||
return convertNumberSlice[uint8, float64]([]uint8{v}), nil | ||
case []uint16: | ||
return convertNumberSlice[uint16, float64](v), nil | ||
case uint16: | ||
return convertNumberSlice[uint16, float64]([]uint16{v}), nil | ||
case []uint32: | ||
return convertNumberSlice[uint32, float64](v), nil | ||
case uint32: | ||
return convertNumberSlice[uint32, float64]([]uint32{v}), nil | ||
case []uint64: | ||
return convertNumberSlice[uint64, float64](v), nil | ||
case uint64: | ||
return convertNumberSlice[uint64, float64]([]uint64{v}), nil | ||
default: | ||
return nil, errors.Errorf("dont know how to convert slice of %T into a []float64", slice) | ||
} | ||
} | ||
|
||
// softmax takes the input slice and applies the softmax function. | ||
func softmax(in []float64) []float64 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you'll probably have to make this public, as I think it's used in the vision service |
||
out := make([]float64, 0, len(in)) | ||
bigSum := 0.0 | ||
for _, x := range in { | ||
bigSum += math.Exp(x) | ||
} | ||
for _, x := range in { | ||
out = append(out, math.Exp(x)/bigSum) | ||
} | ||
return out | ||
} | ||
|
||
// TensorNames returns all the names of the tensors. | ||
func TensorNames(t Tensors) []string { | ||
names := []string{} | ||
for name := range t { | ||
names = append(names, name) | ||
} | ||
return names | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you split this file into two, one with the functions that just have to do with classification, and the other the generic Tensor to protobuf conversion functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
of course! thank you for the first pass. I will update the PR description with all the manual testing