Skip to content

Commit 806abf8

Browse files
authored
Added Hydra support as an option in ArgParse UI (#302)
* Introducing Hydra in Code-generator * Made some bug fixes and CI changes * Add changes to colab functions for Hydra * fix bug in colab.js * fix bugs in utils.py * Removed unnecessay code in setup_config for hydra * changed colab function behaviour when argparser = hydra * fixed colab output * changed the argparser due to undefined output * final change, removing debug output * Fix bugs for Hydra output-dir and overiding the config files
1 parent df76bf1 commit 806abf8

File tree

20 files changed

+199
-17
lines changed

20 files changed

+199
-17
lines changed

.github/workflows/ci.yml

+4
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,17 @@ jobs:
8484
- run: sh ./scripts/run_tests.sh unzip
8585
- run: pnpm dist_lint ${{ matrix.template }} argparse
8686
- run: pnpm dist_lint ${{ matrix.template }} fire
87+
- run: pnpm dist_lint ${{ matrix.template }} hydra
8788

8889
- name: 'Run ${{ matrix.template }} ${{ matrix.test }}'
8990
run: sh ./scripts/run_tests.sh ${{ matrix.test }} ${{ matrix.template }} argparse
9091

9192
- name: 'Run ${{ matrix.template }} ${{ matrix.test }} - Python Fire'
9293
run: sh ./scripts/run_tests.sh ${{ matrix.test }} ${{ matrix.template }} fire
9394

95+
- name: 'Run ${{ matrix.template }} ${{ matrix.test }} - Hydra'
96+
run: sh ./scripts/run_tests.sh ${{ matrix.test }} ${{ matrix.template }} hydra
97+
9498
lint:
9599
runs-on: ubuntu-latest
96100
steps:

__tests__/text-classification.spec.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ afterEach(async () => {
2424
await context.close()
2525
})
2626

27-
const parser = ['argparse', 'fire']
27+
const parser = ['argparse', 'fire', 'hydra']
2828
for (const name of parser) {
2929
test(`text-classification simple ${name}`, async () => {
3030
await page.selectOption('select', 'template-text-classification')

__tests__/vision-classification.spec.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ afterEach(async () => {
2424
await context.close()
2525
})
2626

27-
const parser = ['argparse', 'fire']
27+
const parser = ['argparse', 'fire', 'hydra']
2828
for (const name of parser) {
2929
test(`vision-classification simple ${name}`, async () => {
3030
await page.selectOption('select', 'template-vision-classification')

__tests__/vision-dcgan.spec.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ afterEach(async () => {
2424
await context.close()
2525
})
2626

27-
const parser = ['argparse', 'fire']
27+
const parser = ['argparse', 'fire', 'hydra']
2828
for (const name of parser) {
2929
test(`vision-dcgan simple ${name}`, async () => {
3030
await page.selectOption('select', 'template-vision-dcgan')

__tests__/vision-segmentation.spec.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ afterEach(async () => {
2424
await context.close()
2525
})
2626

27-
const parser = ['argparse', 'fire']
27+
const parser = ['argparse', 'fire', 'hydra']
2828
for (const name of parser) {
2929
test(`vision-segmentation simple ${name}`, async () => {
3030
await page.selectOption('select', 'template-vision-segmentation')

functions/colab.js

+8-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,14 @@ exports.handler = async function (event, _) {
4848
'!pip install -r requirements.txt'
4949
]
5050

51-
const execution_nb_commands = ['!python main.py config.yaml']
51+
const argparser = data.argparser
52+
const execution_nb_commands = [
53+
`!python main.py ${
54+
argparser === 'hydra'
55+
? '#--config-dir=/content/ --config-name=config.yaml'
56+
: 'config.yaml'
57+
}`
58+
]
5259

5360
let nb_cells = [
5461
create_nb_cell(md_cell, 'markdown'),

scripts/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ transformers
99
datasets
1010
tensorboard
1111
fire
12+
hydra-core
1213
omegaconf

scripts/run_code_style.sh

+5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ if [ $1 == "dist_lint" ]; then
1717
ls ./dist-tests/$TEMP-all-fire/main.py
1818
fi
1919

20+
# for hydra
21+
if [ "$ARGPARSE" == "hydra" ]; then
22+
ls ./dist-tests/$TEMP-all-hydra/main.py
23+
fi
24+
2025
# Comment dist-tests in .gitignore to make black running on ./dist-tests folder
2126
sed -i "s/dist-tests/# dist-tests/g" .gitignore
2227

scripts/run_tests.sh

+49-1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,46 @@ run_spawn_fire() {
9292
done
9393
}
9494

95+
# Hydra test functions
96+
97+
run_simple_hydra() {
98+
for dir in $(find ./dist-tests/$1-simple-hydra -type d)
99+
do
100+
cd $dir
101+
python main.py --config-dir=../../src/tests/ci-configs --config-name=$1-simple.yaml
102+
cd $CWD
103+
done
104+
}
105+
106+
run_all_hydra() {
107+
for dir in $(find ./dist-tests/$1-all-hydra -type d)
108+
do
109+
cd $dir
110+
pytest -vra --color=yes --tb=short test_*.py
111+
python main.py --config-dir=../../src/tests/ci-configs --config-name=$1-all.yaml
112+
cd $CWD
113+
done
114+
}
115+
116+
run_launch_hydra() {
117+
for dir in $(find ./dist-tests/$1-launch-hydra -type d)
118+
do
119+
cd $dir
120+
torchrun --nproc_per_node 2 main.py --config-dir=../../src/tests/ci-configs --config-name=$1-launch.yaml ++backend='gloo'
121+
cd $CWD
122+
done
123+
}
124+
125+
run_spawn_hydra() {
126+
for dir in $(find ./dist-tests/$1-spawn-hydra -type d)
127+
do
128+
cd $dir
129+
python main.py --config-dir=../../src/tests/ci-configs --config-name=$1-spawn.yaml ++backend='gloo'
130+
cd $CWD
131+
done
132+
}
133+
134+
95135

96136
if [ $1 = "unzip" ]; then
97137
unzip_all
@@ -100,23 +140,31 @@ elif [ $1 = "simple" ]; then
100140
run_simple $2
101141
elif [ $3 = "fire" ]; then
102142
run_simple_fire $2
103-
fi
143+
elif [ $3 = "hydra" ]; then
144+
run_simple_hydra $2
145+
fi
104146
elif [ $1 = "all" ]; then
105147
if [ $3 = "argparse" ]; then
106148
run_all $2
107149
elif [ $3 = "fire" ]; then
108150
run_all_fire $2
151+
elif [ $3 = "hydra" ]; then
152+
run_all_hydra $2
109153
fi
110154
elif [ $1 = "launch" ]; then
111155
if [ $3 = "argparse" ]; then
112156
run_launch $2
113157
elif [ $3 = "fire" ]; then
114158
run_launch_fire $2
159+
elif [ $3 = "hydra" ]; then
160+
run_launch_hydra $2
115161
fi
116162
elif [ $1 = "spawn" ]; then
117163
if [ $3 = "argparse" ]; then
118164
run_spawn $2
119165
elif [ $3 = "fire" ]; then
120166
run_spawn_fire $2
167+
elif [ $3 = "hydra" ]; then
168+
run_spawn_hydra $2
121169
fi
122170
fi

src/components/NavColab.vue

+2-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ export default {
9393
body: JSON.stringify({
9494
code: store.code,
9595
template: store.config.template,
96-
config: store.config
96+
config: store.config,
97+
argparser: store.config.argparser
9798
})
9899
})
99100
// response body is plain text

src/components/TabTraining.vue

+14-2
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,25 @@
1010

1111
<ul>
1212
<li>
13-
<a href="https://docs.python.org/3/library/argparse.html">Argparse</a> -
14-
is a python built-in tool to handle command-line arguments
13+
<a href="https://docs.python.org/3/library/argparse.html" id="arg"
14+
>Argparse</a
15+
>
16+
- is a python built-in tool to handle command-line arguments
1517
</li>
1618
<li>
1719
<a
1820
href="https://github.com/google/python-fire/blob/master/docs/guide.md"
21+
id="arg"
1922
>Python Fire</a
2023
>
2124
- transforms Python functions into user-friendly command-line tools,
2225
ideal for DL experimentation.
2326
</li>
27+
<li>
28+
<a href="https://hydra.cc" id="arg">Hydra</a>
29+
- Simplifying deep learning experiments through flexible configuration
30+
management
31+
</li>
2432
</ul>
2533
<FormSelect
2634
:label="argparser.description"
@@ -134,4 +142,8 @@ function saveDistributed(key, value) {
134142
.training {
135143
margin-bottom: 0;
136144
}
145+
146+
#arg {
147+
font-weight: bold;
148+
}
137149
</style>

src/metadata/metadata.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"name": "argparser",
55
"type": "array",
66
"description": "Select the argument parser for training",
7-
"options": ["argparse", "fire"],
7+
"options": ["argparse", "fire", "hydra"],
88
"default": "argparse"
99
},
1010
"deterministic": {

src/templates/template-common/README.md

+39-7
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,16 @@ torchrun \
6868
--node_rank 0 \
6969
--master_addr #:::= it.master_addr :::# \
7070
--master_port #:::= it.master_port :::# \
71-
main.py config.yaml --backend #:::= it.backend :::# \
7271
#::: if ((it.argparser == 'fire')) { :::#
72+
main.py config.yaml --backend #:::= it.backend :::# \
7373
[--override_arg=value]
74-
74+
#::: } else if ((it.argparser == 'hydra')){ :::#
75+
main.py --config-dir=[dir-path] \
76+
--config-name=[config-name] ++backend= #:::= it.backend :::# override_arg=[value]
77+
#::: } else { :::#
78+
main.py config.yaml --backend #:::= it.backend :::#
7579
#::: } :::#
80+
7681
```
7782

7883
- Execute on worker nodes
@@ -84,10 +89,15 @@ torchrun \
8489
--node_rank <node_rank> \
8590
--master_addr #:::= it.master_addr :::# \
8691
--master_port #:::= it.master_port :::# \
87-
main.py config.yaml --backend #:::= it.backend :::# \
8892
#::: if ((it.argparser == 'fire')) { :::#
93+
main.py config.yaml --backend #:::= it.backend :::# \
8994
[--override_arg=value]
90-
95+
#::: } else if ((it.argparser == 'hydra')){ :::#
96+
main.py --config-dir=[dir-path] \
97+
--config-name=[config-name] ++backend= #:::= it.backend :::# \
98+
override_arg=[value]
99+
#::: } else { :::#
100+
main.py config.yaml --backend #:::= it.backend :::#
91101
#::: } :::#
92102
```
93103

@@ -98,10 +108,15 @@ torchrun \
98108
```sh
99109
torchrun \
100110
--nproc_per_node #:::= it.nproc_per_node :::# \
101-
main.py config.yaml --backend #:::= it.backend :::# \
102111
#::: if ((it.argparser == 'fire')) { :::#
112+
main.py config.yaml --backend #:::= it.backend :::# \
103113
[--override_arg=value]
104-
114+
#::: } else if ((it.argparser == 'hydra')){ :::#
115+
main.py --config-dir=[dir-path] \
116+
--config-name=[config-name] ++backend= #:::= it.backend :::# \
117+
override_arg=[value]
118+
#::: } else { :::#
119+
main.py config.yaml --backend #:::= it.backend :::#
105120
#::: } :::#
106121
```
107122

@@ -128,7 +143,9 @@ master_port: #:::= it.master_port :::#
128143
129144
```sh
130145
#::: if ((it.argparser == 'fire')) { :::#
131-
python main.py config.yaml --backend #:::= it.backend :::# [--override_arg=value]
146+
python main.py config.yaml --backend #:::= it.backend :::# --override_arg=[value]
147+
#::: } else if ((it.argparser == 'hydra')){ :::#
148+
python main.py --config-dir=[dir-path] --config-name=[config-name] ++backend= #:::= it.backend :::# override_arg=[value]
132149
#::: } else { :::#
133150
python main.py config.yaml --backend #:::= it.backend :::#
134151
#::: } :::#
@@ -149,6 +166,8 @@ master_port: #:::= it.master_port :::#
149166

150167
#::: if ((it.argparser == 'fire')) { :::#
151168
python main.py config.yaml --backend #:::= it.backend :::# [--override_arg=value]
169+
#::: } else if ((it.argparser == 'hydra')){ :::#
170+
python main.py --config-dir=[dir-path] --config-name=[config-name] ++backend=#:::= it.backend :::# override_arg=[value]
152171
#::: } else { :::#
153172
python main.py config.yaml --backend #:::= it.backend :::#
154173
#::: } :::#
@@ -166,6 +185,10 @@ nproc_per_node: #:::= it.nproc_per_node :::#
166185
```sh
167186
#::: if ((it.argparser == 'fire')) { :::#
168187
python main.py config.yaml --backend #:::= it.backend :::# [--override_arg=value]
188+
189+
#::: } else if ((it.argparser == 'hydra')) { :::#
190+
python main.py --config-dir=[dir-path] --config-name=[config-name] override_arg=[value]
191+
169192
#::: } else { :::#
170193
python main.py config.yaml --backend #:::= it.backend :::#
171194
#::: } :::#
@@ -182,8 +205,13 @@ python main.py config.yaml --backend #:::= it.backend :::#
182205
```sh
183206
#::: if ((it.argparser == 'fire')) { :::#
184207
python main.py config.yaml [--override_arg=value]
208+
209+
#::: } else if ((it.argparser == 'hydra')) { :::#
210+
python main.py --config-dir=[dir-path] --config-name=[config-name] override_arg=[value]
211+
185212
#::: } else { :::#
186213
python main.py config.yaml
214+
187215
#::: } :::#
188216
```
189217

@@ -193,4 +221,8 @@ python main.py config.yaml
193221

194222
Note: We use Python-Fire as the default argument parser here. For more information refer the [docs](https://github.com/google/python-fire/blob/master/docs/guide.md)
195223

224+
#::: } else if ((it.argparser == 'hydra')) { :::#
225+
226+
Note: We use Hydra with [OmegaConfig](https://omegaconf.readthedocs.io/en/2.3_branch/) as the default argument parser here. For more information check the [Hydra docs](https://hydra.cc)
227+
196228
#::: } :::#

src/templates/template-common/main.py

+15
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,17 @@ def main(config_path, backend=None, **kwargs):
66
config = setup_config(config_path, backend, **kwargs)
77

88

9+
#::: } else if ((it.argparser == 'hydra')) { :::#
10+
@hydra.main(version_base=None, config_path=".", config_name="config")
11+
def main(cfg: DictConfig):
12+
config = setup_config(cfg)
13+
14+
915
#::: } else { :::#
1016
def main():
1117
config = setup_config()
1218
#::: } :::#
19+
1320
#::: if (it.dist === 'spawn') { :::#
1421
#::: if (it.nproc_per_node && it.nnodes > 1 && it.master_addr && it.master_port) { :::#
1522
spawn_kwargs = {
@@ -33,6 +40,14 @@ def main():
3340
if __name__ == "__main__":
3441
#::: if ((it.argparser == 'fire')) { :::#
3542
fire.Fire(main)
43+
44+
#::: } else if ((it.argparser == 'hydra')){ :::#
45+
sys.argv.append("hydra.run.dir=.")
46+
sys.argv.append("hydra.output_subdir=null")
47+
sys.argv.append("hydra/job_logging=stdout")
48+
main()
49+
3650
#::: } else { :::#
3751
main()
52+
3853
#::: } :::#

src/templates/template-common/requirements.txt

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ omegaconf
66
#::: if ((it.argparser == 'fire')) { :::#
77
#:::= it.argparser :::#
88

9+
#::: } else if ((it.argparser == 'hydra')) { :::#
10+
#:::= it.argparser + '-core' :::#
11+
912
#::: } :::#
1013

1114
#::: if (['neptune', 'polyaxon'].includes(it.logger)) { :::#

0 commit comments

Comments
 (0)