Skip to content

Commit 223a04e

Browse files
authored
Merge pull request #1 from ClarkCGA/dev_datasetfusion
task1_baseline_semseg
2 parents d9bc546 + aa581f1 commit 223a04e

25 files changed

+7973
-8
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
.ipynb_checkpoints/
2+
.jar/
3+
src/**/*.ipynb_checkpoints/
4+
src/**/__pycache__/
5+
notebooks/**/*.ipynb_checkpoints/
6+
output*/
7+
**/.DS_Store

Dockerfile

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,17 @@ RUN pip install --no-cache-dir --upgrade pip pip-tools setuptools
55

66
# Install PyTorch with CUDA support and openCV
77
RUN pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu117
8-
RUN pip install opencv-python
9-
8+
RUN pip install opencv-python-headless
109
# Install pip packages from requirements.txt
1110
COPY requirements.txt .
1211
RUN pip install -r requirements.txt
1312

1413
RUN apt-get --allow-releaseinfo-change update
1514
RUN apt-get --allow-releaseinfo-change-suite update
16-
RUN apt-get update
17-
15+
RUN apt-get update
1816
RUN mkdir /home/workdir
1917
WORKDIR /home/workdir
2018

2119
EXPOSE 8888
2220

23-
ENTRYPOINT ["jupyter", "lab", "--ip='0.0.0.0'", "--port=8888", "--no-browser", "--allow-root"]
24-
#CMD ["/bin/bash"]
21+
ENTRYPOINT ["jupyter", "lab", "--ip='0.0.0.0'", "--port=8888", "--no-browser", "--allow-root"]

