Skip to content

Commit

Permalink
Control of validation during training is added through yaml file; Bug…
Browse files Browse the repository at this point in the history
… of centroid and texts in depth mask removed
  • Loading branch information
usmanzahidi committed Jan 22, 2025
1 parent 6974241 commit 3811e83
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 11 deletions.
4 changes: 4 additions & 0 deletions config/non_ros_params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
datasets:
train_dataset_name: 'aoc_train_dataset'
test_dataset_name: 'aoc_test_dataset'
validation_dataset_name: 'aoc_validation_dataset'
dataset_train_annotation_url: 'https://lncn.ac/aocanntrain'
dataset_train_images_url: 'https://lncn.ac/aocdatatrain'
dataset_test_annotation_url: 'https://lncn.ac/aocanntest'
Expand All @@ -19,12 +20,14 @@ files:
train_dataset_catalog_file: './data/dataset_catalogs/tom_train_dataset_catalog.pkl'
train_annotation_file: './data/tomato_dataset/train/annotations/ripeness_class_annotations.json'
test_annotation_file: './data/tomato_dataset/test/annotations/ripeness_class_annotations.json'
validation_annotation_file: './data/tomato_dataset/val/annotations/ripeness_class_annotations.json'
model_url: 'https://lncn.ac/aocmodel'
meta_catalog_url: 'https://lncn.ac/aocmeta'
train_catalog_url: 'https://lncn.ac/aoccat'
directories:
train_image_dir: './data/strawberry_dataset/train/'
test_image_dir: './data/bag/rgbd/' #'./data/strawberry_dataset/test/'
validation_image_dir: './data/tomato_dataset/val/'
training_output_dir: './data/training_output/'
prediction_output_dir: './data/prediction_output/test_images/'
prediction_json_dir: './data/annotations/predicted/'
Expand All @@ -40,6 +43,7 @@ settings:
bbox: false
show_orientation: true
fruit_type: 'strawberry' # currently supported for "strawberry" or "tomato"
validation_period: 500 # Smaller validation will increase training time. The value is set to have 100 validation during training



7 changes: 2 additions & 5 deletions scripts/detectron_predictor/detectron_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
import os, pickle, logging,traceback

import numpy
# detectron imports
from detectron2.config import get_cfg
#from detectron2.data import Metadata
from detectron2.engine.defaults import DefaultPredictor
#from detectron2.data.catalog import MetadataCatalog
from detectron2 import model_zoo
from detectron_predictor.json_writer.pycococreator.pycococreatortools.fruit_orientation import FruitTypes

