Skip to content

Commit 7ef78f0

Browse files
committed
Add isort precommit
precommit에 isort도 추가하였습니다.
1 parent 6af118b commit 7ef78f0

File tree

13 files changed

+111
-43
lines changed

13 files changed

+111
-43
lines changed

.pre-commit-config.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,8 @@ repos:
44
hooks:
55
- id: black
66
language_version: python3
7+
- repo: https://github.com/PyCQA/isort
8+
rev: 5.6.4
9+
hooks:
10+
- id: isort
11+
language_version: python3

app/api/router/predict.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
from typing import List
33

4-
54
import numpy as np
65
from fastapi import APIRouter
76
from starlette.concurrency import run_in_threadpool
@@ -12,12 +11,13 @@
1211
from app.utils import ScikitLearnModel, my_model
1312
from logger import L
1413

15-
1614
models.Base.metadata.create_all(bind=engine)
1715

1816

1917
router = APIRouter(
20-
prefix="/predict", tags=["predict"], responses={404: {"description": "Not Found"}}
18+
prefix="/predict",
19+
tags=["predict"],
20+
responses={404: {"description": "Not Found"}},
2121
)
2222

2323

app/api/router/train.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,21 @@
33
import re
44
import subprocess
55

6-
76
from fastapi import APIRouter
87

9-
from app.utils import NniWatcher, ExperimentOwl, base_dir, get_free_port, write_yml
8+
from app.utils import (
9+
ExperimentOwl,
10+
NniWatcher,
11+
base_dir,
12+
get_free_port,
13+
write_yml,
14+
)
1015
from logger import L
1116

1217
router = APIRouter(
13-
prefix="/train", tags=["train"], responses={404: {"description": "Not Found"}}
18+
prefix="/train",
19+
tags=["train"],
20+
responses={404: {"description": "Not Found"}},
1421
)
1522

1623

