diff --git a/desidlas/data_model/Data.py b/desidlas/data_model/Data.py index d3cc8f5..50640f7 100644 --- a/desidlas/data_model/Data.py +++ b/desidlas/data_model/Data.py @@ -2,8 +2,8 @@ from abc import ABCMeta -from dla_cnn.data_model import Id -from dla_cnn.data_model import Sightline +from desidlas.data_model import Id +from desidlas.data_model import Sightline class Data(object): diff --git a/desidlas/data_model/Sightline.py b/desidlas/data_model/Sightline.py index 47813b0..fc69380 100644 --- a/desidlas/data_model/Sightline.py +++ b/desidlas/data_model/Sightline.py @@ -1,5 +1,6 @@ import numpy as np from desidlas.dla_cnn.spectra_utils import get_lam_data +from desidlas.datasets.datasetting import split_sightline_into_samples class Sightline(object): @@ -77,17 +78,16 @@ def is_lyb(self, peakix): """ assert self.prediction is not None and peakix in self.prediction.peaks_ixs - lam, lam_rest, ix_dla_range = get_lam_data(self.loglam, self.z_qso) - kernelrangepx = 200 - cut=((np.nonzero(ix_dla_range)[0])>=kernelrangepx)&((np.nonzero(ix_dla_range)[0])<=(len(lam)-kernelrangepx-1)) - lam_analyse=lam[ix_dla_range][cut] + data_split=split_sightline_into_samples(self) + lam_analyse=data_split[5] + lambda_higher = (lam_analyse[peakix]) / (1025.722/1215.67)#找这个peak对应的dla # An array of how close each peak is to beign the ly-b of peakix in spectrum reference frame peak_difference_spectrum = np.abs(lam_analyse[self.prediction.peaks_ixs] - lambda_higher) - nearest_peak_ix = np.argmin(peak_difference_spectrum)#找距离这个dla最近的peak + nearest_peak_ix = np.argmin(peak_difference_spectrum) - # get the column density of the identfied nearest peak算这两个的nhi + # get the column density of the identfied nearest peak _, potential_lya_nhi, _, _ = \ self.prediction.get_coldensity_for_peak(self.prediction.peaks_ixs[nearest_peak_ix]) _, potential_lyb_nhi, _, _ = \ @@ -95,10 +95,10 @@ def is_lyb(self, peakix): # Validations: check that the nearest peak is close enough to match # sanity check that the LyB is at least 0.3 less than the DLA - is_nearest_peak_within_range = peak_difference_spectrum[nearest_peak_ix] <= 15#两者距离小于15 - is_nearest_peak_larger_coldensity = potential_lyb_nhi < potential_lya_nhi - 0.3#nhi差距0.3以上? + is_nearest_peak_within_range = peak_difference_spectrum[nearest_peak_ix] <= 15 + is_nearest_peak_larger_coldensity = potential_lyb_nhi < potential_lya_nhi - 0.3 - return is_nearest_peak_within_range and is_nearest_peak_larger_coldensity#true为lyb,false为lya + return is_nearest_peak_within_range and is_nearest_peak_larger_coldensity#true lyb,false lya def get_lyb_index(self, peakix): diff --git a/desidlas/datasets/datasetting.py b/desidlas/datasets/datasetting.py index 42cbe96..e2247b5 100644 --- a/desidlas/datasets/datasetting.py +++ b/desidlas/datasets/datasetting.py @@ -16,7 +16,7 @@ from desidlas.dla_cnn.spectra_utils import get_lam_data from desidlas.dla_cnn.defs import REST_RANGE,kernel,best_v -def pad_sightline(sightline, lam, lam_rest, ix_dla_range,kernelrangepx,v=best_v['b']): +def pad_sightline(sightline, lam, lam_rest, ix_dla_range,kernelrangepx,v=best_v['all']): """ padding the left and right sides of the spectra @@ -60,7 +60,7 @@ def pad_sightline(sightline, lam, lam_rest, ix_dla_range,kernelrangepx,v=best_v[ lam_padded = np.hstack((pad_lam_left,lam,pad_lam_right)) return flux_padded,lam_padded,pixel_num_left -def split_sightline_into_samples(sightline, REST_RANGE=REST_RANGE, kernel=kernel): +def split_sightline_into_samples(sightline, REST_RANGE=REST_RANGE, kernel=kernel,v=best_v['all']): """ Split the sightline into a series of snippets, each with length kernel @@ -78,7 +78,7 @@ def split_sightline_into_samples(sightline, REST_RANGE=REST_RANGE, kernel=kernel kernelrangepx = int(kernel/2) #200 #padding the sightline: - flux_padded,lam_padded,pixel_num_left=pad_sightline(sightline,lam,lam_rest,ix_dla_range,kernelrangepx,v=best_v['b']) + flux_padded,lam_padded,pixel_num_left=pad_sightline(sightline,lam,lam_rest,ix_dla_range,kernelrangepx,v=v) diff --git a/desidlas/datasets/get_dataset.py b/desidlas/datasets/get_dataset.py index d4b8f55..f4e0824 100644 --- a/desidlas/datasets/get_dataset.py +++ b/desidlas/datasets/get_dataset.py @@ -7,8 +7,9 @@ REST_RANGE = defs.REST_RANGE kernel = defs.kernel smooth_kernel= defs.smooth_kernel +best_v = defs.best_v -def make_datasets(sightlines,validate=True,kernel=kernel, REST_RANGE=REST_RANGE, output=None): +def make_datasets(sightlines, kernel=kernel, REST_RANGE=REST_RANGE, v=best_v['all'],output=None, validate=True): """ Generate training set or validation set for DESI. @@ -28,7 +29,7 @@ def make_datasets(sightlines,validate=True,kernel=kernel, REST_RANGE=REST_RANGE, wavelength_dlas=[dla.central_wavelength for dla in sightline.dlas] coldensity_dlas=[dla.col_density for dla in sightline.dlas] label_sightline(sightline, kernel=kernel, REST_RANGE=REST_RANGE) - data_split=split_sightline_into_samples(sightline,REST_RANGE=REST_RANGE, kernel=kernel) + data_split=split_sightline_into_samples(sightline,REST_RANGE=REST_RANGE, kernel=kernel,v=v) if validate: flux=np.vstack([data_split[0]]) labels_classifier=np.hstack([data_split[1]]) @@ -38,7 +39,7 @@ def make_datasets(sightlines,validate=True,kernel=kernel, REST_RANGE=REST_RANGE, dataset[sightline.id]={'FLUX':flux,'lam':lam,'labels_classifier': labels_classifier, 'labels_offset':labels_offset , 'col_density': col_density,'wavelength_dlas':wavelength_dlas,'coldensity_dlas':coldensity_dlas} else: sample_masks=select_samples_50p_pos_neg(sightline, kernel=kernel) - if len(sample_masks) >0: + if sample_masks !=[]: flux=np.vstack([data_split[0][m] for m in sample_masks]) labels_classifier=np.hstack([data_split[1][m] for m in sample_masks]) labels_offset=np.hstack([data_split[2][m] for m in sample_masks]) @@ -69,7 +70,7 @@ def smooth_flux(flux): return flux_matrix #smooth flux for low S/N sightlines -def make_smoothdatasets(sightlines,validate=True,kernel=smooth_kernel, REST_RANGE=REST_RANGE, output=None): +def make_smoothdatasets(sightlines,kernel=smooth_kernel, REST_RANGE=REST_RANGE, v=best_v['all'], output=None, validate=True): """ Generate smoothed training set or validation set for DESI. @@ -88,7 +89,7 @@ def make_smoothdatasets(sightlines,validate=True,kernel=smooth_kernel, REST_RANG wavelength_dlas=[dla.central_wavelength for dla in sightline.dlas] coldensity_dlas=[dla.col_density for dla in sightline.dlas] label_sightline(sightline, kernel=kernel, REST_RANGE=REST_RANGE) - data_split=split_sightline_into_samples(sightline, REST_RANGE=REST_RANGE, kernel=kernel) + data_split=split_sightline_into_samples(sightline, REST_RANGE=REST_RANGE, kernel=kernel,v=v) if validate: flux=np.vstack([data_split[0]]) labels_classifier=np.hstack([data_split[1]]) @@ -96,7 +97,7 @@ def make_smoothdatasets(sightlines,validate=True,kernel=smooth_kernel, REST_RANG col_density=np.hstack([data_split[3]]) lam=np.vstack([data_split[4]]) flux_matrix=smooth_flux(flux) - dataset[sightline.id]={'FLUXMATRIX':flux_matrix,'lam':lam,'labels_classifier': labels_classifier, 'labels_offset':labels_offset , 'col_density': col_density,'wavelength_dlas':wavelength_dlas,'coldensity_dlas':coldensity_dlas} + dataset[sightline.id]={'FLUX':flux_matrix,'lam':lam,'labels_classifier': labels_classifier, 'labels_offset':labels_offset , 'col_density': col_density,'wavelength_dlas':wavelength_dlas,'coldensity_dlas':coldensity_dlas} else: sample_masks=select_samples_50p_pos_neg(sightline,kernel=kernel) if sample_masks !=[]: @@ -105,7 +106,7 @@ def make_smoothdatasets(sightlines,validate=True,kernel=smooth_kernel, REST_RANG labels_offset=np.hstack([data_split[2][m] for m in sample_masks]) col_density=np.hstack([data_split[3][m] for m in sample_masks]) flux_matrix=smooth_flux(flux) - dataset[sightline.id]={'FLUXMATRIX':flux_matrix,'labels_classifier':labels_classifier,'labels_offset':labels_offset,'col_density': col_density} + dataset[sightline.id]={'FLUX':flux_matrix,'labels_classifier':labels_classifier,'labels_offset':labels_offset,'col_density': col_density} np.save(output,dataset) return dataset diff --git a/desidlas/datasets/get_sightlines.py b/desidlas/datasets/get_sightlines.py index 5c09dc6..64bd794 100644 --- a/desidlas/datasets/get_sightlines.py +++ b/desidlas/datasets/get_sightlines.py @@ -35,7 +35,7 @@ def get_sightlines(spectra,truth,zbest,outpath): sightline.flux = sightline.flux[0:sightline.split_point_br] sightline.error = sightline.error[0:sightline.split_point_br] sightline.loglam = sightline.loglam[0:sightline.split_point_br] - rebin(sightline, best_v['b']) + rebin(sightline, best_v['all']) sightlines.append(sightline) np.save(outpath,sightlines) diff --git a/desidlas/datasets/preprocess.py b/desidlas/datasets/preprocess.py index d8fe688..d18e2f6 100644 --- a/desidlas/datasets/preprocess.py +++ b/desidlas/datasets/preprocess.py @@ -112,7 +112,6 @@ def rebin(sightline, v): ------- :class:`dla_cnn.data_model.Sightline.Sightline`: """ - # TODO -- Add inline comments c = 2.9979246e8 # Set a constant dispersion @@ -209,8 +208,9 @@ def normalize(sightline, full_wavelength, full_flux): assert blue_limit <= red_limit,"No Lymann-alpha forest, Please check this spectra: %i"%sightline.id#when no lymann alpha forest exists, assert error. #use the slice we chose above to normalize this spectra, normalize both flux and error array using the same factor to maintain the s/n. good_pix = (rest_wavelength>=blue_limit)&(rest_wavelength<=red_limit) - sightline.flux = sightline.flux/np.median(full_flux[good_pix]) - sightline.error = sightline.error/np.median(full_flux[good_pix]) + normalizer=np.abs(np.nanmedian(full_flux[good_pix])) + sightline.flux = sightline.flux/normalizer + sightline.error = sightline.error/normalizer def estimate_s2n(sightline): """ @@ -237,9 +237,9 @@ def estimate_s2n(sightline): #for dla in sightline.dlas: #test = test&((wavelength>dla.central_wavelength+delta)|(wavelength0, "this sightline doesn't contain lymann forest, sightline id: %i"%sightline.id - s2n = sightline.flux/sightline.error + s2n = np.abs(sightline.flux/sightline.error) #return s/n - return np.median(s2n[test]) + return np.nanmedian(s2n[test]) def generate_summary_table(sightlines, output_dir, mode = "w"): """ diff --git a/desidlas/notebook/training_prediction.ipynb b/desidlas/notebook/training_prediction.ipynb index ebe28fe..8dfb403 100644 --- a/desidlas/notebook/training_prediction.ipynb +++ b/desidlas/notebook/training_prediction.ipynb @@ -67,7 +67,7 @@ "source": [ "#the codes used for detect DLAs is in desidlas/prediction/get_partprediction.py\n", "#to get the prediction for every part(400 or 600 pixels) , all you need is to run:\n", - "python get_partprediction.py -p 'pre_dataset.npy' -o 'partpre.npy' -m high\n", + "python get_partprediction.py -p 'pre_dataset.npy' -o 'partpre.npy' -model high\n", "\n", "\n", "# -p : path to the dataset used to detect DLAs\n", diff --git a/desidlas/prediction/get_partprediction.py b/desidlas/prediction/get_partprediction.py index afc71b0..be3e3d2 100644 --- a/desidlas/prediction/get_partprediction.py +++ b/desidlas/prediction/get_partprediction.py @@ -40,12 +40,14 @@ def t(tensor_name): -def predictions_ann(hyperparameters, flux, checkpoint_filename, TF_DEVICE=''): +def predictions_ann(hyperparameters, INPUT_SIZE,matrix_size,flux, checkpoint_filename, TF_DEVICE=''): ''' Perform training Parameters ---------- hyperparameters:hyperparameters for the CNN model structure + INPUT_SIZE: pixels numbers for each window , 400 for high SNR and 600 for low SNR + matrix_size: 1 if without smoothing, 4 if smoothing for low SNR flux:list (400 or 600 length), flux from sightline checkpoint_filename: CNN model file used to detect DLAs TF_DEVICE: use which gpu to train, default is '/gpu:1' @@ -69,7 +71,7 @@ def predictions_ann(hyperparameters, flux, checkpoint_filename, TF_DEVICE=''): with tf.Graph().as_default(): - build_model(hyperparameters) # build the CNN model according to hyperparameters + build_model(hyperparameters,INPUT_SIZE,matrix_size) # build the CNN model according to hyperparameters with tf.device(TF_DEVICE), tf.compat.v1.Session() as sess: tf.compat.v1.train.Saver().restore(sess, checkpoint_filename+".ckpt") #load model files @@ -96,12 +98,17 @@ def predictions_ann(hyperparameters, flux, checkpoint_filename, TF_DEVICE=''): parser = argparse.ArgumentParser() parser.add_argument('-p', '--preddataset', help='Datasets to detect DLAs , npy format', required=True, default=False) parser.add_argument('-o', '--output_file', help='output files to save the prediction result, npy format', required=False, default=False) - parser.add_argument('-m', '--modelfiles', help='CNN models for prediction, high snr model or mid snr model', required=False, default=False) + parser.add_argument('-model', '--modelfiles', help='CNN models for prediction, high snr model or mid snr model', required=False, default=False) + parser.add_argument('-t', '--INPUT_SIZE', help='set the input data size', required=False, default=400) + parser.add_argument('-m', '--matrix_size', help='set the matrix size when using smooth', required=False, default=1) + args = vars(parser.parse_args()) - RUN_SINGLE_ITERATION = not args['hyperparamsearch'] - checkpoint_filename = args['checkpoint_file'] if RUN_SINGLE_ITERATION else None + batch_results_file = args['output_file'] + INPUT_SIZE = args['INPUT_SIZE'] + matrix_size = args['matrix_size'] + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG) @@ -111,8 +118,7 @@ def predictions_ann(hyperparameters, flux, checkpoint_filename, TF_DEVICE=''): from desidlas.training.parameterset import parameter_names from desidlas.training.parameterset import parameters hyperparameters = {} - for k in range(0,len(parameter_names)): - hyperparameters[parameter_names[k]] = parameters[k][0] + pred_dataset=args['preddataset'] @@ -123,8 +129,18 @@ def predictions_ann(hyperparameters, flux, checkpoint_filename, TF_DEVICE=''): modelfile=args['modelfiles'] if modelfile == 'high': checkpoint_filename='desidlas/prediction/model/train_highsnr/current_99999' + for k in range(0,len(parameter_names)): + hyperparameters[parameter_names[k]] = parameters[k][1] if modelfile == 'mid': checkpoint_filename='desidlas/prediction/model/train_midsnr/current_99999' + for k in range(0,len(parameter_names)): + hyperparameters[parameter_names[k]] = parameters[k][0] + if modelfile == 'low': + checkpoint_filename='desidlas/prediction/model/train_lowsnr/current_99999' + for k in range(0,len(parameter_names)): + hyperparameters[parameter_names[k]] = parameters[k][0] + INPUT_SIZE = 600 + matrix_size = 4 dataset={} @@ -140,7 +156,7 @@ def predictions_ann(hyperparameters, flux, checkpoint_filename, TF_DEVICE=''): flux=np.array(r[sight_id]['FLUX']) - (pred, conf, offset, coldensity)=predictions_ann(hyperparameters, flux, checkpoint_filename, TF_DEVICE='') + (pred, conf, offset, coldensity)=predictions_ann(hyperparameters, INPUT_SIZE,matrix_size,flux, checkpoint_filename, TF_DEVICE='') dataset[sight_id]={'pred':pred,'conf':conf,'offset': offset, 'coldensity':coldensity } diff --git a/desidlas/prediction/model/train_lowsnr/_init_.py b/desidlas/prediction/model/train_lowsnr/_init_.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/desidlas/prediction/model/train_lowsnr/_init_.py @@ -0,0 +1 @@ + diff --git a/desidlas/prediction/model/train_lowsnr/lowsnrmodel b/desidlas/prediction/model/train_lowsnr/lowsnrmodel new file mode 100644 index 0000000..bd745fe --- /dev/null +++ b/desidlas/prediction/model/train_lowsnr/lowsnrmodel @@ -0,0 +1,2 @@ +The model files are too large to upload to github, you can find the model files here : +https://drive.google.com/drive/folders/15iX-R0o2HmUeLGBKPHjT94xuqaI2tJHr?usp=sharing diff --git a/desidlas/training/model.py b/desidlas/training/model.py index 5aea498..e3c5fea 100644 --- a/desidlas/training/model.py +++ b/desidlas/training/model.py @@ -173,7 +173,10 @@ def build_model(hyperparameters,INPUT_SIZE,matrix_size): #tf.compat.v1.placeholder:claim a tensor that needs to be filled (the data type, shape and name) #x: the empty tensor need to be filled with the input data - x = tf.compat.v1.placeholder(tf.float32, shape=[None,matrix_size, INPUT_SIZE], name='x') + if matrix_size == 1: + x = tf.compat.v1.placeholder(tf.float32, shape=[None,INPUT_SIZE], name='x') + if matrix_size == 4: + x = tf.compat.v1.placeholder(tf.float32, shape=[None,matrix_size, INPUT_SIZE], name='x') #claim the tensor for three labels @@ -189,7 +192,7 @@ def build_model(hyperparameters,INPUT_SIZE,matrix_size): # Stride (4,1) # number of filters = 4 (features?) # Neuron activation = ReLU (rectified linear unit) - W_conv1 = weight_variable([conv1_kernel, 1, 4, conv1_filters]) + W_conv1 = weight_variable([conv1_kernel, 1, matrix_size, conv1_filters]) b_conv1 = bias_variable([conv1_filters]) # https://www.tensorflow.org/versions/r0.10/api_docs/python/nn.html#convolution diff --git a/desidlas/training/parameterset.py b/desidlas/training/parameterset.py index 4badb15..caf0aab 100644 --- a/desidlas/training/parameterset.py +++ b/desidlas/training/parameterset.py @@ -19,7 +19,7 @@ # batch_size [400,700, 400, 500, 600, 700, 850, 1000], # l2_regularization_penalty - [0.005, 0.01, 0.008, 0.005, 0.003], + [0.005,0.005, 0.01, 0.008, 0.005, 0.003], # dropout_keep_prob [0.9,0.98, 0.75, 0.9, 0.95, 0.98, 1], # fc1_n_neurons @@ -29,7 +29,7 @@ # fc2_2_n_neurons [500,350, 200, 350, 500, 700, 900, 1500], # fc2_3_n_neurons - [150, 200, 350, 500, 700, 900, 1500], + [150,150, 200, 350, 500, 700, 900, 1500], # conv1_kernel [40,32, 20, 22, 24, 26, 28, 32, 40, 48, 54], # conv2_kernel @@ -37,7 +37,7 @@ # conv3_kernel [20,16, 10, 14, 16, 20, 24, 28, 32, 34], # conv1_filters - [100, 64, 80, 90, 100, 110, 120, 140, 160, 200], + [100,100, 64, 80, 90, 100, 110, 120, 140, 160, 200], # conv2_filters [256,96, 80, 96, 128, 192, 256], # conv3_filters @@ -49,11 +49,11 @@ # conv3_stride [1,1, 1, 2, 3, 4, 5, 6], # pool1_kernel - [7, 3, 4, 5, 6, 7, 8, 9], + [7,7, 3, 4, 5, 6, 7, 8, 9], # pool2_kernel [4,6, 4, 5, 6, 7, 8, 9, 10], # pool3_kernel - [6, 4, 5, 6, 7, 8, 9, 10], + [6,6, 4, 5, 6, 7, 8, 9, 10], # pool1_stride [1,4, 1, 2, 4, 5, 6], # pool2_stride diff --git a/desidlas/training/training.py b/desidlas/training/training.py index 99aac8e..e62cbd7 100644 --- a/desidlas/training/training.py +++ b/desidlas/training/training.py @@ -11,6 +11,7 @@ import tensorflow as tf import os from pathlib import Path +from pkg_resources import resource_filename from tensorflow.python.framework import ops ops.reset_default_graph() @@ -140,7 +141,7 @@ def train_ann_test_batch(sess, ixs, data, summary_writer=None): -def train_ann(hyperparameters, train_dataset, test_dataset, save_filename=None, load_filename=None, tblogs = "../tmp/tblogs", TF_DEVICE='/gpu:1'): +def train_ann(hyperparameters, train_dataset, test_dataset, INPUT_SIZE,matrix_size,save_filename=None,load_filename=None,tblogs = "../tmp/tblogs",TF_DEVICE='/gpu:1'): """ Perform training @@ -296,6 +297,8 @@ def calc_normalized_score(best_accuracy, best_offset_rmse, best_coldensity_rmse) # Execute batch mode # from desidlas.data_model.Dataset import Dataset + from desidlas.training.parameterset import parameter_names + from desidlas.training.parameterset import parameters datafile_path = os.path.join(resource_filename('desidlas', 'tests'), 'datafile') traindata_path=os.path.join(datafile_path, 'sightlines-16-1375.npy') @@ -314,8 +317,8 @@ def calc_normalized_score(best_accuracy, best_offset_rmse, best_coldensity_rmse) parser.add_argument('-c', '--checkpoint_file', help='Name of the checkpoint file to save (without file extension)', required=False, default=savemodel_path) #../models/training/current parser.add_argument('-r', '--train_dataset_filename', help='File name of the training dataset without extension', required=False, default=traindata_path) parser.add_argument('-e', '--test_dataset_filename', help='File name of the testing dataset without extension', required=False, default=testdata_path) - parser.add_argument('-t', '--INPUT_SIZE', help='set the input data size', required=False, default=600) - parser.add_argument('-m', '--matrix_size', help='set the matrix size when using smooth', required=False, default=4) + parser.add_argument('-t', '--INPUT_SIZE', help='set the input data size', required=False, default=400) + parser.add_argument('-m', '--matrix_size', help='set the matrix size when using smooth', required=False, default=1) args = vars(parser.parse_args()) RUN_SINGLE_ITERATION = not args['hyperparamsearch'] @@ -338,10 +341,6 @@ def calc_normalized_score(best_accuracy, best_offset_rmse, best_coldensity_rmse) os.remove(batch_results_file) if os.path.exists(batch_results_file) else None with open(batch_results_file, "a") as csvoutput: csvoutput.write("iteration_num,normalized_score,best_accuracy,last_accuracy,last_objective,best_offset_rmse,last_offset_rmse,best_coldensity_rmse,last_coldensity_rmse," + ",".join(parameter_names) + "\n") - - - from desidlas.training.parameterset import parameter_names - from desidlas.training.parameterset import parameters #hyperparameter search @@ -352,7 +351,7 @@ def calc_normalized_score(best_accuracy, best_offset_rmse, best_coldensity_rmse) #start the training (best_accuracy, last_accuracy, last_objective, best_offset_rmse, last_offset_rmse, best_coldensity_rmse, - last_coldensity_rmse) = train_ann(hyperparameters, train_dataset, test_dataset, + last_coldensity_rmse) = train_ann(hyperparameters, train_dataset, test_dataset,INPUT_SIZE,matrix_size, save_filename=checkpoint_filename, load_filename=args['loadmodel']) diff --git a/docs/installing.rst b/docs/installing.rst index 877fd76..8821f15 100644 --- a/docs/installing.rst +++ b/docs/installing.rst @@ -83,20 +83,18 @@ Do these for docs:: Get The Model File ============== - The model files are too large to upload to github, you can find the model files for high S/N spectra here: + The model files are too large to upload to github, you can find the model files here: - https://drive.google.com/drive/folders/1DYOE_k9S_F0JmnAdFbTmHkVqyxFlc4t-?usp=sharing - - The model files are too large to upload to github, you can find the model files for low S/N spectra here : - - https://drive.google.com/drive/folders/1s5km1NAg5j0Y-tWI1q58Y09hjj0Jjc8C?usp=sharing + https://drive.google.com/drive/folders/1Cl07CuRBE9ljtvIoTWexEVNSd8Zzwyvg?usp=sharing + + The folders are different models for different S/N spectra. (high: >6. mid:3-6. low:<3) Test CNN ============== - When you finish the installing and want to test the CNN model (training and prediction), you can firstly download all the model files here: + When you finish the installing and want to test the CNN model (training and prediction), you can firstly download all the model files here(same link as above): - https://drive.google.com/drive/folders/1Cl07CuRBE9ljtvIoTWexEVNSd8Zzwyvg?usp=sharing + https://drive.google.com/drive/folders/1Cl07CuRBE9ljtvIoTWexEVNSd8Zzwyvg?usp=sharing And then add the environmental variable CNN_MODEL as the path to the model files like this: diff --git a/docs/notebook/training_prediction.ipynb b/docs/notebook/training_prediction.ipynb index ebe28fe..8dfb403 100644 --- a/docs/notebook/training_prediction.ipynb +++ b/docs/notebook/training_prediction.ipynb @@ -67,7 +67,7 @@ "source": [ "#the codes used for detect DLAs is in desidlas/prediction/get_partprediction.py\n", "#to get the prediction for every part(400 or 600 pixels) , all you need is to run:\n", - "python get_partprediction.py -p 'pre_dataset.npy' -o 'partpre.npy' -m high\n", + "python get_partprediction.py -p 'pre_dataset.npy' -o 'partpre.npy' -model high\n", "\n", "\n", "# -p : path to the dataset used to detect DLAs\n",