Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ source .nlgw/bin/activate
Now we can install `poetry`

```bash
pip install poetry
pip install "poetry<2.0.0"
```

The following command installs all of the necessary dependencies for `nonlocal_gwfluxes`.
Expand Down
4 changes: 2 additions & 2 deletions era5_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ test-data/
### Ann

```bash
python inference.py -M ann -d global -v global -f uvthetaw -e 8 -m 1 -s 1 -t era5 -i inputs/ -c model-huggingface/ -o outputs/ --script
python inference.py -M ann -d global -v global -f uvthetaw -e 85 -m 1 -s 1 -t era5 -i inputs/ -c model-huggingface/ -o outputs/ --script
```

This will generate some test data and a torchscripted model, to be used by `infer.f90` and `infer.py` later on.
Expand Down Expand Up @@ -86,7 +86,7 @@ python infer.py -M ann -t test-data/ -s .
To test the newly generate torchscript models, use the following command:

```bash
bash compile-and-run.sh intel
bash compile-and-run.sh gcc
```

This will compile `infer.f90` into `infer.exe`. This requires having cuda installed on your system. It also requires `ftorch` to
Expand Down
41 changes: 29 additions & 12 deletions era5_training/batch_ann.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash -l
#PBS -N 1x1_uvthw
#PBS -N scripting
#PBS -A USTN0009
#PBS -l select=1:ncpus=4:ngpus=1:mem=80GB
#PBS -l walltime=01:00:00
Expand Down Expand Up @@ -33,19 +33,36 @@ source ~/nonlocal_gwfluxes/.nlgw/bin/activate
# -o /glade/derecho/scratch/agupta/torch_saved_models/


#python inference.py \
# -M attention \
# -d global \
# -v global \
# -f uvthetaw \
# -e 119 \
# -m 1 \
# -s 3 \
# -t era5 \
# -i /glade/derecho/scratch/agupta/era5_training_data/ \
# -c /glade/derecho/scratch/agupta/hugging_face_checkpoints/ \
# -o /glade/derecho/scratch/agupta/gw_inference_files/


python inference.py \
-M attention \
-d global \
-v global \
-f uvthetaw \
-e 119 \
-m 1 \
-s 3 \
-t era5 \
-i /glade/derecho/scratch/agupta/era5_training_data/ \
-c /glade/derecho/scratch/agupta/hugging_face_checkpoints/ \
-o /glade/derecho/scratch/agupta/gw_inference_files/
-M ann \
-d global \
-v global \
-f uvthetaw \
-e 70 \
-s 1 \
-t era5 \
-m 1 \
-i inputs/ \
-c model-huggingface/ \
-o outputs/ \
--script


#python inference.py -M ann -d global -v global -f uvthetaw -e 85 -m 1 -s 1 -t era5 -i /glade/derecho/scratch/agupta/new_training_data/ -c /glade/derecho/scratch/agupta/hugging_face_checkpoints/ -o /glade/derecho/scratch/agupta/gw_inference_files/ --script



Expand Down
44 changes: 25 additions & 19 deletions era5_training/batch_unet.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,31 @@ source ~/nonlocal_gwfluxes/.nlgw/bin/activate
#python training_attention_unet.py stratosphere_only uvthetawN2


python training.py \
-M attention \
-d global \
-v stratosphere_update \
-f uvw \
-i /glade/derecho/scratch/agupta/era5_training_data/ \
-o /glade/derecho/scratch/agupta/torch_saved_models/


#python inference.py \
# -M attention \
# -d global \
# -v stratosphere_update \
# -f uvw \
# -e 100 \
# -s 1 \
# -t era5 \
# -m 1 \
# -i /glade/derecho/scratch/agupta/era5_training_data/ \
#python training.py \
# -M attention \
# -d global \
# -v stratosphere_update \
# -f uvw \
# -i /glade/derecho/scratch/agupta/era5_training_data/ \
# -o /glade/derecho/scratch/agupta/torch_saved_models/