Expand Down Expand Up @@ -174,8 +171,8 @@ def get_predictions_message(self, rgbd_image, image_id=0,fruit_type=FruitTypes.S
colours=self.colours,
category_ids=self.list_category_ids,
masks=self.masks,
show_orientation=self.show_orientation,
fruit_type=fruit_type
show_orientation=False, # UZ: Set to false
fruit_type=fruit_type,
)
drawn_predictions = vis_aoc.draw_instance_predictions(outputs["instances"].to("cpu"))
predicted_image = drawn_predictions.get_image()[:, :, ::-1].copy()
Expand Down
9 changes: 5 additions & 4 deletions scripts/detectron_predictor/visualizer/aoc_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,11 @@ def overlay_instances(
mask=segment.reshape(-1, 2)
theta, centroid,vector,vector2 = FruitOrientation.get_angle_pca(masks[i].mask,self.fruit_type)
height, width = masks[i].mask.shape # Get mask dimensions
scale_factor = min(width, height) / 1500
x,y=centroid
radius = int(10 * scale_factor)
self.draw_polygon(mask, color, alpha=self.alpha,x=x,y=y,radius=radius,theta=theta,scale_factor=scale_factor,vector=vector)
if (self.show_orientation):
scale_factor = min(width, height) / 1500
x,y=centroid
radius = int(10 * scale_factor)
self.draw_polygon(mask, color, alpha=self.alpha,x=x,y=y,radius=radius,theta=theta,scale_factor=scale_factor,vector=vector)

if labels is not None:
# first get a box
Expand Down
10 changes: 9 additions & 1 deletion scripts/detectron_trainer/aoc_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import torch
import detectron2.data.transforms as T
from detectron2.data import DatasetMapper, build_detection_train_loader
from detectron2.evaluation import COCOEvaluator
from detectron2.engine import DefaultTrainer
from detectron2.projects.deeplab import build_lr_scheduler
from detectron2.data import detection_utils as utils
from detectron2.config import CfgNode
from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping
import copy,yaml
import copy,yaml,os

#UZ: extended Default trainer to have methods for augmentation

Expand Down Expand Up @@ -126,3 +127,10 @@ def hsv_convert(cls,dataset_dict):
image = cv2.cvtColor(image, cv2.COLOR_BGR2YUV)
d["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
return dataset_dict

# UZ: Evaluator for validation during training
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
return COCOEvaluator(dataset_name, cfg, True, output_folder)
20 changes: 19 additions & 1 deletion scripts/detectron_trainer/detectron_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,25 @@ def __init__(self, config_data):
# UZ:dataset
self.name_train = config_data['datasets']['train_dataset_name']
self.name_test = config_data['datasets']['test_dataset_name']
self.name_validation = config_data['datasets']['validation_dataset_name']
# UZ:files
self.model_file = config_data['files']['model_file']
self.config_file = config_data['files']['config_file']
self.test_annotation_file = config_data['files']['test_annotation_file']
self.train_annotation_file = config_data['files']['train_annotation_file']
self.validation_annotation_file = config_data['files']['validation_annotation_file']
self.train_dataset_catalog_file = config_data['files']['train_dataset_catalog_file']
self.test_metadata_catalog_file = config_data['files']['test_metadata_catalog_file']
self.pretrained_model = config_data['files']['pretrained_model_file']
# UZ:training
self.num_classes = config_data['training']['number_of_classes']
self.epochs = config_data['training']['epochs']
self.learning_rate = config_data['training']['learning_rate']
self.validation_period = config_data['settings']['validation_period']
# UZ:directories
self.test_image_dir = config_data['directories']['test_image_dir']
self.train_image_dir = config_data['directories']['train_image_dir']
self.validation_image_dir = config_data['directories']['validation_image_dir']
self.download_assets = config_data['settings']['download_assets']

if (self.download_assets):
Expand All @@ -51,6 +55,7 @@ def __init__(self, config_data):
self.cfg = self._configure(self.epochs,self.learning_rate)
self._register_train_dataset()
self._register_test_dataset()
self._register_validation_dataset()

def _configure(self, iterations=10000,
learning_rate=0.0025,num_workers=8,batch_size=8,batch_per_image=512,test_threshold=0.5):
Expand All @@ -70,7 +75,10 @@ def _configure(self, iterations=10000,

cfg.MODEL.ROI_HEADS.NUM_CLASSES = self.num_classes
cfg.DATASETS.TRAIN = (self.name_train,)
cfg.DATASETS.TEST = (self.name_test,)
# UZ: During training TEST setting is actually validation, therefore we use validation dataset
# test dataset is used for evaluation after training is completed
cfg.DATASETS.TEST = (self.name_validation,)
cfg.TEST.EVAL_PERIOD = self.validation_period # evaluation period after which validation will be performed
cfg.DATALOADER.NUM_WORKERS = num_workers
cfg.SOLVER.IMS_PER_BATCH = batch_size
cfg.SOLVER.BASE_LR = learning_rate
Expand Down Expand Up @@ -105,6 +113,16 @@ def _register_train_dataset(self):
if(__debug__): print(traceback.format_exc())
raise Exception(e)

def _register_validation_dataset(self):
try:
validation_dataset_catalog = MetadataCatalog.get(self.name_validation)
validation_dataset_catalog.thing_colors = [(0, 255, 0), (255, 0, 0)]
register_coco_instances(self.name_validation, {}, self.validation_annotation_file, self.validation_image_dir)
except Exception as e:
logging.error(e)
if (__debug__): print(traceback.format_exc())
raise Exception(e)

def train_model(self, resumeType=False,skipTraining=False)->DefaultTrainer:
try:
os.makedirs(self.cfg.OUTPUT_DIR, exist_ok=True)
Expand Down

0 comments on commit 3811e83

Please sign in to comment.