Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 8th No.11】DrivAerNet++ 论文复现 #1062

Open
wants to merge 24 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8c3438a
Update ReduceOnPlateau lr_scheduler.md
LilaKen Jan 10, 2025
4ca8511
Merge branch 'PaddlePaddle:develop' into develop
LilaKen Jan 12, 2025
a3768f7
support drivaernetplusplus
LilaKen Jan 12, 2025
2a6d9dd
amend older error arch.md
LilaKen Jan 12, 2025
b9caca5
amend some question
LilaKen Jan 14, 2025
2d2d13f
update regpintnet.py
LilaKen Jan 14, 2025
e798771
amend regpointnet and drivaernetplusplus.py
LilaKen Jan 15, 2025
7c98038
update code and format markdown
LilaKen Jan 25, 2025
3b872d2
update code and format markdown
LilaKen Jan 25, 2025
0e0acc0
update code and format markdown
LilaKen Jan 25, 2025
657b0ab
update md
LilaKen Jan 25, 2025
79facfa
update code and format markdown
LilaKen Jan 25, 2025
7125eaa
Merge branch 'develop' into DrivAerNet++
LilaKen Jan 25, 2025
4714ac8
Update drivaernetplusplus_dataset.py
LilaKen Jan 26, 2025
beb8e38
Update examples/drivaernetplusplus/drivaernetplusplus.py
LilaKen Jan 27, 2025
bb45ad0
Update examples/drivaernetplusplus/drivaernetplusplus.py
LilaKen Jan 27, 2025
917ea4e
Update examples/drivaernetplusplus/drivaernetplusplus.py
LilaKen Jan 27, 2025
9c0cc68
move data_augmentation to load dataset file
LilaKen Jan 27, 2025
a47e529
sysc the apply_augmentation in markdown file
LilaKen Jan 27, 2025
a3f6a67
sysc the apply_augmentation in markdown file
LilaKen Jan 27, 2025
18c0be5
pre-commit file drivaernetplusplus.py
LilaKen Jan 27, 2025
33c73bd
Update docs/zh/examples/drivaernetplusplus.md
LilaKen Feb 10, 2025
7a2a852
Merge branch 'develop' into DrivAerNet++
LilaKen Feb 10, 2025
40c61d9
update drivaernetplusplus.md
LilaKen Feb 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/zh/api/arch.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
- LNO
- TGCN
- RegDGCNN
- RegPointNet
- IFMMLP
show_root_heading: true
heading_level: 3
1 change: 1 addition & 0 deletions docs/zh/api/data/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@
- CGCNNDataset
- PEMSDataset
- DrivAerNetDataset
- DrivAerNetPlusPlusDataset
- IFMMoeDataset
show_root_heading: true
832 changes: 832 additions & 0 deletions docs/zh/examples/drivaernetplusplus.md
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

整个文档是否能用 vscode 的markdown插件格式化一下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已格式化

Large diffs are not rendered by default.

77 changes: 77 additions & 0 deletions examples/drivaernetplusplus/conf/drivaernetplusplus.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
defaults:
- ppsci_default
- TRAIN: train_default
- TRAIN/ema: ema_default
- TRAIN/swa: swa_default
- EVAL: eval_default
- INFER: infer_default
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
- _self_

hydra:
run:
dir: outputs_drivaernetplusplus/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
job:
name: ${mode}
chdir: false
callbacks:
init_callback:
_target_: ppsci.utils.callbacks.InitCallback
sweep:
dir: ${hydra.run.dir}
subdir: ./

# general settings
mode: eval
seed: 1
output_dir: ${hydra:run.dir}
log_freq: 100

# model settings
MODEL:
input_keys: ["vertices"]
output_keys: ["cd_value"]
weight_keys: ["weight_keys"]
dropout: 0.0
emb_dims: 1024
channels: [6, 64, 128, 256, 512, 1024]
linear_sizes: [128, 64, 32, 16]
k: 40
output_channels: 1

# training settings
TRAIN:
iters_per_epoch: 5399
epochs: 200
num_points: 100000
num_workers: 32
eval_during_train: True
train_ids_file: "train_design_ids.txt"
eval_ids_file: "val_design_ids.txt"
batch_size: 32
scheduler:
mode: "min"
patience: 20
factor: 0.1
verbose: True

# evaluation settings
EVAL:
num_points: 100000
batch_size: 1
pretrained_model_path: "https://dataset.bj.bcebos.com/PaddleScience/DNNFluid-Car/DrivAer%2B%2B/DragPrediction_DrivAerNet_PointNet_r2_batchsize16_200epochs_100kpoints_tsne_NeurIPS_best_model.pdparams"
eval_with_no_grad: True
ids_file: "test_design_ids.txt"
num_workers: 8

# optimizer settings
optimizer:
weight_decay: 0.0001
lr: 0.001
optimizer: 'adam'

ARGS:
# dataset settings
dataset_path: 'data/DrivAerNetPlusPlus_Processed_Point_Clouds_100k_paddle'
aero_coeff: 'data/DrivAerNetPlusPlus_Drag_8k.csv'
subset_dir: 'data/subset_dir'
200 changes: 200 additions & 0 deletions examples/drivaernetplusplus/drivaernetplusplus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import warnings
from functools import partial

import hydra
import paddle
from omegaconf import DictConfig

import ppsci


def train(cfg: DictConfig):
# set model
model = ppsci.arch.RegPointNet(
input_keys=cfg.MODEL.input_keys,
label_keys=cfg.MODEL.output_keys,
weight_keys=cfg.MODEL.weight_keys,
args=cfg.MODEL,
)