python inference.py \
-M attention \
-d global \
-v global \
-f uvthetaw \
-e 100 \
-s 1 \
-t era5 \
-m 1 \
-i inputs/ \
-c model-huggingface/ \
-o outputs/ \
--script


# -i /glade/derecho/scratch/agupta/era5_training_data/ \
# -c /glade/derecho/scratch/agupta/torch_saved_models/ \
# -o /glade/derecho/scratch/agupta/gw_inference_files/

Expand Down
36 changes: 8 additions & 28 deletions era5_training/compile-and-run.sh
Original file line number Diff line number Diff line change
@@ -1,34 +1,15 @@
COMP=$1
FC=ifort
FFLAGS=""

if [[ ${COMP} == "intel" ]]; then
FC=ifort
FFLAGS=""

# source /glade/u/home/tmeltzer/cam-test/debug_env.sh

module purge
module load cesmdev/1.0 ncarenv/23.06 craype/2.7.20 linaro-forge/23.0 intel/2023.0.0 mkl/2023.0.0
module load ncarcompilers/1.0.0 cmake/3.26.3 cray-mpich/8.1.25 hdf5-mpi/1.12.2
module load netcdf-mpi/4.9.2 parallel-netcdf/1.12.3 parallelio/2.6.2-debug esmf/8.6.0b04-debug
elif [[ ${COMP} == "gcc" ]]; then

FC=gfortran
FFLAGS="-ffree-line-length-none"