@@ -44,14 +51,18 @@ def train_insurance(
4451
try:
4552
write_yml(path, experiment_name, experimenter, model_name, version)
4653
nni_create_result = subprocess.getoutput(
47-
"nnictl create --port {} --config {}/{}.yml".format(PORT, path, model_name)
54+
"nnictl create --port {} --config {}/{}.yml".format(
55+
PORT, path, model_name
56+
)
4857
)
4958
sucs_msg = "Successfully started experiment!"
5059

5160
if sucs_msg in nni_create_result:
5261
p = re.compile(r"The experiment id is ([a-zA-Z0-9]+)\n")
5362
expr_id = p.findall(nni_create_result)[0]
54-
nni_watcher = NniWatcher(expr_id, experiment_name, experimenter, version)
63+
nni_watcher = NniWatcher(
64+
expr_id, experiment_name, experimenter, version
65+
)
5566
m_process = multiprocessing.Process(target=nni_watcher.excute)
5667
m_process.start()
5768

@@ -68,6 +79,7 @@ def train_atmos(expr_name: str):
6879
"""
6980
온도 시계열과 관련된 학습을 실행하기 위한 API입니다.
7081
82+
7183
Args:
7284
expr_name(str): NNI가 실행할 실험의 이름 입니다. 이 파라미터를 기반으로 project_dir/experiments/[expr_name] 경로로 찾아가 config.yml을 이용하여 NNI를 실행합니다.
7385
@@ -83,7 +95,9 @@ def train_atmos(expr_name: str):
8395

8496
try:
8597
nni_create_result = subprocess.getoutput(
86-
"nnictl create --port {} --config {}/config.yml".format(nni_port, expr_path)
98+
"nnictl create --port {} --config {}/config.yml".format(
99+
nni_port, expr_path
100+
)
87101
)
88102
sucs_msg = "Successfully started experiment!"
89103

app/database.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import os
22

3-
43
from dotenv import load_dotenv
54
from sqlalchemy import create_engine
6-
from sqlalchemy.orm import sessionmaker
75
from sqlalchemy.ext.declarative import declarative_base
6+
from sqlalchemy.orm import sessionmaker
87

98
load_dotenv(verbose=True)
109

app/models.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
# -*- coding: utf-8 -*-
22
import datetime
33

4-
5-
from sqlalchemy import Column, Integer, String, FLOAT, DateTime, ForeignKey, LargeBinary
6-
from sqlalchemy.sql.functions import now
4+
from sqlalchemy import (
5+
FLOAT,
6+
Column,
7+
DateTime,
8+
ForeignKey,
9+
Integer,
10+
LargeBinary,
11+
String,
12+
)
713
from sqlalchemy.orm import relationship
14+
from sqlalchemy.sql.functions import now
815

916
from app.database import Base
1017

app/query.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,4 @@
9999
WHERE NOT EXISTS (SELECT 1
100100
FROM atmos_model_metadata as amm
101101
WHERE amm.model_name = '{mn}');
102-
"""
102+
"""

app/utils.py

+34-11
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import time
1212
import zipfile
1313

14-
1514
import tensorflow as tf
1615
import yaml
1716

@@ -239,7 +238,9 @@ def watch_process(self):
239238
if self.is_kill:
240239
while True:
241240
self.get_running_experiment()
242-
if self._running_experiment and ("DONE" in self._running_experiment[0]):
241+
if self._running_experiment and (
242+
"DONE" in self._running_experiment[0]
243+
):
243244
_stop_expr = subprocess.getoutput(
244245
"nnictl stop {}".format(self.experiment_id)
245246
)
@@ -284,7 +285,9 @@ def model_final_update(self):
284285

285286
if saved_result is None:
286287
engine.execute(
287-
INSERT_MODEL_CORE.format(final_result.model_name, pickled_model)
288+
INSERT_MODEL_CORE.format(
289+
final_result.model_name, pickled_model
290+
)
288291
)
289292
engine.execute(
290293
INSERT_MODEL_METADATA.format(
@@ -303,7 +306,9 @@ def model_final_update(self):
303306
> final_result[self.evaluation_criteria]
304307
):
305308
engine.execute(
306-
UPDATE_MODEL_CORE.format(pickled_model, saved_result.model_name)
309+
UPDATE_MODEL_CORE.format(
310+
pickled_model, saved_result.model_name
311+
)
307312
)
308313
engine.execute(
309314
UPDATE_MODEL_METADATA.format(
@@ -315,7 +320,9 @@ def model_final_update(self):
315320
)
316321
)
317322

318-
engine.execute(DELETE_ALL_EXPERIMENTS_BY_EXPR_NAME.format(self.experiment_name))
323+
engine.execute(
324+
DELETE_ALL_EXPERIMENTS_BY_EXPR_NAME.format(self.experiment_name)
325+
)
319326

320327

321328
def zip_model(model_path):
@@ -401,7 +408,12 @@ class ExperimentOwl:
401408
"""
402409

403410
def __init__(
404-
self, experiment_id, experiment_name, experiment_path, mfile_manage=True, time=5
411+
self,
412+
experiment_id,
413+
experiment_name,
414+
experiment_path,
415+
mfile_manage=True,
416+
time=5,
405417
):
406418
self.__minute = 60
407419
self.time = time * self.__minute
@@ -434,7 +446,9 @@ def main(self):
434446
expr_list = subprocess.getoutput("nnictl experiment list")
435447

436448
running_expr = [
437-
expr for expr in expr_list.split("\n") if self.experiment_id in expr
449+
expr
450+
for expr in expr_list.split("\n")
451+
if self.experiment_id in expr
438452
]
439453
print(running_expr)
440454
if running_expr and ("DONE" in running_expr[0]):
@@ -486,17 +500,23 @@ def update_tfmodeldb(self):
486500

487501
if not saved_score or (metrics[0] < saved_score[0]):
488502
winner_model = os.path.join(
489-
os.path.join(self.experiment_path, "temp", self.experiment_name)
503+
os.path.join(
504+
self.experiment_path, "temp", self.experiment_name
505+
)
490506
)
491507
if os.path.exists:
492508
shutil.rmtree(winner_model)
493509
os.rename(exprs, winner_model)
494510

495511
m_buffer = zip_model(winner_model)
496-
encode_model = codecs.encode(pickle.dumps(m_buffer), "base64").decode()
512+
encode_model = codecs.encode(
513+
pickle.dumps(m_buffer), "base64"
514+
).decode()
497515

498516
engine.execute(
499-
INSERT_OR_UPDATE_MODEL.format(mn=self.experiment_name, mf=encode_model)
517+
INSERT_OR_UPDATE_MODEL.format(
518+
mn=self.experiment_name, mf=encode_model
519+
)
500520
)
501521
engine.execute(
502522
INSERT_OR_UPDATE_SCORE.format(
@@ -506,7 +526,10 @@ def update_tfmodeldb(self):
506526
score2=metrics[1],
507527
)
508528
)
509-
L.info("saved model %s %s" % (self.experiment_id, self.experiment_name))
529+
L.info(
530+
"saved model %s %s"
531+
% (self.experiment_id, self.experiment_name)
532+
)
510533

511534
def modelfile_cleaner(self):
512535
"""

experiments/atmos_tmp_01/train.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,21 @@
11
import os
22
import sys
33
import time
4+
45
from preprocessing import preprocess
56

67
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
78

8-
import numpy as np
99
import nni
10+
import numpy as np
1011
import pandas as pd
1112
import tensorflow as tf
13+
from expr_db import connect
14+
from sklearn.metrics import mean_absolute_error, mean_squared_error
1215
from tensorflow import keras
13-
from tensorflow.keras.models import Sequential
14-
from tensorflow.keras.layers import Dense
1516
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
16-
from tensorflow.keras.layers import GRU
17-
from sklearn.metrics import mean_absolute_error, mean_squared_error
18-
19-
20-
from expr_db import connect
17+
from tensorflow.keras.layers import GRU, Dense
18+
from tensorflow.keras.models import Sequential
2119

2220
physical_devices = tf.config.list_physical_devices("GPU")
2321
if physical_devices:
@@ -32,7 +30,9 @@ def make_dataset(data, label, window_size=365, predsize=None):
3230
for i in range(len(data) - (window_size + predsize)):
3331
feature_list.append(np.array(data.iloc[i : i + window_size]))
3432
label_list.append(
35-
np.array(label.iloc[i + window_size : i + window_size + predsize])
33+
np.array(
34+
label.iloc[i + window_size : i + window_size + predsize]
35+
)
3636
)
3737
else:
3838
for i in range(len(data) - window_size):

experiments/expr_db.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
2-
from dotenv import load_dotenv
2+
33
import sqlalchemy
4+
from dotenv import load_dotenv
45

56

67
def connect(db="postgres"):

experiments/insurance/trial.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pickle
55
import sys
66

7-
87
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
98

109

@@ -41,7 +40,9 @@ def preprocess(x_train, x_valid, col_list):
4140
encoder = LabelEncoder()
4241

4342
for col in col_list:
44-
tmp_x_train.loc[:, col] = encoder.fit_transform(tmp_x_train.loc[:, col])
43+
tmp_x_train.loc[:, col] = encoder.fit_transform(
44+
tmp_x_train.loc[:, col]
45+
)
4546
tmp_x_valid.loc[:, col] = encoder.transform(tmp_x_valid.loc[:, col])
4647

4748
return tmp_x_train.values, tmp_x_valid.values
@@ -87,7 +88,12 @@ def main(params, engine, experiment_info, connection):
8788
model = XGBRegressor(**params)
8889

8990
# 모델 학습 및 Early Stopping 적용
90-
model.fit(x_tra, y_train, eval_set=[(x_val, y_valid)], early_stopping_rounds=10)
91+
model.fit(
92+
x_tra,
93+
y_train,
94+
eval_set=[(x_val, y_valid)],
95+
early_stopping_rounds=10,
96+
)
9197

9298
y_train_pred = model.predict(x_tra)
9399
y_valid_pred = model.predict(x_val)

logger.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import logging.handlers
3-
from colorlog import ColoredFormatter
43

4+
from colorlog import ColoredFormatter
55

66
L = logging.getLogger("snowdeer_log")
77
L.setLevel(logging.DEBUG)

main.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1+
import uvicorn
12
from fastapi import FastAPI
23
from fastapi.middleware.cors import CORSMiddleware
3-
import uvicorn
4-
54

65
from app.api.router import predict, train
76

@@ -28,5 +27,9 @@ def hello_world():
2827

2928
if __name__ == "__main__":
3029
uvicorn.run(
31-
"main:app", host="0.0.0.0", port=8000, reload=True, reload_dirs=["app/"]
30+
"main:app",
31+
host="0.0.0.0",
32+
port=8000,
33+
reload=True,
34+
reload_dirs=["app/"],
3235
)

pyproject.toml

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[tool.isort]
2+
multi_line_output=3
3+
include_trailing_comma=true
4+
force_grid_wrap=0
5+
use_parentheses=true
6+
line_length=79
7+
8+
[tool.black]
9+
line-length = 79
10+
target-version = ['py38']

0 commit comments

Comments
 (0)