train_dataloader_cfg = {
"dataset": {
"name": "DrivAerNetPlusPlusDataset",
"root_dir": cfg.ARGS.dataset_path,
"input_keys": cfg.MODEL.input_keys,
"label_keys": cfg.MODEL.output_keys,
"weight_keys": cfg.MODEL.weight_keys,
"subset_dir": cfg.ARGS.subset_dir,
"ids_file": cfg.TRAIN.train_ids_file,
"csv_file": cfg.ARGS.aero_coeff,
"num_points": cfg.TRAIN.num_points,
},
"batch_size": cfg.TRAIN.batch_size,
"num_workers": cfg.TRAIN.num_workers,
}

drivaernetplusplus_constraint = ppsci.constraint.SupervisedConstraint(
train_dataloader_cfg,
ppsci.loss.MSELoss("mean"),
name="DrivAerNetplusplus_constraint",
)

constraint = {drivaernetplusplus_constraint.name: drivaernetplusplus_constraint}

valid_dataloader_cfg = {
"dataset": {
"name": "DrivAerNetPlusPlusDataset",
"root_dir": cfg.ARGS.dataset_path,
"input_keys": cfg.MODEL.input_keys,
"label_keys": cfg.MODEL.output_keys,
"weight_keys": cfg.MODEL.weight_keys,
"subset_dir": cfg.ARGS.subset_dir,
"ids_file": cfg.TRAIN.eval_ids_file,
"csv_file": cfg.ARGS.aero_coeff,
"num_points": cfg.TRAIN.num_points,
},
"batch_size": cfg.TRAIN.batch_size,
"num_workers": cfg.TRAIN.num_workers,
}

drivaernetplusplus_valid = ppsci.validate.SupervisedValidator(
valid_dataloader_cfg,
loss=ppsci.loss.MSELoss("mean"),
metric={"MSE": ppsci.metric.MSE()},
name="DrivAerNetplusplus_valid",
)

validator = {drivaernetplusplus_valid.name: drivaernetplusplus_valid}

# set optimizer
lr_scheduler = ppsci.optimizer.lr_scheduler.ReduceOnPlateau(
epochs=cfg.TRAIN.epochs,
iters_per_epoch=(
cfg.TRAIN.iters_per_epoch
// (paddle.distributed.get_world_size() * cfg.TRAIN.batch_size)
+ 1
),
learning_rate=cfg.optimizer.lr,
mode=cfg.TRAIN.scheduler.mode,
patience=cfg.TRAIN.scheduler.patience,
factor=cfg.TRAIN.scheduler.factor,
verbose=cfg.TRAIN.scheduler.verbose,
)()

optimizer = (
ppsci.optimizer.Adam(lr_scheduler, weight_decay=cfg.optimizer.weight_decay)(
model
)
if cfg.optimizer.optimizer == "adam"
else ppsci.optimizer.SGD(lr_scheduler, weight_decay=cfg.optimizer.weight_decay)(
model
)
)

# initialize solver
solver = ppsci.solver.Solver(
model=model,
iters_per_epoch=(
cfg.TRAIN.iters_per_epoch
// (paddle.distributed.get_world_size() * cfg.TRAIN.batch_size)
+ 1
),
constraint=constraint,
output_dir=cfg.output_dir,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epochs=cfg.TRAIN.epochs,
validator=validator,
eval_during_train=cfg.TRAIN.eval_during_train,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)

lr_scheduler.step = partial(lr_scheduler.step, metrics=solver.cur_metric)
solver.lr_scheduler = lr_scheduler

# train model
solver.train()

solver.eval()


def evaluate(cfg: DictConfig):
# set model
model = ppsci.arch.RegPointNet(
input_keys=cfg.MODEL.input_keys,
label_keys=cfg.MODEL.output_keys,
weight_keys=cfg.MODEL.weight_keys,
args=cfg.MODEL,
)

valid_dataloader_cfg = {
"dataset": {
"name": "DrivAerNetPlusPlusDataset",
"root_dir": cfg.ARGS.dataset_path,
"input_keys": cfg.MODEL.input_keys,
"label_keys": cfg.MODEL.output_keys,
"weight_keys": cfg.MODEL.weight_keys,
"subset_dir": cfg.ARGS.subset_dir,
"ids_file": cfg.EVAL.ids_file,
"csv_file": cfg.ARGS.aero_coeff,
"num_points": cfg.EVAL.num_points,
},
"batch_size": cfg.EVAL.batch_size,
"num_workers": cfg.EVAL.num_workers,
}

drivaernetplusplus_valid = ppsci.validate.SupervisedValidator(
valid_dataloader_cfg,
loss=ppsci.loss.MSELoss("mean"),
metric={
"MSE": ppsci.metric.MSE(),
"MAE": ppsci.metric.MAE(),
"Max AE": ppsci.metric.MaxAE(),
"R²": ppsci.metric.R2Score(),
},
name="DrivAerNetPlusPlus_valid",
)

validator = {drivaernetplusplus_valid.name: drivaernetplusplus_valid}

solver = ppsci.solver.Solver(
model=model,
validator=validator,
pretrained_model_path=cfg.EVAL.pretrained_model_path,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)

# evaluate model
solver.eval()


@hydra.main(
version_base=None, config_path="./conf", config_name="drivaernetplusplus.yaml"
)
def main(cfg: DictConfig):
warnings.filterwarnings("ignore")
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
main()
Loading