|
| 1 | +import os |
| 2 | +import time |
| 3 | +import argparse |
| 4 | +import importlib |
| 5 | +import multiprocessing |
| 6 | +import logging |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +from definitions import WEIGHT_DIR |
| 10 | + |
| 11 | +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
| 12 | + |
| 13 | +from utils import dataset # noqa: E402 |
| 14 | +from utils.voting import voting # noqa: E402 |
| 15 | + |
| 16 | +from train import initLog, get_optimizer, run as run_training # noqa: E402 |
| 17 | + |
| 18 | +LOG = logging.getLogger(__name__) |
| 19 | + |
| 20 | + |
| 21 | +def run( |
| 22 | + model: str, |
| 23 | + voting_tag: str, |
| 24 | + voting_times: int, |
| 25 | + train_ds_path: str, |
| 26 | + val_ds_path: str, |
| 27 | + test_ds_paths: list, |
| 28 | + test_add_retrain_sizes: list, |
| 29 | + test_retrain_times: int = 1, |
| 30 | + test_retrain_has_random: bool = True, |
| 31 | + classes=2, # 分類類別 |
| 32 | + sample_size=[32000, 1], # 訓練音訊頻率 |
| 33 | + epochs=160, |
| 34 | + batch_size=150, |
| 35 | + lr=1.0, # learning rate |
| 36 | + optimizer='adadelta', |
| 37 | + loss='categorical_crossentropy', |
| 38 | + metrics=['accuracy'], |
| 39 | + num_gpus=2, # number of gpus |
| 40 | + debug: bool = False, |
| 41 | + explainable=False, |
| 42 | + filter_x=45, |
| 43 | + filter_y=120, |
| 44 | + magnification=4, |
| 45 | + seed=None, |
| 46 | + use_saved_inital_weight=False, |
| 47 | + enabled_transfer_learning=False, |
| 48 | + verbose=1, |
| 49 | + skip_origin=False |
| 50 | +): |
| 51 | + initLog(debug) |
| 52 | + |
| 53 | + import tensorflow as tf |
| 54 | + |
| 55 | + Model = importlib.import_module(f'models.{model}').__getattribute__(model) |
| 56 | + start = time.time() |
| 57 | + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(i) for i in range(num_gpus)]) |
| 58 | + input_shape = tuple(sample_size) |
| 59 | + |
| 60 | + voting_tags = [f'{voting_tag}-{i}' for i in range(voting_times)] |
| 61 | + |
| 62 | + tag = voting_tag.replace('.', '_').replace(' ', '_').replace('/', '_').replace('\\', '_') |
| 63 | + LOG.info(f'Model: {tag}') |
| 64 | + mgr = multiprocessing.Manager() |
| 65 | + # Testing ------------------------------------------------------------------------------------------ |
| 66 | + cls_results = {s: mgr.list() for s in test_ds_paths} |
| 67 | + cls_results['ground_truth'] = {} |
| 68 | + total_acc = mgr.list([0 for _ in test_ds_paths]) |
| 69 | + acc_list = mgr.list([mgr.list() for _ in test_ds_paths]) |
| 70 | + |
| 71 | + LOG.info('Run test') |
| 72 | + for index, _tag in enumerate(voting_tags): |
| 73 | + |
| 74 | + def test(): |
| 75 | + strategy = tf.distribute.MirroredStrategy(devices=[f'/gpu:{i}' for i in range(num_gpus)]) |
| 76 | + with strategy.scope(): |
| 77 | + _model = Model(input_shape, classes).model() |
| 78 | + _model.compile(loss=loss, optimizer=get_optimizer(optimizer, lr), metrics=metrics) |
| 79 | + _model.load_weights(os.path.join(WEIGHT_DIR, _tag + '.h5')) |
| 80 | + |
| 81 | + # Evaluation |
| 82 | + for i, test_ds_path in enumerate(test_ds_paths): |
| 83 | + test_ds = dataset.load(test_ds_path).batch(batch_size) |
| 84 | + score, acc = _model.evaluate(test_ds, verbose=0) |
| 85 | + result = _model.predict(test_ds) |
| 86 | + cls_results[test_ds_path].append(np.where(result >= 0.5, 1, 0)) |
| 87 | + acc_list[i].append(acc) |
| 88 | + total_acc[i] += acc |
| 89 | + LOG.debug(f'no.{index + 1}, score={score}, acc={acc}') |
| 90 | + del test_ds |
| 91 | + del _model |
| 92 | + |
| 93 | + p = multiprocessing.Process(target=test) |
| 94 | + p.start() |
| 95 | + p.join() |
| 96 | + |
| 97 | + if not skip_origin: |
| 98 | + training_kwargs = { |
| 99 | + 'test_ds_paths': [ds.replace('train', 'test') for ds in test_ds_paths], |
| 100 | + 'times': test_retrain_times, |
| 101 | + 'tag': train_ds_path, |
| 102 | + 'classes': classes, |
| 103 | + 'sample_size': sample_size, |
| 104 | + 'epochs': epochs, |
| 105 | + 'batch_size': batch_size, |
| 106 | + 'lr': lr, |
| 107 | + 'optimizer': optimizer, |
| 108 | + 'loss': loss, |
| 109 | + 'metrics': metrics, |
| 110 | + 'num_gpus': num_gpus, |
| 111 | + 'training': True, |
| 112 | + 'seed': seed, |
| 113 | + 'use_saved_inital_weight': use_saved_inital_weight, |
| 114 | + 'verbose': verbose |
| 115 | + } |
| 116 | + |
| 117 | + p = multiprocessing.Process(target=run_training, args=(model, train_ds_path, val_ds_path), kwargs=training_kwargs) |
| 118 | + p.start() |
| 119 | + p.join() |
| 120 | + |
| 121 | + for i, test_ds_path in enumerate(test_ds_paths): |
| 122 | + LOG.info(f"Dataset {test_ds_path}") |
| 123 | + for index in range(voting_times): |
| 124 | + LOG.info(f"第{index+1}次正確率:{acc_list[i][index]:.4f}") |
| 125 | + average_acc = total_acc[i] / len(acc_list[i]) |
| 126 | + LOG.info(f"Average_acc: {average_acc*100:.6f}%") |
| 127 | + |
| 128 | + # Dataset must have train and test set |
| 129 | + |
| 130 | + real_test_ds_path = test_ds_path.replace('train', 'test') |
| 131 | + ground_truth = np.array(dataset.get_ground_truth(test_ds_path)) |
| 132 | + cls_results['ground_truth'][test_ds_path] = ground_truth |
| 133 | + voting_acc, _, voting_rate_list = voting(cls_results[test_ds_path], ground_truth, f'{tag}_{test_ds_path}') |
| 134 | + LOG.info(f"Voting_acc: {voting_acc*100:.6f}%") |
| 135 | + voting_rate_list = np.array(sum(voting_rate_list, [])[::-1]) |
| 136 | + LOG.info(f"Voting_rate_list_size: {len(voting_rate_list)}") |
| 137 | + LOG.info(f"Voting_unconfident_top_10: {voting_rate_list[:10]}") |
| 138 | + |
| 139 | + for length in test_add_retrain_sizes: |
| 140 | + if length > voting_rate_list.shape[0]: |
| 141 | + break |
| 142 | + if test_retrain_has_random: |
| 143 | + # base line training |
| 144 | + training_kwargs = { |
| 145 | + 'test_ds_paths': [real_test_ds_path], |
| 146 | + 'train_ds_size': length, |
| 147 | + 'times': test_retrain_times, |
| 148 | + 'tag': f'{test_ds_path}_RNG_{length}', |
| 149 | + 'classes': classes, |
| 150 | + 'sample_size': sample_size, |
| 151 | + 'epochs': epochs, |
| 152 | + 'batch_size': batch_size, |
| 153 | + 'lr': lr, |
| 154 | + 'optimizer': optimizer, |
| 155 | + 'loss': loss, |
| 156 | + 'metrics': metrics, |
| 157 | + 'num_gpus': num_gpus, |
| 158 | + 'training': True, |
| 159 | + 'seed': seed, |
| 160 | + 'use_saved_inital_weight': use_saved_inital_weight, |
| 161 | + 'verbose': verbose |
| 162 | + } |
| 163 | + |
| 164 | + p = multiprocessing.Process(target=run_training, args=(model, test_ds_path, real_test_ds_path), kwargs=training_kwargs) |
| 165 | + p.start() |
| 166 | + p.join() |
| 167 | + |
| 168 | + # train_ds + base line training |
| 169 | + training_kwargs = { |
| 170 | + 'additional_ds_path': test_ds_path, |
| 171 | + 'additional_ds_size': length, |
| 172 | + 'test_ds_paths': [real_test_ds_path], |
| 173 | + 'times': test_retrain_times, |
| 174 | + 'tag': f'{train_ds_path}+{test_ds_path}_RNG_{length}', |
| 175 | + 'classes': classes, |
| 176 | + 'sample_size': sample_size, |
| 177 | + 'epochs': epochs, |
| 178 | + 'batch_size': batch_size, |
| 179 | + 'lr': lr, |
| 180 | + 'optimizer': optimizer, |
| 181 | + 'loss': loss, |
| 182 | + 'metrics': metrics, |
| 183 | + 'num_gpus': num_gpus, |
| 184 | + 'training': True, |
| 185 | + 'seed': seed, |
| 186 | + 'use_saved_inital_weight': use_saved_inital_weight, |
| 187 | + 'enabled_transfer_learning': enabled_transfer_learning, |
| 188 | + 'enabled_transfer_learning_weights': voting_tags, |
| 189 | + 'verbose': verbose |
| 190 | + } |
| 191 | + p = multiprocessing.Process(target=run_training, args=(model, train_ds_path, val_ds_path), kwargs=training_kwargs) |
| 192 | + p.start() |
| 193 | + p.join() |
| 194 | + |
| 195 | + # train_ds + uncertain base line training |
| 196 | + training_kwargs = { |
| 197 | + 'additional_ds_path': test_ds_path, |
| 198 | + 'additional_ds_indexes': voting_rate_list[:length], |
| 199 | + 'test_ds_paths': [real_test_ds_path], |
| 200 | + 'times': test_retrain_times, |
| 201 | + 'tag': f'{train_ds_path}+{test_ds_path}_UNC_{length}', |
| 202 | + 'classes': classes, |
| 203 | + 'sample_size': sample_size, |
| 204 | + 'epochs': epochs, |
| 205 | + 'batch_size': batch_size, |
| 206 | + 'lr': lr, |
| 207 | + 'optimizer': optimizer, |
| 208 | + 'loss': loss, |
| 209 | + 'metrics': metrics, |
| 210 | + 'num_gpus': num_gpus, |
| 211 | + 'training': True, |
| 212 | + 'seed': seed, |
| 213 | + 'use_saved_inital_weight': use_saved_inital_weight, |
| 214 | + 'enabled_transfer_learning': enabled_transfer_learning, |
| 215 | + 'enabled_transfer_learning_weights': voting_tags, |
| 216 | + 'verbose': verbose |
| 217 | + } |
| 218 | + p = multiprocessing.Process(target=run_training, args=(model, train_ds_path, val_ds_path), kwargs=training_kwargs) |
| 219 | + p.start() |
| 220 | + p.join() |
| 221 | + |
| 222 | + end = time.time() |
| 223 | + elapsed = end - start |
| 224 | + LOG.info(f"Time taken: {elapsed:.3f} seconds.") |
| 225 | + |
| 226 | + |
| 227 | +_examples = '''examples: |
| 228 | + # Train SCNN 18Layers using the keras: |
| 229 | + python %(prog)s \\ |
| 230 | + --model SCNN18 \\ |
| 231 | + --voting_tag 2021-01-23/20210123-12_SCNN18_SCNN-Jamendo-train_h5 \\ |
| 232 | + --voting_times 21 \\ |
| 233 | + --train_ds_path SCNN-Jamendo-train.h5 \\ |
| 234 | + --val_ds_path SCNN-Jamendo-test.h5 \\ |
| 235 | + --test_ds_paths SCNN-Taiwanese-stream-train.h5 SCNN-Classical-test.h5 \\ |
| 236 | + --test_add_retrain_sizes 100 200 300 400 500 600 700 800 900 1000 \\ |
| 237 | + --test_retrain_times 1 \\ |
| 238 | + --test_retrain_has_random \\ |
| 239 | + --classes 2 \\ |
| 240 | + --sample_size 32000 1 \\ |
| 241 | + --epochs 160 \\ |
| 242 | + --batch_size 150 \\ |
| 243 | + --loss categorical_crossentropy \\ |
| 244 | + --optimizer adadelta \\ |
| 245 | + --metrics accuracy \\ |
| 246 | + --lr 1.0 \\ |
| 247 | + --seed 0 |
| 248 | +''' |
| 249 | + |
| 250 | + |
| 251 | +def main(): |
| 252 | + parser = argparse.ArgumentParser(description="Train SCNN 18Layers", epilog=_examples, formatter_class=argparse.RawTextHelpFormatter) |
| 253 | + parser.add_argument('--model', required=True, help="SCNN18,SCNN36,AutoEncoderRemoveVocal") |
| 254 | + parser.add_argument('--voting_tag', required=True, help="Trained Model tag") |
| 255 | + parser.add_argument('--voting_times', required=True, help="How many trained models?(default: %(default)s)", default=21, type=int) |
| 256 | + parser.add_argument('--train_ds_path', required=True, help='Training dataset path') |
| 257 | + parser.add_argument('--val_ds_path', required=True, help='validation dataset path') |
| 258 | + parser.add_argument( |
| 259 | + '--test_ds_paths', |
| 260 | + help='Testing dataset paths; Required pair dataset include train and test ; Use train in here; (default: %(default)s)', |
| 261 | + nargs='+', |
| 262 | + default=['train.h5'] |
| 263 | + ) |
| 264 | + parser.add_argument('--test_add_retrain_sizes', help='Add some test_set to train_set(default: %(default)s)', type=int, nargs='+', default=[100]) |
| 265 | + parser.add_argument('--test_retrain_times', required=True, help="How many times do you train?(default: %(default)s)", default=1, type=int) |
| 266 | + parser.add_argument('--test_retrain_has_random', help="Also trained in random?(default: %(default)s)", default=False, action='store_true') |
| 267 | + parser.add_argument('--classes', help='Output class number(default: %(default)s)', default=2, type=int) |
| 268 | + parser.add_argument('--sample_size', help='Audio sample size(default: %(default)s)', nargs='+', type=int, default=[32000, 1]) |
| 269 | + parser.add_argument('--epochs', help="epochs (default: %(default)s)", default=160, type=int) |
| 270 | + parser.add_argument('--batch_size', help="batch_size (default: %(default)s)", default=150, type=int) |
| 271 | + parser.add_argument('--loss', help="loss(default: %(default)s)", default='categorical_crossentropy', type=str) |
| 272 | + parser.add_argument('--optimizer', help="optimizer(default: %(default)s)", default='adadelta', type=str) |
| 273 | + parser.add_argument('--metrics', help="metrics(default: %(default)s)", nargs='+', default=['accuracy']) |
| 274 | + parser.add_argument('--lr', help="learning rate(default: %(default)s for optimizer default value)", default=0.0, type=float) |
| 275 | + parser.add_argument('--explainable', help="Run explainable?(default: %(default)s)", default=False, action='store_true') |
| 276 | + parser.add_argument('--filter_x', help="Explainable filter_x(default: %(default)s)", default=45, type=int) |
| 277 | + parser.add_argument('--filter_y', help="Explainable filter_y(default: %(default)s)", default=120, type=int) |
| 278 | + parser.add_argument('--magnification', help="Explainable magnification(default: %(default)s)", default=4, type=int) |
| 279 | + parser.add_argument('--num_gpus', help="Number of gpus(default: %(default)s)", default=2, type=int) |
| 280 | + parser.add_argument('--debug', help="Is debuging?(default: %(default)s)", default=False, action='store_true') |
| 281 | + parser.add_argument('--seed', help="Random seed (default: %(default)s)", type=int) |
| 282 | + parser.add_argument('--verbose', help="Verbose (default: %(default)s)", default=1, type=int) |
| 283 | + parser.add_argument('--use_saved_inital_weight', help="use saved inital weight(default: %(default)s)", default=False, action='store_true') |
| 284 | + parser.add_argument('--skip_origin', help="skip training origin weight(default: %(default)s)", default=False, action='store_true') |
| 285 | + parser.add_argument('--enabled_transfer_learning', help="enabled transfer learnning(default: %(default)s)", default=False, action='store_true') |
| 286 | + args = parser.parse_args() |
| 287 | + |
| 288 | + run(**vars(args)) |
| 289 | + |
| 290 | + |
| 291 | +if __name__ == "__main__": |
| 292 | + main() |
0 commit comments