-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DOCS-3423: Update training script to parse labels based on feedback from etai/tahiya #3941
Open
sguequierre
wants to merge
15
commits into
viamrobotics:main
Choose a base branch
from
sguequierre:DOCS-3423/feedback
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
cd40828
Update training script
sguequierre 8f316e2
more training script updates
sguequierre 14a8aea
more updates
sguequierre ee3016b
Apply suggestions from code review
sguequierre 20b19dd
Apply suggestions from code review
sguequierre 2cefc01
correct line
sguequierre 2a13f0e
fix flake8
sguequierre d163ff8
fixup
sguequierre 47280a9
add template without labels parsed
sguequierre 4afee29
markdown parser fixup
sguequierre fd84cd3
data line fixup
sguequierre 9e10be8
Apply suggestions from code review
sguequierre d619328
Apply suggestions from code review
sguequierre dc5f434
Apply suggestions from code review
sguequierre 4c57132
Apply suggestions from code review
sguequierre File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,7 +63,7 @@ my-training/ | |
|
||
Add the following code to `setup.py` and add additional required packages on line 11: | ||
|
||
```python {class="line-numbers linkable-line-numbers" data-line="11"} | ||
```python {class="line-numbers linkable-line-numbers" data-line="9"} | ||
from setuptools import find_packages, setup | ||
|
||
setup( | ||
|
@@ -72,8 +72,6 @@ setup( | |
packages=find_packages(), | ||
include_package_data=True, | ||
install_requires=[ | ||
"google-cloud-aiplatform", | ||
"google-cloud-storage", | ||
# TODO: Add additional required packages | ||
], | ||
) | ||
|
@@ -90,15 +88,18 @@ If you haven't already, create a folder called <file>model</file> and create an | |
|
||
<p><strong>4. Add <code>training.py</code> code</strong></p> | ||
|
||
<p>Copy this template into <file>training.py</file>:</p> | ||
<p>You can set up your training script to use a hard coded set of labels or allow users to pass in a set of labels when using the training script. Allowing users to pass in labels when using training scripts makes your training script more flexible for reuse.</p> | ||
<p>Copy one of the following templates into <file>training.py</file>, depending on how you want to handle labels:</p> | ||
|
||
{{% expand "Click to see the template" %}} | ||
{{% expand "Click to see the template without parsing labels (recommended for use with UI)" %}} | ||
|
||
```python {class="line-numbers linkable-line-numbers" data-line="126,170" } | ||
```python {class="line-numbers linkable-line-numbers" data-line="134" } | ||
import argparse | ||
import json | ||
import os | ||
import typing as ty | ||
from tensorflow.keras import Model # Add proper import | ||
import tensorflow as tf # Add proper import | ||
|
||
single_label = "MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION" | ||
multi_label = "MODEL_TYPE_MULTI_LABEL_CLASSIFICATION" | ||
|
@@ -108,23 +109,29 @@ unknown_label = "UNKNOWN" | |
API_KEY = os.environ['API_KEY'] | ||
API_KEY_ID = os.environ['API_KEY_ID'] | ||
|
||
DEFAULT_EPOCHS = 200 | ||
|
||
# This parses the required args for the training script. | ||
# The model_dir variable will contain the output directory where | ||
# the ML model that this script creates should be stored. | ||
# The data_json variable will contain the metadata for the dataset | ||
# that you should use to train the model. | ||
|
||
|
||
def parse_args(): | ||
"""Returns dataset file, model output directory, and num_epochs if present. | ||
These must be parsed as command line arguments and then used as the model | ||
input and output, respectively. The number of epochs can be used to | ||
optionally override the default. | ||
"""Returns dataset file, model output directory, and num_epochs | ||
if present. These must be parsed as command line arguments and then used | ||
as the model input and output, respectively. The number of epochs can be | ||
used to optionally override the default. | ||
""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--dataset_file", dest="data_json", type=str) | ||
parser.add_argument("--model_output_directory", dest="model_dir", type=str) | ||
parser.add_argument("--dataset_file", dest="data_json", | ||
type=str, required=True) | ||
parser.add_argument("--model_output_directory", dest="model_dir", | ||
type=str, required=True) | ||
parser.add_argument("--num_epochs", dest="num_epochs", type=int) | ||
args = parser.parse_args() | ||
|
||
return args.data_json, args.model_dir, args.num_epochs | ||
|
||
|
||
|
@@ -250,12 +257,17 @@ def save_model( | |
model_dir: output directory for model artifacts | ||
model_name: name of saved model | ||
""" | ||
file_type = "" | ||
|
||
# Save the model to the output directory. | ||
# Save the model to the output directory | ||
file_type = "tflite" # Add proper file type | ||
filename = os.path.join(model_dir, f"{model_name}.{file_type}") | ||
|
||
# Example: Convert to TFLite | ||
converter = tf.lite.TFLiteConverter.from_keras_model(model) | ||
tflite_model = converter.convert() | ||
|
||
# Save the model | ||
with open(filename, "wb") as f: | ||
f.write(model) | ||
f.write(tflite_model) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
@@ -273,14 +285,244 @@ if __name__ == "__main__": | |
image_filenames, image_labels = parse_filenames_and_labels_from_json( | ||
DATA_JSON, LABELS, model_type) | ||
|
||
# Validate epochs | ||
epochs = ( | ||
DEFAULT_EPOCHS if NUM_EPOCHS is None | ||
or NUM_EPOCHS <= 0 else int(NUM_EPOCHS) | ||
) | ||
|
||
# Build and compile model on data | ||
model = build_and_compile_model() | ||
model = build_and_compile_model(image_labels, model_type, IMG_SIZE + (3,)) | ||
|
||
# Save labels.txt file | ||
save_labels(LABELS + [unknown_label], MODEL_DIR) | ||
# Convert the model to tflite | ||
save_model( | ||
model, MODEL_DIR, "classification_model", IMG_SIZE + (3,) | ||
model, MODEL_DIR, "classification_model" | ||
) | ||
``` | ||
|
||
{{% /expand %}} | ||
|
||
{{% expand "Click to see the template with parsed labels" %}} | ||
|
||
```python {class="line-numbers linkable-line-numbers" data-line="148" } | ||
import argparse | ||
import json | ||
import os | ||
import typing as ty | ||
from tensorflow.keras import Model # Add proper import | ||
import tensorflow as tf # Add proper import | ||
|
||
single_label = "MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION" | ||
multi_label = "MODEL_TYPE_MULTI_LABEL_CLASSIFICATION" | ||
labels_filename = "labels.txt" | ||
unknown_label = "UNKNOWN" | ||
|
||
API_KEY = os.environ['API_KEY'] | ||
API_KEY_ID = os.environ['API_KEY_ID'] | ||
|
||
DEFAULT_EPOCHS = 200 | ||
|
||
# This parses the required args for the training script. | ||
# The model_dir variable will contain the output directory where | ||
# the ML model that this script creates should be stored. | ||
# The data_json variable will contain the metadata for the dataset | ||
# that you should use to train the model. | ||
|
||
|
||
def parse_args(): | ||
"""Returns dataset file, model output directory, labels, and num_epochs | ||
if present. These must be parsed as command line arguments and then used | ||
as the model input and output, respectively. The number of epochs can be | ||
used to optionally override the default. | ||
""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--dataset_file", dest="data_json", | ||
type=str, required=True) | ||
parser.add_argument("--model_output_directory", dest="model_dir", | ||
type=str, required=True) | ||
parser.add_argument("--num_epochs", dest="num_epochs", type=int) | ||
parser.add_argument( | ||
"--labels", | ||
dest="labels", | ||
type=str, | ||
required=True, | ||
help="Space-separated list of labels, \ | ||
enclosed in single quotes (e.g., 'label1 label2').", | ||
) | ||
args = parser.parse_args() | ||
|
||
if not args.labels: | ||
raise ValueError("Labels must be provided") | ||
|
||
labels = [label.strip() for label in args.labels.strip("'").split()] | ||
return args.data_json, args.model_dir, args.num_epochs, labels | ||
|
||
|
||
# This is used for parsing the dataset file (produced and stored in Viam), | ||
# parse it to get the label annotations | ||
# Used for training classifiction models | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. flake8 parser was complaining about this |
||
def parse_filenames_and_labels_from_json( | ||
filename: str, all_labels: ty.List[str], model_type: str | ||
) -> ty.Tuple[ty.List[str], ty.List[str]]: | ||
"""Load and parse JSON file to return image filenames and corresponding | ||
labels. The JSON file contains lines, where each line has the key | ||
"image_path" and "classification_annotations". | ||
Args: | ||
filename: JSONLines file containing filenames and labels | ||
all_labels: list of all N_LABELS | ||
model_type: string single_label or multi_label | ||
""" | ||
image_filenames = [] | ||
image_labels = [] | ||
|
||
with open(filename, "rb") as f: | ||
for line in f: | ||
json_line = json.loads(line) | ||
image_filenames.append(json_line["image_path"]) | ||
|
||
annotations = json_line["classification_annotations"] | ||
labels = [unknown_label] | ||
for annotation in annotations: | ||
if model_type == multi_label: | ||
if annotation["annotation_label"] in all_labels: | ||
labels.append(annotation["annotation_label"]) | ||
# For single label model, we want at most one label. | ||
# If multiple valid labels are present, we arbitrarily select | ||
# the last one. | ||
if model_type == single_label: | ||
if annotation["annotation_label"] in all_labels: | ||
labels = [annotation["annotation_label"]] | ||
image_labels.append(labels) | ||
return image_filenames, image_labels | ||
|
||
|
||
# Parse the dataset file (produced and stored in Viam) to get | ||
# bounding box annotations | ||
# Used for training object detection models | ||
def parse_filenames_and_bboxes_from_json( | ||
filename: str, | ||
all_labels: ty.List[str], | ||
) -> ty.Tuple[ty.List[str], ty.List[str], ty.List[ty.List[float]]]: | ||
"""Load and parse JSON file to return image filenames | ||
and corresponding labels with bboxes. | ||
Args: | ||
filename: JSONLines file containing filenames and bboxes | ||
all_labels: list of all N_LABELS | ||
""" | ||
image_filenames = [] | ||
bbox_labels = [] | ||
bbox_coords = [] | ||
|
||
with open(filename, "rb") as f: | ||
for line in f: | ||
json_line = json.loads(line) | ||
image_filenames.append(json_line["image_path"]) | ||
annotations = json_line["bounding_box_annotations"] | ||
labels = [] | ||
coords = [] | ||
for annotation in annotations: | ||
if annotation["annotation_label"] in all_labels: | ||
labels.append(annotation["annotation_label"]) | ||
# Store coordinates in rel_yxyx format so that | ||
# we can use the keras_cv function | ||
coords.append( | ||
[ | ||
annotation["y_min_normalized"], | ||
annotation["x_min_normalized"], | ||
annotation["y_max_normalized"], | ||
annotation["x_max_normalized"], | ||
] | ||
) | ||
bbox_labels.append(labels) | ||
bbox_coords.append(coords) | ||
return image_filenames, bbox_labels, bbox_coords | ||
|
||
|
||
# Build the model | ||
def build_and_compile_model( | ||
labels: ty.List[str], model_type: str, input_shape: ty.Tuple[int, int, int] | ||
) -> Model: | ||
"""Builds and compiles a model | ||
Args: | ||
labels: list of string lists, where each string list contains up to | ||
N_LABEL labels associated with an image | ||
model_type: string single_label or multi_label | ||
input_shape: 3D shape of input | ||
""" | ||
|
||
# TODO: Add logic to build and compile model | ||
|
||
return model | ||
|
||
|
||
def save_labels(labels: ty.List[str], model_dir: str) -> None: | ||
"""Saves a label.txt of output labels to the specified model directory. | ||
Args: | ||
labels: list of string lists, where each string list contains up to | ||
N_LABEL labels associated with an image | ||
model_dir: output directory for model artifacts | ||
""" | ||
filename = os.path.join(model_dir, labels_filename) | ||
with open(filename, "w") as f: | ||
for label in labels[:-1]: | ||
f.write(label + "\n") | ||
f.write(labels[-1]) | ||
|
||
|
||
def save_model( | ||
model: Model, | ||
model_dir: str, | ||
model_name: str, | ||
) -> None: | ||
"""Save model as a TFLite model. | ||
Args: | ||
model: trained model | ||
model_dir: output directory for model artifacts | ||
model_name: name of saved model | ||
""" | ||
# Save the model to the output directory | ||
file_type = "tflite" # Add proper file type | ||
sguequierre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
filename = os.path.join(model_dir, f"{model_name}.{file_type}") | ||
|
||
# Example: Convert to TFLite | ||
converter = tf.lite.TFLiteConverter.from_keras_model(model) | ||
tflite_model = converter.convert() | ||
|
||
# Save the model | ||
with open(filename, "wb") as f: | ||
f.write(tflite_model) | ||
|
||
|
||
if __name__ == "__main__": | ||
DATA_JSON, MODEL_DIR, NUM_EPOCHS, LABELS = parse_args() | ||
|
||
IMG_SIZE = (256, 256) | ||
|
||
# Read dataset file. | ||
# The model type can be changed based on whether you want the model to | ||
# output one label per image or multiple labels per image | ||
model_type = multi_label | ||
image_filenames, image_labels = parse_filenames_and_labels_from_json( | ||
DATA_JSON, LABELS, model_type) | ||
|
||
# Validate epochs | ||
epochs = ( | ||
DEFAULT_EPOCHS if NUM_EPOCHS is None | ||
or NUM_EPOCHS <= 0 else int(NUM_EPOCHS) | ||
) | ||
|
||
# Build and compile model on data | ||
model = build_and_compile_model(image_labels, model_type, IMG_SIZE + (3,)) | ||
|
||
# Save labels.txt file | ||
save_labels(LABELS + [unknown_label], MODEL_DIR) | ||
# Convert the model to tflite | ||
save_model( | ||
model, MODEL_DIR, "classification_model" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. correction, was never being used |
||
) | ||
``` | ||
|
||
|
@@ -300,6 +542,10 @@ The script you are creating must take the following command line inputs: | |
- `dataset_file`: a file containing the data and metadata for the training job | ||
- `model_output_directory`: the location where the produced model artifacts are saved to | ||
|
||
If you used the training script template that allows users to pass in labels, it will also take the following command line inputs: | ||
|
||
- `labels`: space separated list of labels, enclosed in single quotes | ||
|
||
The `parse_args()` function in the template parses your arguments. | ||
|
||
You can add additional custom command line inputs by adding them to the `parse_args()` function. | ||
|
@@ -547,6 +793,11 @@ In the Viam app, navigate to your list of [**DATASETS**](https://app.viam.com/da | |
|
||
Click **Train model** and select **Train on a custom training script**, then follow the prompts. | ||
|
||
{{% alert title="Tip" color="tip" %}} | ||
If you used the version of <file>training.py</file> that allows users to pass in labels, your training job will fail with the error `ERROR training.py: error: the following arguments are required: --labels`. | ||
To use labels, you must use the CLI. | ||
{{% /alert %}} | ||
|
||
{{% /tab %}} | ||
{{% tab name="CLI" %}} | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
know this is awk but python markdown parser was complaining about this specifically