README.md

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,70 @@
1-
# gfm-segmentation-baseline
2-
Baseline model for crop type segmentation as part of the GFM downstream task evaluations
1+
# Baseline Model for Segmentation Fine-Tuning of the HLS Foundation Model
2+
This repo contains the code, performance metrics and trained model weights for a supervised CNN model as the baseline for multi-temporal crop type segmentation fine-tuning of the HLS Foundation Model (FM). The FM is released by NASA and IBM [here](https://huggingface.co/ibm-nasa-geospatial), and the fine-tuned FM model for this task is presented [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification). You can also access the training dataset for this task [here](https://huggingface.co/datasets/ibm-nasa-geospatial/multi-temporal-crop-classification).
3+
4+
This project is funded by an award from NASA to the Center for Geospatial Analytics at Clark University.
5+
6+
## Instructions to run the code using Docker:
7+
8+
**Step 1-** Change directory to an empty folder in your machine and clone the repo.
9+
```
10+
$ cd /to_empty/dir/on_host/
11+
$ git clone git@github.com:ClarkCGA/gfm-segmentation-baseline.git
12+
```
13+
14+
**Step 2-** Make sure the Docker daemon is running and build the Docker image as following:
15+
```
16+
docker build -t <image_name>:<tag> .
17+
```
18+
19+
**step 3-** Run the Docker image as a container:
20+
```
21+
docker run --gpus all -it -p 8888:8888 -v <path/to/the/cloned-repo/on-host>:/home/workdir -v <path/to/the/dataset/on-host>:/home/data <image_name>:<tag>
22+
```
23+
24+
This command will start a container based on the specified Docker image and starts a JupyterLab session. Type `localhost:8888` in your browser and copy the provided token from the terminal to open the JupyterLab.
25+
26+
**step 4-** Run the pipeline:
27+
28+
Open the jupyter notebook located at `notebooks/main.ipynb`.
29+
30+
Modify the "default_config.yaml" or create your own config file and run the cells as explained in the notebook.
31+
32+
## Model Weights
33+
The model weights trained on the dataset for 100 epochs with the parameters specified in the "default_config.yaml", is stored in the `model_weights/multi_temporal_crop_classification.pth`. Instructions to load and use the pre-trained model for zero-shot inference or warm-up training is explained in the notebook.
34+
35+
# Evaluation metrics:
36+
![Confusion Matrix](_media/confusion_matrix.png)
37+
38+
## Overall Metrics:
39+
40+
|Metric |Value |
41+
|----------------|--------|
42+
|Overall Accuracy|0.63056 |
43+
|Mean Accuracy |0.61915 |
44+
|Mean IoU |0.42086 |
45+
|mean Precision |0.57392 |
46+
|mean Recall |0.57492 |
47+
|Mean F1 Score |0.57251 |
48+
49+
## Class-wise Metrics:
50+
51+
|Class | Accuracy |IoU |Precision |Recall |F1 Score |
52+
|--------------------|------------|------------|-----------|-------------|------------|
53+
|Natural Vegetation |0.6366 |0.4577 |0.6196 |0.6366 |0.6280 |
54+
|Forest |0.7171 |0.4772 |0.5878 |0.7171 |0.6461 |
55+
|Corn |0.6332 |0.5226 |0.7494 |0.6332 |0.6864 |
56+
|Soybeans |0.6676 |0.51675 |0.6957 |0.6676 |0.6814 |
57+
|Wetlands |0.6035 |0.4109 |0.5628 |0.6035 |0.5825 |
58+
|Developed/Barren |0.6022 |0.4637 |0.6684 |0.6022 |0.6336 |
59+
|Open Water |0.8775 |0.7596 |0.8496 |0.8775 |0.8633 |
60+
|Winter Wheat |0.6639 |0.4950 |0.6606 |0.6639 |0.6622 |
61+
|Alfalfa |0.5902 |0.3847 |0.5250 |0.5902 |0.5557 |
62+
|Fallow/Idle Cropland|0.5293 |0.3599 |0.5292 |0.5293 |0.5293 |
63+
|Cotton |0.4529 |0.3258 |0.5371 |0.4529 |0.4914 |
64+
|Sorghum |0.6152 |0.3909 |0.5174 |0.6152 |0.5621 |
65+
|Other |0.4589 |0.3268 |0.5316 |0.4589 |0.4926 |
66+
67+
68+
69+
70+

_media/confusion_matrix.png

670 KB
Loading

config/default_config.yaml

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
2+
# Custom dataset params
3+
src_dir: /home/data
4+
train_dataset_name: chips_filtered_13_classes_complete
5+
train_csv_path: /home/workdir/train_ids.csv
6+
val_csv_path: /home/workdir/val_ids.csv
7+
test_csv_path: /home/workdir/test_ids.csv
8+
apply_normalization: true
9+
normal_strategy: z_value
10+
stat_procedure: gpb
11+
global_stats:
12+
min: [124.0, 308.0, 191.0, 598.0, 423.0, 271.0]
13+
max: [1207.0, 1765.0, 2366.0, 4945.0, 4646.0, 3897.0]
14+
mean: [494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962, 1739.579917]
15+
std: [284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808]
16+
transformations:
17+
- v_flip
18+
- h_flip
19+
- d_flip
20+
- rotate
21+
aug_params:
22+
rotation_degree: [-180, -90, 90, 180]
23+
24+
# DataLoader
25+
train_BatchSize: 10
26+
val_test_BatchSize: 3
27+
28+
# Model initialization params
29+
n_classes: 14
30+
input_channels: 18
31+
filter_config: [64, 128, 256, 512, 1024, 1024]
32+
use_skipAtt: false
33+
train_dropout_rate: 0.15
34+
35+
# Model compiler params
36+
working_dir: /home/workdir
37+
out_dir: output6
38+
class_mapping:
39+
0: Unknown
40+
1: Natural Vegetation
41+
2: Forest
42+
3: Corn
43+
4: Soybeans
44+
5: Wetlands
45+
6: Developed/Barren
46+
7: Open Water
47+
8: Winter Wheat
48+
9: Alfalfa
49+
10: Fallow/Idle Cropland
50+
11: Cotton
51+
12: Sorghum
52+
13: Other
53+
gpuDevices:
54+
- 0
55+
init_type: kaiming
56+
params_init: None
57+
freeze_params: None
58+
59+
# Model fitting
60+
epochs: 100
61+
optimizer: sam
62+
LR: 0.011
63+
LR_policy: PolynomialLR
64+
criterion:
65+
name: TverskyFocalLoss
66+
weight:
67+
- 0.0182553
68+
- 0.03123664
69+
- 0.02590038
70+
- 0.03026126
71+
- 0.04142966
72+
- 0.04371284
73+
- 0.15352935
74+
- 0.07286951
75+
- 0.10277024
76+
- 0.10736637
77+
- 0.1447082
78+
- 0.17132445
79+
- 0.0566358
80+
ignore_index: 0
81+
gamma: 0.9
82+
83+
momentum: 0.95
84+
checkpoint_interval: 20
85+
resume: false
86+
resume_epoch: None
87+
lr_prams:
88+
# StepLR & MultiStepLR
89+
step_size: 3
90+
milestones:
91+
- 5
92+
- 10
93+
- 20
94+
- 35
95+
- 50
96+
- 70
97+
- 90
98+
gamma: 0.98
99+
# ReduceLROnPlateau
100+
mode: triangular
101+
factor: 0.8
102+
patience: 3
103+
threshold: 0.0001
104+
threshold_mode: rel
105+
min_lr: 3.0e-06
106+
# PolynomialLR
107+
max_decay_steps: 80
108+
min_learning_rate: 1.0e-04
109+
power: 0.85
110+
# CyclicLR
111+
base_lr: 3.0e-05
112+
max_lr: 0.01
113+
step_size_up: 1100
114+
115+
# Accuracy assessment
116+
val_metric_fname: validate_metrics_global_z_gpb.csv
117+
118+
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:08b0b8c2d91b22cd3de8e79b95ea76e0a4c5e4ec66b606e71bfda98a39095f6d
3+
size 413633539

0 commit comments

Comments
 (0)