Skip to content

Commit a7fd544

Browse files
committed
De-dupe softmax
1 parent ad56d03 commit a7fd544

File tree

3 files changed

+7
-91
lines changed

3 files changed

+7
-91
lines changed

ml/ml.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func FormatClassificationOutputs(
4747
if !ok {
4848
return nil, errors.Errorf("no tensor named 'probability' among output tensors [%s]", strings.Join(tensorNames(outMap), ", "))
4949
}
50-
probs, err := convertToFloat64Slice(data.Data())
50+
probs, err := ConvertToFloat64Slice(data.Data())
5151
if err != nil {
5252
return nil, err
5353
}
@@ -196,7 +196,7 @@ func convertNumberSlice[T1, T2 number](t1 []T1) []T2 {
196196
return t2
197197
}
198198

199-
func convertToFloat64Slice(slice interface{}) ([]float64, error) {
199+
func ConvertToFloat64Slice(slice interface{}) ([]float64, error) {
200200
switch v := slice.(type) {
201201
case []float64:
202202
return v, nil

services/vision/mlvision/detector.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -121,19 +121,19 @@ func attemptToBuildDetector(mlm mlmodel.Service,
121121
if err != nil {
122122
return nil, err
123123
}
124-
locations, err := convertToFloat64Slice(outMap[locationName].Data())
124+
locations, err := ml.ConvertToFloat64Slice(outMap[locationName].Data())
125125
if err != nil {
126126
return nil, err
127127
}
128-
scores, err := convertToFloat64Slice(outMap[scoreName].Data())
128+
scores, err := ml.ConvertToFloat64Slice(outMap[scoreName].Data())
129129
if err != nil {
130130
return nil, err
131131
}
132132
hasCategoryTensor := false
133133
categories := make([]float64, len(scores)) // default 0 category if no category output
134134
if categoryName != "" {
135135
hasCategoryTensor = true
136-
categories, err = convertToFloat64Slice(outMap[categoryName].Data())
136+
categories, err = ml.ConvertToFloat64Slice(outMap[categoryName].Data())
137137
if err != nil {
138138
return nil, err
139139
}
@@ -332,7 +332,7 @@ func guessDetectionTensorNames(outMap ml.Tensors) (string, string, string, error
332332
if err != nil {
333333
return "", "", "", err
334334
}
335-
val64, err := convertToFloat64Slice(val)
335+
val64, err := ml.ConvertToFloat64Slice(val)
336336
if err != nil {
337337
return "", "", "", err
338338
}
@@ -390,7 +390,7 @@ func guessDetectionTensorNames(outMap ml.Tensors) (string, string, string, error
390390
return "", "", "", err
391391
}
392392
}
393-
val, err := convertToFloat64Slice(whole.Data())
393+
val, err := ml.ConvertToFloat64Slice(whole.Data())
394394
if err != nil {
395395
return "", "", "", err
396396
}

services/vision/mlvision/ml_model.go

-84
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@ import (
66
"bufio"
77
"context"
88
"fmt"
9-
"math"
109
"os"
1110
"path/filepath"
1211
"strings"
1312
"sync"
1413

1514
"github.com/pkg/errors"
1615
"go.opencensus.io/trace"
17-
"golang.org/x/exp/constraints"
1816

1917
"go.viam.com/rdk/logging"
2018
"go.viam.com/rdk/ml"
@@ -299,88 +297,6 @@ func getIndex(s []int, num int) int {
299297
return -1
300298
}
301299

302-
// softmax takes the input slice and applies the softmax function.
303-
func softmax(in []float64) []float64 {
304-
out := make([]float64, 0, len(in))
305-
bigSum := 0.0
306-
for _, x := range in {
307-
bigSum += math.Exp(x)
308-
}
309-
for _, x := range in {
310-
out = append(out, math.Exp(x)/bigSum)
311-
}
312-
return out
313-
}
314-
315-
// Number interface for converting between numbers.
316-
type number interface {
317-
constraints.Integer | constraints.Float
318-
}
319-
320-
// convertNumberSlice converts any number slice into another number slice.
321-
func convertNumberSlice[T1, T2 number](t1 []T1) []T2 {
322-
t2 := make([]T2, len(t1))
323-
for i := range t1 {
324-
t2[i] = T2(t1[i])
325-
}
326-
return t2
327-
}
328-
329-
func convertToFloat64Slice(slice interface{}) ([]float64, error) {
330-
switch v := slice.(type) {
331-
case []float64:
332-
return v, nil
333-
case float64:
334-
return []float64{v}, nil
335-
case []float32:
336-
return convertNumberSlice[float32, float64](v), nil
337-
case float32:
338-
return convertNumberSlice[float32, float64]([]float32{v}), nil
339-
case []int:
340-
return convertNumberSlice[int, float64](v), nil
341-
case int:
342-
return convertNumberSlice[int, float64]([]int{v}), nil
343-
case []uint:
344-
return convertNumberSlice[uint, float64](v), nil
345-
case uint:
346-
return convertNumberSlice[uint, float64]([]uint{v}), nil
347-
case []int8:
348-
return convertNumberSlice[int8, float64](v), nil
349-
case int8:
350-
return convertNumberSlice[int8, float64]([]int8{v}), nil
351-
case []int16:
352-
return convertNumberSlice[int16, float64](v), nil
353-
case int16:
354-
return convertNumberSlice[int16, float64]([]int16{v}), nil
355-
case []int32:
356-
return convertNumberSlice[int32, float64](v), nil
357-
case int32:
358-
return convertNumberSlice[int32, float64]([]int32{v}), nil
359-
case []int64:
360-
return convertNumberSlice[int64, float64](v), nil
361-
case int64:
362-
return convertNumberSlice[int64, float64]([]int64{v}), nil
363-
case []uint8:
364-
return convertNumberSlice[uint8, float64](v), nil
365-
case uint8:
366-
return convertNumberSlice[uint8, float64]([]uint8{v}), nil
367-
case []uint16:
368-
return convertNumberSlice[uint16, float64](v), nil
369-
case uint16:
370-
return convertNumberSlice[uint16, float64]([]uint16{v}), nil
371-
case []uint32:
372-
return convertNumberSlice[uint32, float64](v), nil
373-
case uint32:
374-
return convertNumberSlice[uint32, float64]([]uint32{v}), nil
375-
case []uint64:
376-
return convertNumberSlice[uint64, float64](v), nil
377-
case uint64:
378-
return convertNumberSlice[uint64, float64]([]uint64{v}), nil
379-
default:
380-
return nil, errors.Errorf("dont know how to convert slice of %T into a []float64", slice)
381-
}
382-
}
383-
384300
// tensorNames returns all the names of the tensors.
385301
func tensorNames(t ml.Tensors) []string {
386302
names := []string{}

0 commit comments

Comments
 (0)