Early-Bird GCNs: Graph-Network Co-Optimization Towards More Efficient GCN Training and Inference via Drawing Early-Bird Lottery Tickets
Haoran You, Zhihan Lu, Zijian Zhou, Yonggan Fu, Yingyan Lin
Accepted by AAAI 2022. More Info: [ Paper | Appendix | Slide | Poster | Video | Github ]
conda env create -f env.yaml
pip install torch_geometric
pip uninstall torch-scatter
pip install torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
pip uninstall torch-sparse
pip install torch-sparse==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
pip uninstall torch-cluster
pip install torch-cluster==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
pip uninstall torch-spline-conv
pip install torch-spline-conv==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
-
To pretrain, prune, retrain separately:
-
Pretrain the GCN:
-
python3 pytorch_train.py --epochs 10 --dataset Cora
-
Prune the pretrained GCN using different prune method:
-
python3 pytorch_prune_weight_iterate.py --ratio_graph 60 --ratio_weight 60 # or python3 pytorch_prune_weight_cotrain.py --ratio_graph 60 --ratio_weight 60 # or python3 pytorch_prune_weight_first.py --ratio_graph 60 --ratio_weight 60
-
Retrain the pruned GCN to recover the accuracy:
-
python3 pytorch_retrain_with_graph.py --load_path prune_weight_iterate/model.pth.tar
-
-
By using functions like
os.system("python3 "+"pytorch_train.py"+" --epochs "+str(1)+" --dataset "+str(args.dataset))
in Python, we are able to run the above process in one file, and can stop automatically when found jointEB ticket:-
python run_threshold_jointEB.py --times 100 --epochs 1 --dataset Cora --ratio_graph 20 --ratio_weight 50
-
-
Futhermore, we use a script to run all experiment settings (like different pruning ratio of graph and pruning ratio of weights) automatically:
-
python test_jointEB_dist_traj.py
-
More details are coming soon.