forked from viamrobotics/rdk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon.go
87 lines (76 loc) · 2.9 KB
/
common.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
package app
import (
mltrainingpb "go.viam.com/api/app/mltraining/v1"
)
// Constants used throughout app.
const (
UploadChunkSize = 64 * 1024 // UploadChunkSize is 64 KB
)
// Types used throughout app.
// ModelType specifies the type of model used for classification or detection.
type ModelType int
const (
// ModelTypeUnspecified represents an unspecified model.
ModelTypeUnspecified ModelType = iota
// ModelTypeSingleLabelClassification represents a single-label classification model.
ModelTypeSingleLabelClassification
// ModelTypeMultiLabelClassification represents a multi-label classification model.
ModelTypeMultiLabelClassification
// ModelTypeObjectDetection represents an object detection model.
ModelTypeObjectDetection
)
// ModelFramework is the framework type of a model.
type ModelFramework int
const (
// ModelFrameworkUnspecified is an unspecified model framework.
ModelFrameworkUnspecified ModelFramework = iota
// ModelFrameworkTFLite specifies a TFLite model framework.
ModelFrameworkTFLite
// ModelFrameworkTensorFlow specifies a TensorFlow model framework.
ModelFrameworkTensorFlow
// ModelFrameworkPyTorch specifies a PyTorch model framework.
ModelFrameworkPyTorch
// ModelFrameworkONNX specifies a ONNX model framework.
ModelFrameworkONNX
)
func modelTypeFromProto(modelType mltrainingpb.ModelType) ModelType {
switch modelType {
case mltrainingpb.ModelType_MODEL_TYPE_UNSPECIFIED:
return ModelTypeUnspecified
case mltrainingpb.ModelType_MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION:
return ModelTypeSingleLabelClassification
case mltrainingpb.ModelType_MODEL_TYPE_MULTI_LABEL_CLASSIFICATION:
return ModelTypeMultiLabelClassification
case mltrainingpb.ModelType_MODEL_TYPE_OBJECT_DETECTION:
return ModelTypeObjectDetection
}
return ModelTypeUnspecified
}
func modelTypeToProto(modelType ModelType) mltrainingpb.ModelType {
switch modelType {
case ModelTypeUnspecified:
return mltrainingpb.ModelType_MODEL_TYPE_UNSPECIFIED
case ModelTypeSingleLabelClassification:
return mltrainingpb.ModelType_MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION
case ModelTypeMultiLabelClassification:
return mltrainingpb.ModelType_MODEL_TYPE_MULTI_LABEL_CLASSIFICATION
case ModelTypeObjectDetection:
return mltrainingpb.ModelType_MODEL_TYPE_OBJECT_DETECTION
}
return mltrainingpb.ModelType_MODEL_TYPE_UNSPECIFIED
}
func modelFrameworkFromProto(framework mltrainingpb.ModelFramework) ModelFramework {
switch framework {
case mltrainingpb.ModelFramework_MODEL_FRAMEWORK_UNSPECIFIED:
return ModelFrameworkUnspecified
case mltrainingpb.ModelFramework_MODEL_FRAMEWORK_TFLITE:
return ModelFrameworkTFLite
case mltrainingpb.ModelFramework_MODEL_FRAMEWORK_TENSORFLOW:
return ModelFrameworkTensorFlow
case mltrainingpb.ModelFramework_MODEL_FRAMEWORK_PYTORCH:
return ModelFrameworkPyTorch
case mltrainingpb.ModelFramework_MODEL_FRAMEWORK_ONNX:
return ModelFrameworkONNX
}
return ModelFrameworkUnspecified
}