module purge
module load ncarenv/24.12 gcc/12.4.0 cmake cuda/12.3.2 netcdf/4.9.3
else
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[0;33m'
NC='\033[0m' # No Color
echo -e "${RED}ERROR:${YELLOW} required option missing. Please specify [${GREEN}gcc${YELLOW}] or [${GREEN}intel${YELLOW}] as compiler.${NC}"
exit 1
fi
module --force purge
# these come from the environment listed in software_environment.txt in the CESM Case directory
module load cesmdev/1.0 ncarenv/23.06 craype/2.7.20 intel/2023.0.0 mkl/2023.0.0 ncarcompilers/1.0.0
module load cmake/3.26.3 cray-mpich/8.1.25 hdf5-mpi/1.12.2 netcdf-mpi/4.9.2 parallel-netcdf/1.12.3
module load parallelio/2.6.2 esmf/8.6.0b04

source ../.nlgw/bin/activate

FTORCH_ROOT="/glade/u/home/tmeltzer/FTorch/bin/ftorch_${COMP}"
FTORCH_ROOT="${HOME}/fresh/ftorch-install"
NETCDF_LIB="${NETCDF}/lib"
export LD_LIBRARY_PATH="${NETCDF_LIB}:${FTORCH_ROOT}/lib64:${LD_LIBRARY_PATH}"

Expand All @@ -45,7 +26,6 @@ echo $COMMAND

${COMMAND}

# gdb -q --args ./infer.exe attention test-data/ .
./infer.exe attention test-data/ .
echo
echo "========================================="
Expand Down
6 changes: 4 additions & 2 deletions era5_training/get-model-and-data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ mkdir -p inputs

echo "retrieving model weights..."
cd model-huggingface
wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch8.pt
wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/retrained_ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch85.pt
wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch94.pt
wget https://huggingface.co/amangupta2/iccs_coupling_checkpoints/resolve/main/attnunet_era5_global_global_uvthetaw_mseloss_train_epoch119.pt
cd ..

mv model-huggingface/retrained_ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch85.pt model-huggingface/ann_cnn_1x1_global_global_era5_uvthetaw__train_epoch85.pt

echo "retrieving test input..."
(cd inputs && wget https://g-b56e81.7a577b.6fbd.data.globus.org/1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2010_constant_mu_sigma_scaling01.nc)
(cd inputs && wget https://g-b56e81.7a577b.6fbd.data.globus.org/1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling01.nc)
3 changes: 2 additions & 1 deletion era5_training/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def main():
model = torch.jit.load(model_path)

# run model inference
pred = model(torch.tensor(input_data).to(device))
with torch.no_grad():
pred = model(torch.tensor(input_data).to(device))

pred = pred.cpu().detach().numpy()
print("pred.shape = ", pred.shape)
Expand Down
18 changes: 12 additions & 6 deletions era5_training/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
print(f"output_dir={args.output_dir}")
print(f"script={args.script}")

bs_train = 20 # 80 (80 works for most). (does not work for global uvthetaw)
bs_train = 5 # 20 # 80 (80 works for most). (does not work for global uvthetaw)
bs_test = bs_train

# --------------------------------------------------
Expand Down Expand Up @@ -136,11 +136,13 @@
odir = str(args.output_dir) + "/"
pref = str(args.ckpt_dir) + "/" # "/scratch/users/ag4680/torch_saved_models/attention_unet/"
if model == "ann":
ckpt = f"ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_era5_{features}__train_epoch{epoch}.pt"
# ckpt = f"retrained_ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_era5_{features}__train_epoch{epoch}.pt"
ckpt = f"retrained_L93_ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_era5_{features}__train_epoch{epoch}.pt"
log_filename = f"./{teston}_inference_ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_{features}_ckpt_epoch_{epoch}.txt"
elif model == "attention":
ckpt = (
f"attnunet_era5_{domain}_{vertical}_{features}_mseloss_train_epoch{str(epoch).zfill(2)}.pt"
# f"attnunet_era5_{domain}_{vertical}_{features}_mseloss_train_epoch{str(epoch).zfill(2)}.pt"
f"retrained_L93_attnunet_era5_{domain}_{vertical}_{features}_mseloss_train_epoch{epoch}.pt"
)
log_filename = (
f"./{teston}_inference_attnunet_{domain}_{vertical}_{features}_ckpt_epoch_{epoch}.txt"
Expand All @@ -157,7 +159,7 @@
# Define test files
# ------- To test on one year of ERA5 data
test_files = []
test_years = np.array([2010])
test_years = np.array([2015])
test_month = args.month # int(sys.argv[4]) # np.arange(1,13)
logger.info(f"Inference for month {test_month}")
if teston == "era5":
Expand All @@ -174,7 +176,7 @@
)
elif vertical == "global" or vertical == "stratosphere_update":
if stencil == 1:
pre = idir + f"1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_"
pre = idir + f"1x1_inputfeatures_u_v_theta_w_uw_vw_gcp_era5_training_data_hourly_"
else:
pre = (
idir
Expand All @@ -183,7 +185,10 @@

for year in test_years:
for months in np.arange(test_month, test_month + 1):
test_files.append(f"{pre}{year}_constant_mu_sigma_scaling{str(months).zfill(2)}.nc")
# test_files.append(f"{pre}{year}_constant_mu_sigma_scaling{str(months).zfill(2)}.nc") # usual
test_files.append(
f"{pre}{year}_L93_constant_mu_sigma_scaling{str(months).zfill(2)}.nc"
) # L93

elif teston == "ifs":
if vertical == "stratosphere_only":
Expand Down Expand Up @@ -219,6 +224,7 @@
)

idim = testset.idim

odim = testset.odim
hdim = 4 * idim

Expand Down
32 changes: 19 additions & 13 deletions utils/dataloader_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,16 @@ def __init__(self, files, domain, vertical, stencil, manual_shuffle, features, r
if self.vertical == "global":
# 122 channels for each feature
if self.features == "uvtheta":
self.v = np.arange(0, 369) # for u,v,theta
# self.v = np.arange(0, 369) # for u,v,theta
self.v = np.arange(0, 282) # for L93
elif self.features == "uvthetaw":
self.v = np.arange(0, 491) # for u,v,theta,w
# self.v = np.arange(0, 551) # for u,v,theta,w
self.v = np.arange(0, 375) # for L93
elif self.features == "uvw":
self.v = np.concatenate(
(np.arange(0, 247), np.arange(369, 491)), axis=0
) # for u,v,w
# self.v = np.concatenate(
# (np.arange(0, 247), np.arange(369, 551)), axis=0
# ) # for u,v,w
self.v = np.concatenate((np.arange(0, 189), np.arange(282, 375)), axis=0) # for L93
self.w = np.arange(0, self.odim) # all vertical channels

elif self.vertical == "stratosphere_only":
Expand Down Expand Up @@ -86,7 +89,7 @@ def __init__(self, files, domain, vertical, stencil, manual_shuffle, features, r
self.v = np.arange(0, 491) # for u,v,theta,w
elif self.features == "uvw":
self.v = np.concatenate(
(np.arange(0, 247), np.arange(369, 491)), axis=0
(np.arange(0, 247), np.arange(369, 551)), axis=0
) # for u,v,w
self.w = np.concatenate(
(np.arange(0, 60), np.arange(122, 182)), axis=0
Expand Down Expand Up @@ -296,13 +299,16 @@ def __init__(self, files, domain, vertical, manual_shuffle, features, region="1a
if self.vertical == "global":
# 122 channels for each feature
if self.features == "uvtheta":
self.v = np.arange(3, 369) # for u,v,theta
self.v = np.arange(3, 282) # for L93
# self.v = np.arange(3, 369) # for u,v,theta
elif self.features == "uvthetaw":
self.v = np.arange(3, 491) # for u,v,theta,w
self.v = np.arange(3, 375) # for L93
# self.v = np.arange(3, 551) # for u,v,theta,w
elif self.features == "uvw":
self.v = np.concatenate(
(np.arange(3, 247), np.arange(369, 491)), axis=0
) # for u,v,w
self.v = np.concatenate((np.arange(3, 189), np.arange(282, 375)), axis=0) # for L93
# self.v = np.concatenate(
# (np.arange(3, 247), np.arange(369, 551)), axis=0
# ) # for u,v,w
self.w = np.arange(0, self.odim) # all vertical channels

elif self.vertical == "stratosphere_only":
Expand All @@ -328,10 +334,10 @@ def __init__(self, files, domain, vertical, manual_shuffle, features, region="1a
if self.features == "uvtheta":
self.v = np.arange(3, 369) # for u,v,theta
elif self.features == "uvthetaw":
self.v = np.arange(3, 491) # for u,v,theta,w
self.v = np.arange(3, 551) # for u,v,theta,w
elif self.features == "uvw":
self.v = np.concatenate(
(np.arange(3, 247), np.arange(369, 491)), axis=0
(np.arange(3, 247), np.arange(369, 551)), axis=0
) # for u,v,w
self.w = np.concatenate(
(np.arange(0, 60), np.arange(122, 182)), axis=0
Expand Down
8 changes: 5 additions & 3 deletions utils/function_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def Inference_and_Save_ANN_CNN(
INP = INP.reshape(T[0] * T[1], T[2], T[3], T[4])
T = OUT.shape
OUT = OUT.reshape(T[0] * T[1], -1)
PRED = model(INP)

with torch.no_grad():
PRED = model(INP)

if is_script:
print("saving data...")
Expand All @@ -205,7 +207,7 @@ def Inference_and_Save_ANN_CNN(
xdata.to_netcdf(f"test-data/ann-cnn-{k}.nc")

print("scripting...")
script_to_torchscript(model, filename="nlgw_ann-cnn_gpu_scripted.pt")
script_to_torchscript(model, filename=f"nlgw_ann-cnn_{device}_scripted.pt")
print("complete")

S = PRED.shape
Expand Down Expand Up @@ -386,7 +388,7 @@ def Inference_and_Save_AttentionUNet(
model.eval()
count = 0
for i, (INP, OUT) in enumerate(testloader):
# print([i,count])
# print([i, count])
INP = INP.to(device)
S = OUT.shape
o_output[count : count + S[0], :, :, :] = OUT[
Expand Down
Loading