diff --git a/dmriqcpy/analysis/stats.py b/dmriqcpy/analysis/stats.py index f59a439..f4daf65 100644 --- a/dmriqcpy/analysis/stats.py +++ b/dmriqcpy/analysis/stats.py @@ -1,9 +1,10 @@ # -*- coding: utf-8 -*- +import os import nibabel as nib import numpy as np -import os -import pandas as pd + +from dmriqcpy.analysis.utils import get_stats_dataframes def stats_mean_median(column_names, filenames): @@ -26,8 +27,6 @@ def stats_mean_median(column_names, filenames): across subjects. """ values = [] - import time - sub_filenames = [os.path.basename(curr_subj).split('.')[0] for curr_subj in filenames] for filename in filenames: data = nib.load(filename).get_data() @@ -40,24 +39,14 @@ def stats_mean_median(column_names, filenames): mean = np.mean(data[data > 0]) median = np.median(data[data > 0]) - values.append( - [mean, median]) - - stats_per_subjects = pd.DataFrame(values, index=sub_filenames, - columns=column_names) - - stats_across_subjects = pd.DataFrame([stats_per_subjects.mean(), - stats_per_subjects.std(), - stats_per_subjects.min(), - stats_per_subjects.max()], - index=['mean', 'std', 'min', 'max'], - columns=column_names) + values.append([mean, median]) - return stats_per_subjects, stats_across_subjects + return get_stats_dataframes(filenames, values, column_names) -def stats_mean_in_tissues(column_names, images, wm_images, gm_images, - csf_images): +def stats_mean_in_tissues( + column_names, images, wm_images, gm_images, csf_images +): """ Compute mean value in WM, GM and CSF mask. @@ -82,7 +71,6 @@ def stats_mean_in_tissues(column_names, images, wm_images, gm_images, DataFrame containing mean, std, min and max of mean across subjects. """ values = [] - sub_images = [os.path.basename(curr_subj).split('.')[0] for curr_subj in images] for i in range(len(images)): data = nib.load(images[i]).get_data() @@ -95,20 +83,9 @@ def stats_mean_in_tissues(column_names, images, wm_images, gm_images, data_csf = np.mean(data[csf > 0]) data_max = np.max(data[wm > 0]) - values.append( - [data_wm, data_gm, data_csf, data_max]) + values.append([data_wm, data_gm, data_csf, data_max]) - stats_per_subjects = pd.DataFrame(values, index=sub_images, - columns=column_names) - - stats_across_subjects = pd.DataFrame([stats_per_subjects.mean(), - stats_per_subjects.std(), - stats_per_subjects.min(), - stats_per_subjects.max()], - index=['mean', 'std', 'min', 'max'], - columns=column_names) - - return stats_per_subjects, stats_across_subjects + return get_stats_dataframes(images, values, column_names) def stats_frf(column_names, filenames): @@ -130,22 +107,12 @@ def stats_frf(column_names, filenames): DataFrame containing mean, std, min and max of mean across subjects. """ values = [] + for filename in filenames: frf = np.loadtxt(filename) values.append([frf[0], frf[1], frf[3]]) - sub_filenames = [os.path.basename(curr_subj).split('.')[0] for curr_subj in filenames] - stats_per_subjects = pd.DataFrame(values,index=sub_filenames, - columns=column_names) - - stats_across_subjects = pd.DataFrame([stats_per_subjects.mean(), - stats_per_subjects.std(), - stats_per_subjects.min(), - stats_per_subjects.max()], - index=['mean', 'std', 'min', 'max'], - columns=column_names) - - return stats_per_subjects, stats_across_subjects + return get_stats_dataframes(filenames, values, column_names) def stats_tractogram(column_names, tractograms): @@ -167,24 +134,12 @@ def stats_tractogram(column_names, tractograms): DataFrame containing mean, std, min and max of mean across subjects. """ values = [] - sub_tractograms = [os.path.basename(curr_subj).split('.')[0] for curr_subj in tractograms] + for tractogram_file in tractograms: tractogram = nib.streamlines.load(tractogram_file, lazy_load=True) + values.append([tractogram.header["nb_streamlines"]]) - values.append( - [tractogram.header['nb_streamlines']]) - - stats_per_subjects = pd.DataFrame(values, index=sub_tractograms, - columns=column_names) - - stats_across_subjects = pd.DataFrame([stats_per_subjects.mean(), - stats_per_subjects.std(), - stats_per_subjects.min(), - stats_per_subjects.max()], - index=['mean', 'std', 'min', 'max'], - columns=column_names) - - return stats_per_subjects, stats_across_subjects + return get_stats_dataframes(tractograms, values, column_names) def stats_mask_volume(column_names, images): @@ -206,24 +161,10 @@ def stats_mask_volume(column_names, images): DataFrame containing mean, std, min and max of mean across subjects. """ values = [] - sub_images = [os.path.basename(curr_subj).split('.')[0] for curr_subj in images] for image in images: img = nib.load(image) - data = img.get_data() - voxel_volume = np.prod(img.header['pixdim'][1:4]) - volume = np.count_nonzero(data) * voxel_volume - - values.append([volume]) - - stats_per_subjects = pd.DataFrame(values, index=sub_images, - columns=column_names) - - stats_across_subjects = pd.DataFrame([stats_per_subjects.mean(), - stats_per_subjects.std(), - stats_per_subjects.min(), - stats_per_subjects.max()], - index=['mean', 'std', 'min', 'max'], - columns=column_names) + voxel_volume = np.prod(img.header["pixdim"][1:4]) + values.append([np.count_nonzero(img.get_data()) * voxel_volume]) - return stats_per_subjects, stats_across_subjects + return get_stats_dataframes(images, values, column_names) diff --git a/dmriqcpy/analysis/utils.py b/dmriqcpy/analysis/utils.py index 3ba0a04..8af2f72 100644 --- a/dmriqcpy/analysis/utils.py +++ b/dmriqcpy/analysis/utils.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- import logging -import numpy as np import os + +import numpy as np import pandas as pd """ -Some functions comes from -https://github.com/scilus/scilpy/blob/master/scilpy/utils/bvec_bval_tools.py +Some functions come from Scilpy v1.2.0 +https://github.com/scilus/scilpy/blob/1.2.0/scilpy/utils/bvec_bval_tools.py """ @@ -14,6 +15,7 @@ def get_nearest_bval(bvals, curr_bval, tol=20): """ Get nearest bval in a list of bvals If not in the list, return the current bval + Parameters ---------- bvals: array @@ -28,11 +30,10 @@ def get_nearest_bval(bvals, curr_bval, tol=20): ------- bval: float Return the nearest bval or the current one. - - """ - indices = np.where(np.logical_and(bvals <= curr_bval + tol, - bvals >= curr_bval - tol))[0] + indices = np.where( + np.logical_and(bvals <= curr_bval + tol, bvals >= curr_bval - tol) + )[0] if len(indices) > 0: bval = bvals[indices[0]] else: @@ -46,7 +47,7 @@ def read_protocol(in_jsons, tags): Parameters ---------- - in_json : List + in_jsons : List List of jsons files tags: List List of tags to check @@ -60,13 +61,9 @@ def read_protocol(in_jsons, tags): dfs_for_graph: DataFrame DataFrame containing all valid for all subjects. """ - dfs = [] - for in_json in in_jsons: - data = pd.read_json(in_json, orient='index') - dfs.append(data.T) - + dfs = [pd.read_json(in_json, orient="index").T for in_json in in_jsons] temp = pd.concat(dfs, ignore_index=True) - index = [os.path.basename(item).split('.')[0] for item in in_jsons] + index = [os.path.basename(item).split(".")[0] for item in in_jsons] dfs = [] tmp_dfs_for_graph = [] dfs_for_graph_all = [] @@ -79,30 +76,34 @@ def read_protocol(in_jsons, tags): tdf = tdf.rename(columns={tag: "Number of subjects"}) tdf.reset_index(inplace=True) tdf = tdf.rename(columns={tag: "Value(s)"}) - tdf = tdf.sort_values(by=['Value(s)'], - ascending=False) + tdf = tdf.sort_values(by=["Value(s)"], ascending=False) dfs.append((tag, tdf)) t = temp[tag] t.index = index tdf = pd.DataFrame(t) - if isinstance(temp[tag][0], int) or\ - isinstance(temp[tag][0], float): + if isinstance(temp[tag][0], int) or isinstance( + temp[tag][0], float + ): tmp_dfs_for_graph.append(tdf) - dfs.append(('complete_' + tag, tdf)) + dfs.append(("complete_" + tag, tdf)) else: logging.warning("{} does not exist in the metadata.".format(tag)) if tmp_dfs_for_graph: dfs_for_graph = pd.concat(tmp_dfs_for_graph, axis=1, join="inner") - dfs_for_graph_all = pd.DataFrame([dfs_for_graph.mean(), - dfs_for_graph.std(), - dfs_for_graph.min(), - dfs_for_graph.max()], - index=['mean', 'std', 'min', 'max'], - columns=dfs_for_graph.columns) + dfs_for_graph_all = pd.DataFrame( + [ + dfs_for_graph.mean(), + dfs_for_graph.std(), + dfs_for_graph.min(), + dfs_for_graph.max(), + ], + index=["mean", "std", "min", "max"], + columns=dfs_for_graph.columns, + ) return dfs, dfs_for_graph, dfs_for_graph_all @@ -127,7 +128,7 @@ def dwi_protocol(bvals, tol=20): values_stats = [] column_names = ["Nbr shells", "Nbr directions"] shells = {} - index = [os.path.basename(item).split('.')[0] for item in bvals] + index = [os.path.basename(item).split(".")[0] for item in bvals] for i, filename in enumerate(bvals): values = [] @@ -135,17 +136,18 @@ def dwi_protocol(bvals, tol=20): centroids, shells_indices = identify_shells(bval, threshold=tol) s_centroids = sorted(centroids) - values.append(', '.join(str(x) for x in s_centroids)) + values.append(", ".join(str(x) for x in s_centroids)) values.append(len(shells_indices)) - columns = ["bvals"] - columns.append("Nbr directions") + columns = ["bvals", "Nbr directions"] for centroid in s_centroids: nearest_centroid = get_nearest_bval(list(shells.keys()), centroid) if np.int(nearest_centroid) not in shells: shells[np.int(nearest_centroid)] = {} - nb_directions = len(shells_indices[shells_indices == - np.where(centroids == centroid)[ - 0]]) + nb_directions = len( + shells_indices[ + shells_indices == np.where(centroids == centroid)[0] + ] + ) if filename not in shells[np.int(nearest_centroid)]: shells[np.int(nearest_centroid)][index[i]] = 0 shells[np.int(nearest_centroid)][index[i]] += nb_directions @@ -154,18 +156,17 @@ def dwi_protocol(bvals, tol=20): values_stats.append([len(centroids) - 1, len(shells_indices)]) - stats_per_subjects[filename] = pd.DataFrame([values], index=[index[i]], - columns=columns) + stats_per_subjects[filename] = pd.DataFrame( + [values], index=[index[i]], columns=columns + ) - stats = pd.DataFrame(values_stats, index=index, - columns=column_names) + stats = pd.DataFrame(values_stats, index=index, columns=column_names) - stats_across_subjects = pd.DataFrame([stats.mean(), - stats.std(), - stats.min(), - stats.max()], - index=['mean', 'std', 'min', 'max'], - columns=column_names) + stats_across_subjects = pd.DataFrame( + [stats.mean(), stats.std(), stats.min(), stats.max()], + index=["mean", "std", "min", "max"], + columns=column_names, + ) return stats_per_subjects, stats, stats_across_subjects, shells @@ -181,8 +182,6 @@ def identify_shells(bvals, threshold=40.0, roundCentroids=False, sort=False): alternative to K-means considering we don't already know the number of shells K. - Note. This function should be added in Dipy soon. - Parameters ---------- bvals: array (N,) @@ -204,7 +203,7 @@ def identify_shells(bvals, threshold=40.0, roundCentroids=False, sort=False): For each bval, the associated centroid K. """ if len(bvals) == 0: - raise ValueError('Empty b-values.') + raise ValueError("Empty b-values.") # Finding centroids bval_centroids = [bvals[0]] @@ -217,8 +216,9 @@ def identify_shells(bvals, threshold=40.0, roundCentroids=False, sort=False): centroids = np.array(bval_centroids) # Identifying shells - bvals_for_diffs = np.tile(bvals.reshape(bvals.shape[0], 1), - (1, centroids.shape[0])) + bvals_for_diffs = np.tile( + bvals.reshape(bvals.shape[0], 1), (1, centroids.shape[0]) + ) shell_indices = np.argmin(np.abs(bvals_for_diffs - centroids), axis=1) @@ -237,16 +237,16 @@ def identify_shells(bvals, threshold=40.0, roundCentroids=False, sort=False): return centroids, shell_indices -def build_ms_from_shell_idx(bvecs, shell_idx): +def get_bvecs_from_shells_idxs(bvecs, shell_idxs): """ - Get bvecs from indexes + Get bvecs associated to each shell Parameters ---------- bvecs: numpy.ndarray - bvecs - shell_idx: numpy.ndarray - index for each bval + bvecs (N, 3) + shell_idxs: numpy.ndarray + lists of indexes into bvecs for each shell (K, M_k) Return ------ @@ -254,12 +254,49 @@ def build_ms_from_shell_idx(bvecs, shell_idx): bvecs for each bval """ - S = len(set(shell_idx)) - if (-1 in set(shell_idx)): - S -= 1 + nb_shells = len(set(shell_idxs)) + if -1 in set(shell_idxs): + nb_shells -= 1 - ms = [] - for i_ms in range(S): - ms.append(bvecs[shell_idx == i_ms]) + return [bvecs[shell_idxs == i_ms] for i_ms in range(nb_shells)] + + +def get_stats_dataframes(filenames, stats, metrics_names): + """ + Create a DataFrame from a list of statistics estimated on a list of + subjects, as well as a DataFrame of summary statistics (Mean, std, + min, max) across subjects. + + Parameters + ---------- + filenames: array of strings + Array of filenames used to compute the statistics. + stats: array + Array of statistics for each filename. + metrics_names: array of strings + Names of the metrics (statistics) available in stats. + Return + ------ + stats_per_subjects: DataFrame + DataFrame of statistics per subject. + stats_across_subjects: DataFrame + DataFrame of statistics across subjects (Mean, std, min, max). + """ - return ms + stats_per_subjects = pd.DataFrame( + stats, + index=[os.path.basename(f).split(".")[0] for f in filenames], + columns=metrics_names, + ) + stats_across_subjects = pd.DataFrame( + [ + stats_per_subjects.mean(), + stats_per_subjects.std(), + stats_per_subjects.min(), + stats_per_subjects.max(), + ], + index=["mean", "std", "min", "max"], + columns=metrics_names, + ) + + return stats_per_subjects, stats_across_subjects diff --git a/dmriqcpy/io/report.py b/dmriqcpy/io/report.py index 8cec342..cabe9e7 100644 --- a/dmriqcpy/io/report.py +++ b/dmriqcpy/io/report.py @@ -1,23 +1,26 @@ # -*- coding: utf-8 -*- - -from os.path import dirname, join, realpath import os -from shutil import copytree, copyfile +from os.path import dirname, join, realpath +from shutil import copyfile, copytree from jinja2 import Environment, FileSystemLoader -ONLINE_LIBS = ['js/FileSaver.js', - 'js/StreamSaver.min.js', - 'js/dark-mode-switch.js', - 'js/scripts.js', - 'css/style.css', - 'css/w3.css'] +ONLINE_LIBS = [ + "js/FileSaver.js", + "js/StreamSaver.min.js", + "js/dark-mode-switch.js", + "js/scripts.js", + "css/style.css", + "css/w3.css", +] -class Report(): + +class Report: """ Class to create html report for dmriqc. """ + def __init__(self, report_name): """ Initialise the Report Class. @@ -28,18 +31,25 @@ def __init__(self, report_name): Report name in html format. """ self.path = dirname(realpath(__file__)) - self.env = Environment(loader=FileSystemLoader( - join(self.path, "../template"))) + self.env = Environment( + loader=FileSystemLoader(join(self.path, "../template")) + ) self.report_name = report_name self.out_dir = dirname(report_name) if ".html" not in self.report_name: self.report_name += ".html" - def generate(self, title=None, nb_subjects=None, - summary_dict=None, graph_array=None, metrics_dict=None, - warning_dict=None, - online=False): + def generate( + self, + title=None, + nb_subjects=None, + summary_dict=None, + graph_array=None, + metrics_dict=None, + warning_dict=None, + online=False, + ): """ Generate and save the report. @@ -50,41 +60,49 @@ def generate(self, title=None, nb_subjects=None, nb_subjects : int Number of subjects. summary_dict : dict - Dictionnary of the statistic summaries for each metric. + Dictionary of the statistic summaries for each metric. summary_dict[METRIC_NAME] = HTML_CODE graph_array : array Array of graph div from plotly to display in the summary tab. metrics_dict : dict - Dictionnary of the subjects informations for each metric. - metrics_dict[METRIC_NAME] = {SUBJECT: - { 'stats': HTML_CODE, - 'screenshot': IMAGE_PATH} - } + Dictionary of the subjects' information for each metric. + metrics_dict[METRIC_NAME] = { + SUBJECT: { 'stats': HTML_CODE, 'screenshot': IMAGE_PATH } + } warning_dict : dict - Dictionnary of warning subjects for each metric. - warning_dict[METRIC_NAME] = { 'WANING_TYPE': ARRAY_OF_SUBJECTS, - 'nb_warnings': NUMBER_OF_SUBJECTS} + Dictionary of warning subjects for each metric. + warning_dict[METRIC_NAME] = { + 'WANING_TYPE': ARRAY_OF_SUBJECTS, + 'nb_warnings': NUMBER_OF_SUBJECTS + } + online : bool + If true, will fetch the js and css libraries online """ if online: os.makedirs(join(self.out_dir, "libs/css")) os.makedirs(join(self.out_dir, "libs/js")) for curr_lib in ONLINE_LIBS: - copyfile(join(self.path, "../template/libs/", curr_lib), - join(self.out_dir, "libs/", curr_lib)) + copyfile( + join(self.path, "../template/libs/", curr_lib), + join(self.out_dir, "libs/", curr_lib), + ) else: - copytree(join(self.path, "../template/libs"), - join(self.out_dir, "libs")) + copytree( + join(self.path, "../template/libs"), join(self.out_dir, "libs") + ) - with open(self.report_name, 'w') as out_file: - template = self.env.get_template('template.html') + with open(self.report_name, "w") as out_file: + template = self.env.get_template("template.html") - rendered = template.render(title=title, - nb_subjects=nb_subjects, - summary_dict=summary_dict, - graph_summ=graph_array, - metrics_dict=metrics_dict, - warning_list=warning_dict, - online=online) + rendered = template.render( + title=title, + nb_subjects=nb_subjects, + summary_dict=summary_dict, + graph_summ=graph_array, + metrics_dict=metrics_dict, + warning_list=warning_dict, + online=online, + ) out_file.write(rendered) out_file.close() diff --git a/dmriqcpy/io/utils.py b/dmriqcpy/io/utils.py index c28f807..5f547e0 100644 --- a/dmriqcpy/io/utils.py +++ b/dmriqcpy/io/utils.py @@ -1,28 +1,18 @@ # -*- coding: utf-8 -*- -""" -Some functions comes from -https://github.com/scilus/scilpy/blob/master/scilpy/io/utils.py -""" - import glob import os +import shutil +import numpy as np -def add_overwrite_arg(parser): - """ - Add overwrite option to the parser. - Parameters - ---------- - parser: argparse.ArgumentParser object - """ - parser.add_argument( - '-f', dest='overwrite', action='store_true', - help='Force overwriting of the output files.') +""" +Some functions comes from +https://github.com/scilus/scilpy/blob/master/scilpy/io/utils.py +""" -def assert_inputs_exist(parser, required, optional=None, - are_directories=False): +def assert_inputs_exist(parser, required, optional=None, are_directories=False): """ Assert that all inputs exist. If not, print parser's usage and exit. @@ -34,11 +24,12 @@ def assert_inputs_exist(parser, required, optional=None, Each element will be ignored if None are_directories: bool """ + def check(path, are_directories): if not os.path.isfile(path) and not are_directories: - parser.error('Input file {} does not exist'.format(path)) + parser.error("Input file {} does not exist".format(path)) elif are_directories and not os.path.isdir(path): - parser.error('Input directory {} does not exist'.format(path)) + parser.error("Input directory {} does not exist".format(path)) if isinstance(required, str): required = [required] @@ -66,10 +57,14 @@ def assert_outputs_exist(parser, args, required, optional=None): optional: string or list of paths. Each element will be ignored if None """ + def check(path): if os.path.isfile(path) and not args.overwrite: - parser.error('Output file {} exists. Use -f to force ' - 'overwriting'.format(path)) + parser.error( + "Output file {} exists. Use -f to force overwriting".format( + path + ) + ) if isinstance(required, str): required = [required] @@ -84,10 +79,74 @@ def check(path): check(optional_file) +def assert_list_arguments_equal_size(parser, *args): + sizes = [len(arg) for arg in args] + if len(sizes) == 0: + parser.error("No input images provided.") + if not np.allclose(sizes, sizes[0]): + parser.error("Not the same number of images in input.") + + +def add_overwrite_arg(parser): + """ + Add overwrite option to the parser. + + Parameters + ---------- + parser: argparse.ArgumentParser object + """ + parser.add_argument( + "-f", + dest="overwrite", + action="store_true", + help="Force overwriting of the output files.", + ) + + def add_online_arg(parser): - parser.add_argument('--online', action='store_true', - help='If set, the script will use the internet ' - 'connexion to grab the needed libraries.') + parser.add_argument( + "--online", + action="store_true", + help="If set, opening the generated HTML report will require an " + "internet connection to dynamically load JS and CSS dependencies", + ) + + +def add_nb_threads_arg(parser, default=1): + parser.add_argument( + "--nb_threads", + type=int, + default=default, + help="Number of threads. [%(default)s]", + ) + + +def add_skip_arg(parser, default=2): + parser.add_argument( + "--skip", + default=default, + type=int, + help="Number of images skipped to build the mosaic. [%(default)s]", + ) + + +def add_nb_columns_arg(parser, default=12): + parser.add_argument( + "--nb_columns", + default=default, + type=int, + help="Number of columns for the mosaic. [%(default)s]", + ) + + +def clean_output_directories(outputs_data=True): + if outputs_data: + if os.path.exists("data"): + shutil.rmtree("data") + os.makedirs("data") + + if os.path.exists("libs"): + shutil.rmtree("libs") def list_files_from_paths(paths): @@ -107,7 +166,7 @@ def list_files_from_paths(paths): out_images = [] for curr_path in paths: if os.path.isdir(curr_path): - curr_images = glob.glob(os.path.join(curr_path, '*')) + curr_images = glob.glob(os.path.join(curr_path, "*")) else: curr_images = [curr_path] diff --git a/dmriqcpy/reporting/__init__.py b/dmriqcpy/reporting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dmriqcpy/reporting/report.py b/dmriqcpy/reporting/report.py new file mode 100644 index 0000000..9c2122d --- /dev/null +++ b/dmriqcpy/reporting/report.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- +import os +from collections.abc import Iterable +from functools import partial +from multiprocessing import Pool + +import numpy as np + +from dmriqcpy.analysis.stats import ( + stats_frf, + stats_mask_volume, + stats_mean_in_tissues, + stats_mean_median, + stats_tractogram, +) +from dmriqcpy.viz.graph import ( + graph_frf_b0, + graph_frf_eigen, + graph_mask_volume, + graph_mean_in_tissues, + graph_mean_median, + graph_tractogram, +) +from dmriqcpy.viz.screenshot import ( + screenshot_mosaic_blend, + screenshot_mosaic_wrapper, + screenshot_tracking, +) +from dmriqcpy.viz.utils import analyse_qa, dataframe_to_html + + +def get_qa_report(summary, stats, qa_labels): + qa_report = analyse_qa(summary, stats, qa_labels) + files_flagged_warning = np.concatenate( + [filenames for filenames in qa_report.values()] + ) + qa_report["nb_warnings"] = len(np.unique(files_flagged_warning)) + return qa_report + + +def _get_stats_and_graphs( + metrics, + stats_fn, + stats_labels, + graph_fn, + graph_title, + qa_labels=None, + include_plotlyjs=False, +): + qa_labels = qa_labels or stats_labels + summary, stats = stats_fn(stats_labels, metrics) + qa_report = get_qa_report(summary, stats, qa_labels) + graphs = [ + fn(title, qa_labels, summary, include_plotlyjs) + for fn, title in zip( + graph_fn if isinstance(graph_fn, Iterable) else [graph_fn], + graph_title if isinstance(graph_title, Iterable) else [graph_title], + ) + ] + + return summary, stats, qa_report, graphs + + +def get_tractogram_qa_stats_and_graph(tractograms, report_is_online): + return _get_stats_and_graphs( + tractograms, + stats_tractogram, + ["Nb streamlines"], + graph_tractogram, + "Tracking", + include_plotlyjs=not report_is_online, + ) + + +def get_frf_qa_stats_and_graph(frfs, report_is_online): + return _get_stats_and_graphs( + frfs, + stats_frf, + ["Mean Eigen value 1", "Mean Eigen value 2", "Mean B0"], + [graph_frf_eigen, graph_frf_b0], + ["EigenValues", "Mean B0"], + include_plotlyjs=not report_is_online, + ) + + +def get_mask_qa_stats_and_graph(masks, name, report_is_online): + return _get_stats_and_graphs( + masks, + stats_mask_volume, + ["{} volume".format(name)], + graph_mask_volume, + "{} mean volume".format(name), + include_plotlyjs=not report_is_online, + ) + + +def get_qa_stats_and_graph_in_tissues( + metric, name, wm_masks, gm_masks, csf_masks, report_is_online +): + stats_labels = [ + "Mean {} in WM".format(name), + "Mean {} in GM".format(name), + "Mean {} in CSF".format(name), + "Max {} in WM".format(name), + ] + return _get_stats_and_graphs( + metric, + partial( + stats_mean_in_tissues, + wm_images=wm_masks, + gm_images=gm_masks, + csf_images=csf_masks, + ), + stats_labels, + graph_mean_in_tissues, + "Mean {}".format(name), + stats_labels[:3], + not report_is_online, + ) + + +def get_generic_qa_stats_and_graph(metrics, name, report_is_online): + return _get_stats_and_graphs( + metrics, + stats_mean_median, + ["Mean {}".format(name), "Median {}".format(name)], + graph_mean_median, + "Mean {}".format(name), + include_plotlyjs=not report_is_online, + ) + + +def generate_report_package( + metric_image_path, + blend_image_path=None, + stats_summary=None, + skip=1, + nb_columns=15, + duration=100, + cmap=None, + blend_val=0.5, + lut=None, + pad=20, + blend_is_mask=False, + metric_is_tracking=False, +): + if metric_is_tracking: + screenshot_path = screenshot_tracking( + metric_image_path, blend_image_path, "data" + ) + elif blend_image_path: + screenshot_path = screenshot_mosaic_blend( + metric_image_path, + blend_image_path, + directory="data", + skip=skip, + pad=pad, + nb_columns=nb_columns, + blend_val=blend_val, + cmap=cmap, + lut=lut, + is_mask=blend_is_mask, + ) + else: + screenshot_path = screenshot_mosaic_wrapper( + metric_image_path, + directory="data", + skip=skip, + pad=pad, + nb_columns=nb_columns, + duration=duration, + cmap=cmap, + lut=lut, + ) + + subject_data = {"screenshot": screenshot_path} + subj_metric_name = os.path.basename(metric_image_path).split(".")[0] + + if stats_summary is not None: + subject_data["stats"] = dataframe_to_html( + stats_summary.loc[subj_metric_name].to_frame() + ) + + return subj_metric_name, subject_data + + +def generate_metric_reports_parallel( + metrics_iterable, + nb_threads, + chunksize=1, + report_package_generation_fn=generate_report_package, +): + with Pool(nb_threads) as pool: + qc_tabs_data_pool = pool.starmap_async( + report_package_generation_fn, + metrics_iterable, + chunksize=chunksize, + ) + return {tag: data for tag, data in qc_tabs_data_pool.get()} diff --git a/dmriqcpy/template/libs/js/scripts.js b/dmriqcpy/template/libs/js/scripts.js index 7c35d85..bc196f0 100644 --- a/dmriqcpy/template/libs/js/scripts.js +++ b/dmriqcpy/template/libs/js/scripts.js @@ -775,7 +775,7 @@ function doMouseWheel(event) { } function update_status(object) { - document.getElementById(object.name + "_status").innerText = object.innerText; + document.getElementById(object.name + "_status").innerText = object.innerText.trim(" ").trim("\n"); document.getElementById(object.name + "_status").style.backgroundColor = object.style.backgroundColor; if (object.innerText != "Pending") { document.getElementById("curr_subj").style.backgroundColor = object.style.backgroundColor; diff --git a/dmriqcpy/version.py b/dmriqcpy/version.py index e0569da..95943dc 100644 --- a/dmriqcpy/version.py +++ b/dmriqcpy/version.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - from __future__ import absolute_import, division, print_function import glob @@ -16,15 +15,17 @@ if _version_extra: _ver.append(_version_extra) -__version__ = '.'.join(map(str, _ver)) +__version__ = ".".join(map(str, _ver)) -CLASSIFIERS = ["Development Status :: 3 - Alpha", - "Environment :: Console", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - "Programming Language :: Python", - "Topic :: Scientific/Engineering"] +CLASSIFIERS = [ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Topic :: Scientific/Engineering", +] # Description should be a one-liner: description = "Diffusion MRI Quality Check in python " @@ -47,8 +48,16 @@ MINOR = _version_minor MICRO = _version_micro VERSION = __version__ -REQUIRES = ['numpy (>=1.18)', 'jinja2 (>=2.10.1)', 'pandas (>=0.25.1)', - 'nibabel (>=3.0)', 'plotly (>=3.0.0)', 'vtk (>=8.1.2)', - 'pillow (>=6.2.0)', 'fury (>=0.2.0)', - 'matplotlib (>=2.2.0)', 'scipy (>=1.4.1)'] +REQUIRES = [ + "numpy (>=1.18)", + "jinja2 (>=2.10.1)", + "pandas (>=0.25.1)", + "nibabel (>=3.0)", + "plotly (>=3.0.0)", + "vtk (>=8.1.2)", + "pillow (>=6.2.0)", + "fury (>=0.2.0)", + "matplotlib (>=2.2.0)", + "scipy (>=1.4.1)", +] SCRIPTS = glob.glob("scripts/*.py") diff --git a/dmriqcpy/viz/graph.py b/dmriqcpy/viz/graph.py index 1fff5e6..1061e2d 100644 --- a/dmriqcpy/viz/graph.py +++ b/dmriqcpy/viz/graph.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- - import numpy as np from plotly.graph_objs import Bar, Box, Figure import plotly.offline as off +from dmriqcpy.viz.utils import graph_to_html + -def graph_mean_median(title, column_names, summary, online=False): +def graph_mean_median(title, column_names, summary, include_plotlyjs=False): """ Compute plotly graph with mean and median stats @@ -17,18 +18,16 @@ def graph_mean_median(title, column_names, summary, online=False): Name of the columns in the summary DataFrame. summary : DataFrame DataFrame containing the mean and median stats. - online: Boolean - If false it will include plotlyjs + include_plotlyjs: Boolean + If True, javascript and css dependencies for plotting will + be injected in the graph's HTML code returned. If not, they + have to be included manually or via a CDN. Returns ------- div : html div (string) Graph as a HTML div. """ - include_plotlyjs = not online - - means = [] - medians = [] np.random.seed(1) image = summary.index means = np.array(summary[column_names[0]]) @@ -37,40 +36,31 @@ def graph_mean_median(title, column_names, summary, online=False): mean = Box( name="Mean", y=means, - boxpoints='all', + boxpoints="all", jitter=0.3, text=image, pointpos=-1.8, - hoverinfo="y+text" + hoverinfo="y+text", ) - median = Box( name="Median", y=medians, - boxpoints='all', + boxpoints="all", jitter=0.3, text=image, pointpos=-1.8, - hoverinfo="y+text" + hoverinfo="y+text", ) - data = [mean, median] - - fig = Figure(data=data) - max_value = max(np.max(means), np.max(medians)) - - range_yaxis = [0, max_value + 2 * max_value] - - fig['layout']['yaxis'].update(range=range_yaxis) - fig['layout'].update(title=title) - fig['layout'].update(width=500, height=500) - div = off.plot(fig, show_link=False, include_plotlyjs=include_plotlyjs, - output_type='div') - div = div.replace("
", "
") - return div + return graph_to_html( + [mean, median], + title, + [0, 3.0 * max(np.max(means), np.max(medians))], + include_plotlyjs=include_plotlyjs, + ) -def graph_mean_in_tissues(title, column_names, summary, online=False): +def graph_mean_in_tissues(title, column_names, summary, include_plotlyjs=False): """ Compute plotly graph with mean value in tissue masks @@ -82,70 +72,59 @@ def graph_mean_in_tissues(title, column_names, summary, online=False): Name of the columns in the summary DataFrame. summary : DataFrame DataFrame containing the mean stats. - online: Boolean - If false it will include plotlyjs + include_plotlyjs: Boolean + If True, javascript and css dependencies for plotting will + be injected in the graph's HTML code returned. If not, they + have to be included manually or via a CDN. Returns ------- div : html div (string) Graph as a HTML div. """ - include_plotlyjs = not online - - means_wm = [] - means_gm = [] - means_csf = [] np.random.seed(1) metric = summary.index means_wm = np.array(summary[column_names[0]]) means_gm = np.array(summary[column_names[1]]) means_csf = np.array(summary[column_names[2]]) + wm = Box( name="WM", y=means_wm, - boxpoints='all', + boxpoints="all", jitter=0.3, text=metric, pointpos=-1.8, - hoverinfo="y+text" + hoverinfo="y+text", ) - gm = Box( name="GM", y=means_gm, - boxpoints='all', + boxpoints="all", jitter=0.3, text=metric, pointpos=-1.8, - hoverinfo="y+text" + hoverinfo="y+text", ) - csf = Box( name="CSF", y=means_csf, - boxpoints='all', + boxpoints="all", jitter=0.3, text=metric, pointpos=-1.8, hoverinfo="y+text", ) - data = [wm, gm, csf] - - fig = Figure(data=data) - - range_yaxis = [0, np.max(means_wm) + 2 * np.max(means_wm)] - fig['layout']['yaxis'].update(range=range_yaxis) - fig['layout'].update(title=title) - fig['layout'].update(width=500, height=500) - - div = off.plot(fig, show_link=False, include_plotlyjs=include_plotlyjs, - output_type='div') - div = div.replace("
", "
") - return div + return graph_to_html( + [wm, gm, csf], + title, + [0, 3.0 * np.max(means_wm)], + include_plotlyjs=include_plotlyjs, + ) -def graph_frf_eigen(title, column_names, summary, online=False): +def graph_frf_eigen(title, column_names, summary, include_plotlyjs=False): """ Compute plotly graph with mean frf values @@ -157,16 +136,16 @@ def graph_frf_eigen(title, column_names, summary, online=False): Name of the columns in the summary DataFrame. summary : DataFrame DataFrame containing the mean stats. - online: Boolean - If false it will include plotlyjs + include_plotlyjs: Boolean + If True, javascript and css dependencies for plotting will + be injected in the graph's HTML code returned. If not, they + have to be included manually or via a CDN. Returns ------- div : html div (string) Graph as a HTML div. """ - include_plotlyjs = not online - np.random.seed(1) metric = summary.index e1 = np.array(summary[column_names[0]]) @@ -175,35 +154,28 @@ def graph_frf_eigen(title, column_names, summary, online=False): e1_graph = Box( name="Eigen value 1", y=e1, - boxpoints='all', + boxpoints="all", jitter=0.3, text=metric, pointpos=-1.8, - hoverinfo="y+text" + hoverinfo="y+text", ) - e2_graph = Box( name="Eigen value 2", y=e2, - boxpoints='all', + boxpoints="all", jitter=0.3, text=metric, pointpos=-1.8, - hoverinfo="y+text" + hoverinfo="y+text", ) - data = [e1_graph, e2_graph] - - fig = Figure(data=data) + return graph_to_html( + [e1_graph, e2_graph], title, include_plotlyjs=include_plotlyjs + ) - fig['layout'].update(title=title) - fig['layout'].update(width=500, height=500) - div = off.plot(fig, show_link=False, include_plotlyjs=include_plotlyjs, - output_type='div') - div = div.replace("
", "
") - return div -def graph_frf_b0(title, column_names, summary, online=False): +def graph_frf_b0(title, column_names, summary, include_plotlyjs=False): """ Compute plotly graph with mean b0 values @@ -215,41 +187,33 @@ def graph_frf_b0(title, column_names, summary, online=False): Name of the columns in the summary DataFrame. summary : DataFrame DataFrame containing the mean stats. - online: Boolean - If false it will include plotlyjs + include_plotlyjs: Boolean + If True, javascript and css dependencies for plotting will + be injected in the graph's HTML code returned. If not, they + have to be included manually or via a CDN. Returns ------- div : html div (string) Graph as a HTML div. """ - include_plotlyjs = not online - np.random.seed(1) metric = summary.index - e1_graph = Box( + + mean_b0 = Box( name="Mean B0", y=np.array(summary[column_names[2]]), - boxpoints='all', + boxpoints="all", jitter=0.3, text=metric, pointpos=-1.8, - hoverinfo="y+text" + hoverinfo="y+text", ) - data = [e1_graph] - - fig = Figure(data=data) - - fig['layout'].update(title=title) - fig['layout'].update(width=500, height=500) - div = off.plot(fig, show_link=False, include_plotlyjs=include_plotlyjs, - output_type='div') - div = div.replace("
", "
") - return div + return graph_to_html([mean_b0], title, include_plotlyjs=include_plotlyjs) -def graph_tractogram(title, column_names, summary, online=False): +def graph_tractogram(title, column_names, summary, include_plotlyjs=False): """ Compute plotly graph with mean number of streamlines @@ -261,17 +225,16 @@ def graph_tractogram(title, column_names, summary, online=False): Name of the columns in the summary DataFrame. summary : DataFrame DataFrame containing the mean stats. - online: Boolean - If false it will include plotlyjs + include_plotlyjs: Boolean + If True, javascript and css dependencies for plotting will + be injected in the graph's HTML code returned. If not, they + have to be included manually or via a CDN. Returns ------- div : html div (string) Graph as a HTML div. """ - include_plotlyjs = not online - - nb_streamlines = [] np.random.seed(1) metric = summary.index nb_streamlines = np.array(summary[column_names[0]]) @@ -279,26 +242,19 @@ def graph_tractogram(title, column_names, summary, online=False): nb_streamlines_graph = Box( name="Nb streamlines", y=nb_streamlines, - boxpoints='all', + boxpoints="all", jitter=0.3, text=metric, pointpos=-1.8, - hoverinfo="y+text" + hoverinfo="y+text", ) - data = [nb_streamlines_graph] - - fig = Figure(data=data) - - fig['layout'].update(title=title) - fig['layout'].update(width=500, height=500) - div = off.plot(fig, show_link=False, include_plotlyjs=include_plotlyjs, - output_type='div') - div = div.replace("
", "
") - return div + return graph_to_html( + [nb_streamlines_graph], title, include_plotlyjs=include_plotlyjs + ) -def graph_mask_volume(title, column_names, summary, online=False): +def graph_mask_volume(title, column_names, summary, include_plotlyjs=False): """ Compute plotly graph with mean mask volume @@ -310,16 +266,16 @@ def graph_mask_volume(title, column_names, summary, online=False): Name of the columns in the summary DataFrame. summary : DataFrame DataFrame containing the mean stats. - online: Boolean - If false it will include plotlyjs + include_plotlyjs: Boolean + If True, javascript and css dependencies for plotting will + be injected in the graph's HTML code returned. If not, they + have to be included manually or via a CDN. Returns ------- div : html div (string) Graph as a HTML div. """ - include_plotlyjs = not online - np.random.seed(1) metric = summary.index volume = np.array(summary[column_names[0]]) @@ -327,26 +283,19 @@ def graph_mask_volume(title, column_names, summary, online=False): volume_graph = Box( name="Volume", y=volume, - boxpoints='all', + boxpoints="all", jitter=0.3, text=metric, pointpos=-1.8, - hoverinfo="y+text" + hoverinfo="y+text", ) - data = [volume_graph] - - fig = Figure(data=data) - - fig['layout'].update(title=title) - fig['layout'].update(width=500, height=500) - div = off.plot(fig, show_link=False, include_plotlyjs=include_plotlyjs, - output_type='div') - div = div.replace("
", "
") - return div + return graph_to_html( + [volume_graph], title, include_plotlyjs=include_plotlyjs + ) -def graph_dwi_protocol(title, column_name, summary, online=False): +def graph_dwi_protocol(title, column_name, summary, include_plotlyjs=False): """ Compute plotly graph with mean mask volume @@ -354,20 +303,20 @@ def graph_dwi_protocol(title, column_name, summary, online=False): ---------- title : string Title of the graph. - column_names : array of strings + column_name : array of strings Name of the columns in the summary DataFrame. summary : DataFrame DataFrame containing the mean stats. - online: Boolean - If false it will include plotlyjs + include_plotlyjs: Boolean + If True, javascript and css dependencies for plotting will + be injected in the graph's HTML code returned. If not, they + have to be included manually or via a CDN. Returns ------- div : html div (string) Graph as a HTML div. """ - include_plotlyjs = not online - np.random.seed(1) metric = summary.index data = np.array(summary[column_name]) @@ -375,26 +324,17 @@ def graph_dwi_protocol(title, column_name, summary, online=False): graph = Box( name=column_name, y=data, - boxpoints='all', + boxpoints="all", jitter=0.3, text=metric, pointpos=-1.8, - hoverinfo="y+text" + hoverinfo="y+text", ) - data = [graph] - - fig = Figure(data=data) + return graph_to_html([graph], title, include_plotlyjs=include_plotlyjs) - fig['layout'].update(title=title) - fig['layout'].update(width=500, height=500) - div = off.plot(fig, show_link=False, include_plotlyjs=include_plotlyjs, - output_type='div') - div = div.replace("
", "
") - return div - -def graph_directions_per_shells(title, summary, online=False): +def graph_directions_per_shells(title, summary, include_plotlyjs=False): """ Compute plotly graph with mean mask volume @@ -404,45 +344,41 @@ def graph_directions_per_shells(title, summary, online=False): Title of the graph. summary : dict DataFrame containing the mean stats. - online: Boolean - If false it will include plotlyjs + include_plotlyjs: Boolean + If True, javascript and css dependencies for plotting will + be injected in the graph's HTML code returned. If not, they + have to be included manually or via a CDN. Returns ------- div : html div (string) Graph as a HTML div. """ - include_plotlyjs = not online - np.random.seed(1) data_graph = [] + for i in sorted(summary): metric = list(summary[i].keys()) data = list(summary[i].values()) graph = Box( - name="b=" + str(i), + name="b={}".format(i), y=data, - boxpoints='all', + boxpoints="all", jitter=0.3, text=metric, pointpos=-1.8, - hoverinfo="y+text" + hoverinfo="y+text", ) data_graph.append(graph) - fig = Figure(data=data_graph) - - fig['layout'].update(title=title) - fig['layout'].update(width=700, height=500) - div = off.plot(fig, show_link=False, include_plotlyjs=include_plotlyjs, - output_type='div') - div = div.replace("
", "
") - return div + return graph_to_html( + data_graph, title, width=700, include_plotlyjs=include_plotlyjs + ) -def graph_subjects_per_shells(title, summary, online=False): +def graph_subjects_per_shells(title, summary, include_plotlyjs=False): """ Compute plotly graph with mean mask volume @@ -452,36 +388,32 @@ def graph_subjects_per_shells(title, summary, online=False): Title of the graph. summary : dict DataFrame containing the mean stats. - online: Boolean - If false it will include plotlyjs + include_plotlyjs: Boolean + If True, javascript and css dependencies for plotting will + be injected in the graph's HTML code returned. If not, they + have to be included manually or via a CDN. Returns ------- div : html div (string) Graph as a HTML div. """ - include_plotlyjs = not online - np.random.seed(1) data_graph = [] + for i in sorted(summary): metric = list(summary[i].keys()) data = [len(metric)] graph = Bar( - name="b=" + str(i), + name="b={}".format(i), y=data, - x=["b=" + str(i)], - hoverinfo="y" + x=["b={}".format(i)], + hoverinfo="y", ) data_graph.append(graph) - fig = Figure(data=data_graph) - - fig['layout'].update(title=title) - fig['layout'].update(width=700, height=500) - div = off.plot(fig, show_link=False, include_plotlyjs=include_plotlyjs, - output_type='div') - div = div.replace("
", "
") - return div + return graph_to_html( + data_graph, title, width=700, include_plotlyjs=include_plotlyjs + ) diff --git a/dmriqcpy/viz/screenshot.py b/dmriqcpy/viz/screenshot.py index bc7e7ad..0033d40 100644 --- a/dmriqcpy/viz/screenshot.py +++ b/dmriqcpy/viz/screenshot.py @@ -1,35 +1,47 @@ # -*- coding: utf-8 -*- - import os from tempfile import mkstemp -from PIL import Image, ImageDraw, ImageFont from dipy.data import get_sphere +from dipy.io.streamline import load_tractogram import fury from fury import actor, window from matplotlib.cm import get_cmap import nibabel as nib import numpy as np +from PIL import Image, ImageDraw, ImageFont from dmriqcpy.viz.utils import compute_labels_map, renderer_to_arr -from dipy.io.streamline import load_tractogram -vtkcolors = [window.colors.blue, - window.colors.red, - window.colors.yellow, - window.colors.purple, - window.colors.cyan, - window.colors.green, - window.colors.orange, - window.colors.white, - window.colors.brown, - window.colors.grey] - - -def screenshot_mosaic_wrapper(filename, output_prefix="", directory=".", skip=1, - pad=20, nb_columns=15, axis=True, cmap=None, - return_path=True, duration=100, lut=None, - compute_lut=False): + +vtkcolors = [ + window.colors.blue, + window.colors.red, + window.colors.yellow, + window.colors.purple, + window.colors.cyan, + window.colors.green, + window.colors.orange, + window.colors.white, + window.colors.brown, + window.colors.grey, +] + + +def screenshot_mosaic_wrapper( + filename, + output_prefix="", + directory=".", + skip=1, + pad=20, + nb_columns=15, + axis=True, + cmap=None, + return_path=True, + duration=100, + lut=None, + compute_lut=False, +): """ Compute mosaic wrapper from an image @@ -44,7 +56,7 @@ def screenshot_mosaic_wrapper(filename, output_prefix="", directory=".", skip=1, skip : int Number of images to skip between 2 images in the mosaic. pad : int - Padding value between each images. + Padding value between each image. nb_columns : int Number of columns. axis : bool @@ -53,9 +65,11 @@ def screenshot_mosaic_wrapper(filename, output_prefix="", directory=".", skip=1, Colormap name in matplotlib format. return_path : bool Return path of the mosaic. + duration : int + Number of frames of the generated GIF lut : str Look up table. - Compute lut: bool + compute_lut: bool If set, will compute a look of table using compute_labels_map. Returns @@ -66,10 +80,16 @@ def screenshot_mosaic_wrapper(filename, output_prefix="", directory=".", skip=1, mosaic in array 2D """ data = nib.load(filename).get_data() + if len(data.dtype) > 0: + # Data is 4D and has been unpacked in a structured numpy array, + # we need to get the last dimension out for snapshotting. This + # is only a problem with some RGB images, for which the 4th + # dimension is encoded in the datatype of the array + data = data.view((data.dtype[0], len(data.dtype))) + data = np.nan_to_num(data) unique = np.unique(data) - - output_prefix = output_prefix.replace(' ', '_') + '_' + output_prefix = output_prefix.replace(" ", "_") + "_" if lut is not None or compute_lut: lut = compute_labels_map(lut, unique, compute_lut) @@ -82,21 +102,37 @@ def screenshot_mosaic_wrapper(filename, output_prefix="", directory=".", skip=1, if return_path: image_name = os.path.basename(str(filename)).split(".")[0] if isinstance(imgs_comb, list): - name = os.path.join(directory, output_prefix + image_name + '.gif') - imgs_comb[0].save(name, save_all=True, append_images=imgs_comb[1:], - duration=duration, loop=0) + name = os.path.join(directory, output_prefix + image_name + ".gif") + imgs_comb[0].save( + name, + save_all=True, + append_images=imgs_comb[1:], + duration=duration, + loop=0, + ) else: - name = os.path.join(directory, output_prefix + image_name + '.png') + name = os.path.join(directory, output_prefix + image_name + ".png") imgs_comb.save(name) + return name else: return imgs_comb -def screenshot_mosaic_blend(image, image_blend, output_prefix="", directory=".", - blend_val=0.5, skip=1, pad=20, nb_columns=15, - cmap=None, is_mask=False, lut=None, - compute_lut=False): +def screenshot_mosaic_blend( + image, + image_blend, + output_prefix="", + directory=".", + blend_val=0.5, + skip=1, + pad=20, + nb_columns=15, + cmap=None, + is_mask=False, + lut=None, + compute_lut=False, +): """ Compute a blend mosaic from an image and a mask @@ -124,40 +160,55 @@ def screenshot_mosaic_blend(image, image_blend, output_prefix="", directory=".", Image blend is a mask. lut : str Look up table - + compute_lut: bool + If set, will compute a look of table using compute_labels_map. Returns ------- name : string Path of the mosaic """ - mosaic_image = screenshot_mosaic_wrapper(image, skip=skip, pad=pad, - nb_columns=nb_columns, axis=False, - cmap=cmap, return_path=False) - mosaic_blend = screenshot_mosaic_wrapper(image_blend, skip=skip, pad=pad, - nb_columns=nb_columns, axis=False, - return_path=False, lut=lut, - compute_lut=compute_lut) - - output_prefix = output_prefix.replace(' ', '_') + '_' + mosaic_image = screenshot_mosaic_wrapper( + image, + skip=skip, + pad=pad, + nb_columns=nb_columns, + axis=False, + cmap=cmap, + return_path=False, + ) + mosaic_blend = screenshot_mosaic_wrapper( + image_blend, + skip=skip, + pad=pad, + nb_columns=nb_columns, + axis=False, + return_path=False, + lut=lut, + compute_lut=compute_lut, + ) + + output_prefix = output_prefix.replace(" ", "_") + "_" if is_mask: data = np.array(mosaic_blend) data[(data == (255, 255, 255)).all(axis=-1)] = (255, 0, 0) mosaic_blend = Image.fromarray(data, mode="RGB") + image_name = os.path.basename(str(image)).split(".")[0] if isinstance(mosaic_image, list): blend = [] for _, mosaic in enumerate(mosaic_image): - blend.append(Image.blend(mosaic, mosaic_blend, - alpha=blend_val)) - name = os.path.join(directory, output_prefix + image_name + '.gif') - blend[0].save(name, save_all=True, append_images=blend[1:], - duration=100, loop=0) + blend.append(Image.blend(mosaic, mosaic_blend, alpha=blend_val)) + name = os.path.join(directory, output_prefix + image_name + ".gif") + blend[0].save( + name, save_all=True, append_images=blend[1:], duration=100, loop=0 + ) else: blend = Image.blend(mosaic_image, mosaic_blend, alpha=blend_val) - name = os.path.join(directory, output_prefix + image_name + '.png') + name = os.path.join(directory, output_prefix + image_name + ".png") blend.save(name) + return name @@ -196,48 +247,61 @@ def screenshot_mosaic(data, skip, pad, nb_columns, axis, cmap): if max_val - min_val < 20 and max_val.is_integer(): min_val = data.min() max_val = np.percentile(data[data > 0], 99.99) - shape = ((data[:, :, 0].shape[1] + pad) * nb_rows + pad * nb_rows, - (data[:, :, 0].shape[0] + pad) * nb_columns + nb_columns * pad) + + shape = ( + (data[:, :, 0].shape[1] + 2 * pad) * nb_rows, + (data[:, :, 0].shape[0] + 2 * pad) * nb_columns, + ) padding = ((int(pad / 2), int(pad / 2)), (int(pad / 2), int(pad / 2))) + is_rgb = False if is_4d: time = data.shape[3] if time == 3: is_rgb = True + shape += (time,) padding += ((0, 0),) if not is_rgb: data = np.interp(data, xp=[min_val, max_val], fp=[0, 255]).astype( - dtype=np.uint8) + dtype=np.uint8 + ) mosaic = np.zeros(shape, dtype=np.uint8) - for i, idx in enumerate(range_row): corner = i % nb_columns row = int(i / nb_columns) curr_img = np.rot90(data[:, :, idx]) - curr_img = np.pad(curr_img, padding, 'constant').astype(dtype=np.uint8) + curr_img = np.pad(curr_img, padding, "constant").astype(dtype=np.uint8) curr_shape = curr_img.shape - mosaic[curr_shape[0] * row + row * pad: - row * curr_shape[0] + curr_shape[0] + row * pad, - curr_shape[1] * corner + corner * pad: - corner * curr_shape[1] + curr_shape[1] + corner * pad] = curr_img + row_pad = (curr_shape[0] + pad) * row + col_pad = (curr_shape[1] + pad) * corner + mosaic[ + row_pad : row_pad + curr_shape[0], col_pad : col_pad + curr_shape[1] + ] = curr_img + if axis and not is_4d: - mosaic = np.pad(mosaic, ((50, 50), (50, 50)), 'constant').astype( - dtype=np.uint8) + mosaic = np.pad(mosaic, ((50, 50), (50, 50)), "constant").astype( + dtype=np.uint8 + ) + img = Image.fromarray(mosaic) draw = ImageDraw.Draw(img) - fnt = ImageFont.truetype( - '/usr/share/fonts/truetype/freefont/FreeSans.ttf', 40) + "/usr/share/fonts/truetype/freefont/FreeSans.ttf", 40 + ) + draw.text([mosaic.shape[1] / 2, 0], "A", fill=255, font=fnt) - draw.text([mosaic.shape[1] / 2, mosaic.shape[0] - 40], "P", fill=255, - font=fnt) + draw.text( + [mosaic.shape[1] / 2, mosaic.shape[0] - 40], "P", fill=255, font=fnt + ) draw.text([0, mosaic.shape[0] / 2], "L", fill=255, font=fnt) - draw.text([mosaic.shape[1] - 40, mosaic.shape[0] / 2], "R", fill=255, - font=fnt) + draw.text( + [mosaic.shape[1] - 40, mosaic.shape[0] / 2], "R", fill=255, font=fnt + ) + mosaic = np.array(img, dtype=np.uint8) if cmap is not None: @@ -247,43 +311,51 @@ def screenshot_mosaic(data, skip, pad, nb_columns, axis, cmap): tmp = screenshot_3_axis(data, mosaic, cmap, is_4d) mosaic = np.vstack((tmp, mosaic)) del data + if is_4d and mosaic.shape[2] != 3: gif = [] for i in range(mosaic.shape[2]): img_t = np.uint8(np.clip(mosaic[:, :, i], 0, 255)) imgs_comb = Image.fromarray(img_t) + if mosaic[:, :, i].shape[1] > 1920: - basewidth = 1920 - wpercent = (basewidth / float(imgs_comb.size[0])) - hsize = int((float(imgs_comb.size[1]) * float(wpercent))) - imgs_comb = imgs_comb.resize((basewidth, hsize), - Image.ANTIALIAS) + base_width = 1920 + w_percent = base_width / float(imgs_comb.size[0]) + hsize = int((float(imgs_comb.size[1]) * float(w_percent))) + imgs_comb = imgs_comb.resize( + (base_width, hsize), Image.ANTIALIAS + ) draw = ImageDraw.Draw(imgs_comb) fnt = ImageFont.truetype( - '/usr/share/fonts/truetype/freefont/FreeSans.ttf', 40) - draw.text([0, 0], str(i) + "/" + str(mosaic.shape[2]), - fill=255, font=fnt) + "/usr/share/fonts/truetype/freefont/FreeSans.ttf", 40 + ) + draw.text( + [0, 0], str(i) + "/" + str(mosaic.shape[2]), fill=255, font=fnt + ) gif.append(imgs_comb.convert("RGB")) + return gif img = np.uint8(np.clip(mosaic, 0, 255)) imgs_comb = Image.fromarray(img) if mosaic[:, :].shape[1] > 1920: - basewidth = 1920 - wpercent = (basewidth / float(imgs_comb.size[0])) - hsize = int((float(imgs_comb.size[1]) * float(wpercent))) - imgs_comb = imgs_comb.resize((basewidth, hsize), Image.ANTIALIAS) - imgs_comb = imgs_comb.convert("RGB") - return imgs_comb + base_width = 1920 + w_percent = base_width / float(imgs_comb.size[0]) + hsize = int((float(imgs_comb.size[1]) * float(w_percent))) + imgs_comb = imgs_comb.resize((base_width, hsize), Image.ANTIALIAS) + + return imgs_comb.convert("RGB") def screenshot_3_axis(data, mosaic, cmap=None, is_4d=False): - middle = [data.shape[0] // 2 + 4, data.shape[1] // 2, - data.shape[2] // 2] - slice_display = [data[middle[0], :, :], data[:, middle[1], :], - data[:, :, middle[2]]] + middle = [data.shape[0] // 2 + 4, data.shape[1] // 2, data.shape[2] // 2] + slice_display = [ + data[middle[0], :, :], + data[:, middle[1], :], + data[:, :, middle[2]], + ] size = max(data.shape) image = np.array([]) @@ -296,50 +368,59 @@ def screenshot_3_axis(data, mosaic, cmap=None, is_4d=False): top = np.floor(pad_h / 2) bottom = pad_h - top padding = ((top, bottom), (left, right)) + if is_4d: padding += ((0, 0),) + img2 = np.pad(img, np.array(padding, dtype=int), "constant") img2 = np.rot90(img2) - if image.size == 0: - image = img2 - else: - image = np.hstack((image, img2)) + image = img2 if image.size == 0 else np.hstack((image, img2)) if is_4d: tmp = [] for i in range(image.shape[2]): three_axis = Image.fromarray(np.uint8(image[:, :, i])) three_axis_np = np.array(three_axis) - tmp.append(_resize_mosaic(mosaic[:, :, i], three_axis, - three_axis_np)) + tmp.append( + _resize_mosaic(mosaic[:, :, i], three_axis, three_axis_np) + ) + image = np.moveaxis(np.array(tmp), 0, 2) else: three_axis = Image.fromarray(np.uint8(image)) three_axis_np = np.array(three_axis) image = _resize_mosaic(mosaic, three_axis, three_axis_np) + if cmap is not None: colormap = get_cmap(cmap) image = np.array(colormap(image / 255.0) * 255).astype(dtype=np.uint8) + return np.array(image, dtype=np.uint8) def _resize_mosaic(mosaic, three_axis, three_axis_np): - ratio = min(mosaic.shape[0] / three_axis_np.shape[0], - mosaic.shape[1] / three_axis_np.shape[1]) + ratio = min( + mosaic.shape[0] / three_axis_np.shape[0], + mosaic.shape[1] / three_axis_np.shape[1], + ) three_axis = three_axis.resize( - (int(np.floor(three_axis_np.shape[1] * ratio)), - int(np.floor(three_axis_np.shape[0] * ratio)))) + ( + int(np.floor(three_axis_np.shape[1] * ratio)), + int(np.floor(three_axis_np.shape[0] * ratio)), + ) + ) three_axis_np = np.array(three_axis) tmp = np.zeros((three_axis_np.shape[0], mosaic.shape[1])) diff = np.abs(np.subtract(tmp.shape, three_axis_np.shape)) - tmp[diff[0]: three_axis_np.shape[0] + diff[0], - np.int(diff[1] / 2): three_axis_np.shape[1] + np.int( - diff[1] / 2)] = three_axis_np + tmp[ + diff[0] : three_axis_np.shape[0] + diff[0], + np.int(diff[1] / 2) : three_axis_np.shape[1] + np.int(diff[1] / 2), + ] = three_axis_np return tmp -def screenshot_fa_peaks(fa, peaks, directory='.'): +def screenshot_fa_peaks(fa, peaks, directory="."): """ Compute 3 view screenshot with peaks on FA. @@ -357,41 +438,41 @@ def screenshot_fa_peaks(fa, peaks, directory='.'): name : string Path of the mosaic """ - slice_name = ['sagittal', 'coronal', 'axial'] + slice_name = ["sagittal", "coronal", "axial"] data = nib.load(fa).get_data() evecs_data = nib.load(peaks).get_data() evecs = np.zeros(data.shape + (1, 3)) evecs[:, :, :, 0, :] = evecs_data[...] - - middle = [data.shape[0] // 2 + 4, data.shape[1] // 2, - data.shape[2] // 2] - - slice_display = [(middle[0], None, None), (None, middle[1], None), - (None, None, middle[2])] + middle = [data.shape[0] // 2 + 4, data.shape[1] // 2, data.shape[2] // 2] + slice_display = [ + (middle[0], None, None), + (None, middle[1], None), + (None, None, middle[2]), + ] concat = [] for j, slice_name in enumerate(slice_name): image_name = os.path.basename(str(peaks)).split(".")[0] - name = os.path.join(directory, image_name + '.png') - slice_actor = actor.slicer(data, interpolation='nearest', opacity=0.3) + name = os.path.join(directory, image_name + ".png") + slice_actor = actor.slicer(data, interpolation="nearest", opacity=0.3) peak_actor = actor.peak_slicer(evecs, colors=None) peak_actor.GetProperty().SetLineWidth(2.5) - - slice_actor.display(slice_display[j][0], slice_display[j][1], - slice_display[j][2]) - peak_actor.display(slice_display[j][0], slice_display[j][1], - slice_display[j][2]) + slice_actor.display( + slice_display[j][0], slice_display[j][1], slice_display[j][2] + ) + peak_actor.display( + slice_display[j][0], slice_display[j][1], slice_display[j][2] + ) renderer = window.Scene() - renderer.add(slice_actor) renderer.add(peak_actor) - center = slice_actor.GetCenter() pos = None viewup = None + center = slice_actor.GetCenter() if slice_name == "sagittal": pos = (center[0] - 350, center[1], center[2]) viewup = (0, 0, -1) @@ -404,15 +485,11 @@ def screenshot_fa_peaks(fa, peaks, directory='.'): camera = renderer.GetActiveCamera() camera.SetViewUp(viewup) - camera.SetPosition(pos) camera.SetFocalPoint(center) img = renderer_to_arr(renderer, (1080, 1080)) - if len(concat) == 0: - concat = img - else: - concat = np.hstack((concat, img)) + concat = img if len(concat) == 0 else np.hstack((concat, img)) imgs_comb = Image.fromarray(concat) imgs_comb.save(name) @@ -438,18 +515,34 @@ def screenshot_tracking(tracking, t1, directory="."): name : string Path of the mosaic """ - sft = load_tractogram(tracking, 'same') + sft = load_tractogram(tracking, "same") sft.to_vox() t1 = nib.load(t1) t1_data = t1.get_data() - slice_name = ['sagittal', 'coronal', 'axial'] - img_center = [(int(t1_data.shape[0] / 2) + 5, None, None), - (None, int(t1_data.shape[1] / 2), None), - (None, None, int(t1_data.shape[2] / 2))] - center = [(img_center[0][0] - 350 - (1 - t1.header.get_zooms()[0]) * 350, img_center[1][1], img_center[2][2]), - (img_center[0][0], img_center[1][1] + 350 + (1 - t1.header.get_zooms()[1]) * 350, img_center[2][2]), - (img_center[0][0], img_center[1][1], img_center[2][2] + 350 + (1 - t1.header.get_zooms()[2]) * 350)] + slice_name = ["sagittal", "coronal", "axial"] + img_center = [ + (int(t1_data.shape[0] / 2) + 5, None, None), + (None, int(t1_data.shape[1] / 2), None), + (None, None, int(t1_data.shape[2] / 2)), + ] + center = [ + ( + img_center[0][0] - (2 - t1.header.get_zooms()[0]) * 350, + img_center[1][1], + img_center[2][2], + ), + ( + img_center[0][0], + img_center[1][1] + (2 - t1.header.get_zooms()[1]) * 350, + img_center[2][2], + ), + ( + img_center[0][0], + img_center[1][1], + img_center[2][2] + (2 - t1.header.get_zooms()[2]) * 350, + ), + ] viewup = [(0, 0, -1), (0, 0, -1), (0, -1, 0)] size = (1920, 1080) @@ -462,76 +555,90 @@ def screenshot_tracking(tracking, t1, directory="."): for streamline in sft.streamlines: if it > 10000: break + if slice_idx in np.array(streamline, dtype=int)[:, i]: it += 1 - idx = np.where(np.array(streamline, dtype=int)[:, i] == \ - slice_idx)[0][0] - lower = idx - 2 - if lower < 0: - lower = 0 - upper = idx + 2 - if upper > len(streamline) - 1: - upper = len(streamline) - 1 + idx = np.where( + np.array(streamline, dtype=int)[:, i] == slice_idx + )[0][0] + + lower = max(idx - 2, 0) + upper = min(idx + 2, len(streamline) - 1) streamlines.append(streamline[lower:upper]) - ren = window.Scene() + min_val = np.min(t1_data[t1_data > 0]) + max_val = np.percentile(t1_data[t1_data > 0], 99) + t1_color = ( + np.float32(t1_data - min_val) + / np.float32(max_val - min_val) + * 255.0 + ) streamline_actor = actor.line(streamlines, linewidth=0.2) - ren.add(streamline_actor) + slice_actor = actor.slicer( + t1_color, opacity=0.8, value_range=(0, 255), interpolation="nearest" + ) + slice_actor.display( + img_center[i][0], img_center[i][1], img_center[i][2] + ) - min_val = np.min(t1_data[t1_data > 0]) - max_val = np.percentile(t1_data[t1_data > 0], 99) - t1_color = np.float32(t1_data - min_val) \ - / np.float32(max_val - min_val) * 255.0 - slice_actor = actor.slicer(t1_color, opacity=0.8, value_range=(0, 255), - interpolation='nearest') + ren = window.Scene() + ren.add(streamline_actor) ren.add(slice_actor) - slice_actor.display(img_center[i][0], img_center[i][1], - img_center[i][2]) camera = ren.GetActiveCamera() camera.SetViewUp(viewup[i]) center_cam = streamline_actor.GetCenter() camera.SetPosition(center[i]) - camera.SetFocalPoint((center_cam)) + camera.SetFocalPoint(center_cam) img2 = renderer_to_arr(ren, size) - if image.size == 0: - image = img2 - else: - image = np.hstack((image, img2)) + image = img2 if image.size == 0 else np.hstack((image, img2)) - streamlines = [] - it = 0 + streamlines, it = [], 0 for streamline in sft.streamlines: if it > 10000: break + it += 1 streamlines.append(streamline) - ren = window.Scene() streamline_actor = actor.line(streamlines, linewidth=0.2) + center = streamline_actor.GetCenter() + + ren = window.Scene() ren.add(streamline_actor) + camera = ren.GetActiveCamera() camera.SetViewUp(0, 0, -1) - center = streamline_actor.GetCenter() - camera.SetPosition(center[0], 350 + (1 - t1.header.get_zooms()[1]) * 350, center[2]) + camera.SetPosition( + center[0], 350 + (1 - t1.header.get_zooms()[1]) * 350, center[2] + ) camera.SetFocalPoint(center) + img2 = renderer_to_arr(ren, (3 * 1920, 1920)) image = np.vstack((image, img2)) imgs_comb = Image.fromarray(image) imgs_comb = imgs_comb.resize((3 * 1920, 1920 + 1080)) image_name = os.path.basename(str(tracking)).split(".")[0] - name = os.path.join(directory, image_name + '.png') + name = os.path.join(directory, image_name + ".png") imgs_comb.save(name) return name -def plot_proj_shell(ms, centroids, use_sym=True, use_sphere=True, - same_color=False, - rad=0.025, opacity=1.0, ofile=None, ores=(300, 300)): +def plot_proj_shell( + ms, + centroids, + use_sym=True, + use_sphere=True, + same_color=False, + rad=0.025, + opacity=1.0, + ofile=None, + ores=(300, 300), +): """ Plot each shell @@ -558,36 +665,51 @@ def plot_proj_shell(ms, centroids, use_sym=True, use_sphere=True, ------ """ global vtkcolors + if len(ms) > 10: vtkcolors = fury.colormap.distinguishable_colormap(nb_colors=len(ms)) - radius = np.interp(centroids, xp=[min(centroids), max(centroids)], - fp=[0, 1]) + + radius = np.interp( + centroids, xp=[min(centroids), max(centroids)], fp=[0, 1] + ) + ren = window.Scene() ren.SetBackground(1, 1, 1) if use_sphere: - sphere = get_sphere('symmetric724') + sphere = get_sphere("symmetric724") shape = (1, 1, 1, sphere.vertices.shape[0]) - fid, fname = mkstemp(suffix='_odf_slicer.mmap') - odfs = np.memmap(fname, dtype=np.float64, mode='w+', shape=shape) + fid, fname = mkstemp(suffix="_odf_slicer.mmap") + odfs = np.memmap(fname, dtype=np.float64, mode="w+", shape=shape) odfs[:] = 1 odfs[..., 0] = 1 affine = np.eye(4) + for i, shell in enumerate(ms): - sphere_actor = actor.odf_slicer(odfs, affine, sphere=sphere, - colormap='winter', scale=radius[i], - opacity=opacity) + sphere_actor = actor.odf_slicer( + odfs, + affine, + sphere=sphere, + colormap="winter", + scale=radius[i], + opacity=opacity, + ) ren.add(sphere_actor) for i, shell in enumerate(ms): if same_color: i = 0 - pts_actor = actor.point(shell * radius[i], vtkcolors[i], - point_radius=rad) + + pts_actor = actor.point( + shell * radius[i], vtkcolors[i], point_radius=rad + ) ren.add(pts_actor) + if use_sym: - pts_actor = actor.point(-shell * radius[i], vtkcolors[i], - point_radius=rad) + pts_actor = actor.point( + -shell * radius[i], vtkcolors[i], point_radius=rad + ) ren.add(pts_actor) + if ofile: - window.snapshot(ren, fname=ofile + '.png', size=ores) + window.snapshot(ren, fname=ofile + ".png", size=ores) diff --git a/dmriqcpy/viz/utils.py b/dmriqcpy/viz/utils.py index b4fc946..30c01bd 100644 --- a/dmriqcpy/viz/utils.py +++ b/dmriqcpy/viz/utils.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- - import fury import numpy as np import vtk +from plotly.graph_objs import Figure +import plotly.offline as off from vtk.util import numpy_support """ @@ -32,11 +33,13 @@ def analyse_qa(stats_per_subjects, stats_across_subjects, column_names): warning = {} for metric in column_names: warning[metric] = [] - std = stats_across_subjects.at['std', metric] - mean = stats_across_subjects.at['mean', metric] + std = stats_across_subjects.at["std", metric] + mean = stats_across_subjects.at["mean", metric] for name in stats_per_subjects.index: - if stats_per_subjects.at[name, metric] > mean + 2 * std or\ - stats_per_subjects.at[name, metric] < mean - 2 * std: + if ( + stats_per_subjects.at[name, metric] > mean + 2 * std + or stats_per_subjects.at[name, metric] < mean - 2 * std + ): warning[metric].append(name) return warning @@ -49,19 +52,39 @@ def dataframe_to_html(data_frame, index=True): ---------- data_frame : DataFrame DataFrame. - + index : bool + Print rows index labels Returns ------- data_frame_html : string HTML table. """ data_frame_html = data_frame.to_html(index=index).replace( - '', - '
') + '
', + '
', + ) return data_frame_html +def graph_to_html( + data, title, range_yaxis=None, width=500, height=500, include_plotlyjs=False +): + fig = Figure(data=data) + fig["layout"].update(title=title) + fig["layout"].update(width=width, height=height) + + if range_yaxis is not None: + fig["layout"]["yaxis"].update(range=range_yaxis) + + div = off.plot( + fig, + show_link=False, + include_plotlyjs=include_plotlyjs, + output_type="div", + ) + return div.replace("
", '
') + + def renderer_to_arr(ren, size): """ Convert DataFrame to HTML table. @@ -120,21 +143,26 @@ def renderer_to_arr(ren, size): def compute_labels_map(lut_fname, unique_vals, compute_lut): labels = {} if compute_lut: - labels[0] = np.array((0,0,0),dtype=np.int8) - vtkcolors = fury.colormap.distinguishable_colormap(nb_colors=len(unique_vals)) + labels[0] = np.array((0, 0, 0), dtype=np.int8) + vtkcolors = fury.colormap.distinguishable_colormap( + nb_colors=len(unique_vals) + ) for index, curr_label in enumerate(unique_vals[1:]): - labels[curr_label] = np.array((vtkcolors[index][0]*255, - vtkcolors[index][1]*255, - vtkcolors[index][2]*255), - dtype=np.int8) + labels[curr_label] = np.array( + ( + vtkcolors[index][0] * 255, + vtkcolors[index][1] * 255, + vtkcolors[index][2] * 255, + ), + dtype=np.int8, + ) else: with open(lut_fname) as f: for line in f: - tokens = ' '.join(line.split()).split() - if tokens and not tokens[0].startswith('#'): - labels[np.int(tokens[0])] = np.array((tokens[2], - tokens[3], - tokens[4]), - dtype=np.int8) + tokens = " ".join(line.split()).split() + if tokens and not tokens[0].startswith("#"): + labels[np.int(tokens[0])] = np.array( + (tokens[2], tokens[3], tokens[4]), dtype=np.int8 + ) return labels diff --git a/scripts/dmriqc_brain_extraction.py b/scripts/dmriqc_brain_extraction.py index f250425..3022406 100755 --- a/scripts/dmriqc_brain_extraction.py +++ b/scripts/dmriqc_brain_extraction.py @@ -2,21 +2,29 @@ # -*- coding: utf-8 -*- import argparse -import os -import shutil +from functools import partial -import itertools -from multiprocessing import Pool import numpy as np from dmriqcpy.io.report import Report -from dmriqcpy.io.utils import (add_online_arg, add_overwrite_arg, - assert_inputs_exist, assert_outputs_exist, - list_files_from_paths) -from dmriqcpy.analysis.stats import stats_mean_median -from dmriqcpy.viz.graph import graph_mean_median -from dmriqcpy.viz.screenshot import screenshot_mosaic_blend -from dmriqcpy.viz.utils import analyse_qa, dataframe_to_html +from dmriqcpy.io.utils import ( + add_online_arg, + add_overwrite_arg, + add_nb_columns_arg, + add_nb_threads_arg, + add_skip_arg, + assert_inputs_exist, + assert_list_arguments_equal_size, + assert_outputs_exist, + clean_output_directories, + list_files_from_paths, +) +from dmriqcpy.reporting.report import ( + generate_metric_reports_parallel, + generate_report_package, + get_generic_qa_stats_and_graph, +) +from dmriqcpy.viz.utils import dataframe_to_html DESCRIPTION = """ @@ -25,125 +33,82 @@ def _build_arg_parser(): - p = argparse.ArgumentParser(description=DESCRIPTION, - formatter_class=argparse.RawTextHelpFormatter) - p.add_argument('image_type', - help='Type of image (e.g. B0).') - - p.add_argument('output_report', - help='Filename of QC report (in html format).') - - p.add_argument('--no_bet', nargs='+', required=True, - help='A folder or a list of images with the skull in' - ' Nifti format.') - p.add_argument('--bet_mask', nargs='+', required=True, - help='Folder or a list of images of brain extraction masks' - ' in Nifti format.') - - p.add_argument('--skip', default=2, type=int, - help='Number of images skipped to build the ' - 'mosaic. [%(default)s]') - - p.add_argument('--nb_columns', default=12, type=int, - help='Number of columns for the mosaic. [%(default)s]') - - p.add_argument('--nb_threads', type=int, default=1, - help='Number of threads. [%(default)s]') - + p = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter + ) + + p.add_argument("image_type", help="Type of image (e.g. B0).") + p.add_argument("output_report", help="Filename of QC report (in html format).") + p.add_argument( + "--no_bet", + nargs="+", + required=True, + help="A folder or a list of images with the skull in Nifti format.", + ) + p.add_argument( + "--bet_mask", + nargs="+", + required=True, + help="A folder or a list of images of brain extraction masks in Nifti format.", + ) + + add_skip_arg(p) + add_nb_columns_arg(p) + add_nb_threads_arg(p) add_online_arg(p) add_overwrite_arg(p) return p -def _subj_parralel(images_no_bet, images_bet_mask, name, skip, - summary, nb_columns): - subjects_dict = {} - for subj_metric, mask in zip(images_no_bet, images_bet_mask): - curr_key = os.path.basename(subj_metric).split('.')[0] - screenshot_path = screenshot_mosaic_blend(subj_metric, mask, - output_prefix=name, - directory="data", - blend_val=0.3, - skip=skip, - nb_columns=nb_columns, - is_mask=True) - - summary_html = dataframe_to_html(summary.loc[curr_key].to_frame()) - subjects_dict[curr_key] = {} - subjects_dict[curr_key]['screenshot'] = screenshot_path - subjects_dict[curr_key]['stats'] = summary_html - return subjects_dict - - def main(): parser = _build_arg_parser() args = parser.parse_args() images_no_bet = list_files_from_paths(args.no_bet) images_bet_mask = list_files_from_paths(args.bet_mask) - - if not len(images_no_bet) == len(images_bet_mask): - parser.error("Not the same number of images in input.") + assert_list_arguments_equal_size(parser, images_no_bet, images_bet_mask) all_images = np.concatenate([images_no_bet, images_bet_mask]) assert_inputs_exist(parser, all_images) assert_outputs_exist(parser, args, [args.output_report, "data", "libs"]) + clean_output_directories() - if os.path.exists("data"): - shutil.rmtree("data") - os.makedirs("data") - - if os.path.exists("libs"): - shutil.rmtree("libs") - - metrics = images_no_bet - name = args.image_type - curr_metrics = ['Mean {}'.format(name), - 'Median {}'.format(name)] - - summary, stats = stats_mean_median(curr_metrics, metrics) - - warning_dict = {} - warning_dict[name] = analyse_qa(summary, stats, curr_metrics) - warning_images = [filenames for filenames in warning_dict[name].values()] - warning_list = np.concatenate([warning_images]) - warning_dict[name]['nb_warnings'] = len(np.unique(warning_list)) - - graphs = [] - graph = graph_mean_median('Mean {}'.format(name), curr_metrics, summary, - args.online) - graphs.append(graph) - - stats_html = dataframe_to_html(stats) - summary_dict = {} - summary_dict[name] = stats_html - - pool = Pool(args.nb_threads) - subjects_dict_pool = pool.starmap(_subj_parralel, - zip(np.array_split(np.array(images_no_bet), args.nb_threads), - np.array_split(np.array(images_bet_mask), args.nb_threads), - itertools.repeat(name), itertools.repeat(args.skip), - itertools.repeat(summary), itertools.repeat(args.nb_columns))) - - pool.close() - pool.join() - - metrics_dict = {} - subjects_dict = {} - for dict_sub in subjects_dict_pool: - for key in dict_sub: - subjects_dict[key] = dict_sub[key] - metrics_dict[name] = subjects_dict - + metrics, name = images_no_bet, args.image_type nb_subjects = len(images_no_bet) - report = Report(args.output_report) - report.generate(title="Quality Assurance BET " + args.image_type, - nb_subjects=nb_subjects, summary_dict=summary_dict, - graph_array=graphs, metrics_dict=metrics_dict, - warning_dict=warning_dict, - online=args.online) + summary, stats, qa_report, qa_graphs = get_generic_qa_stats_and_graph( + metrics, name, args.online + ) + warning_dict = {name: qa_report} + summary_dict = {name: dataframe_to_html(stats)} + + metrics_dict = { + name: generate_metric_reports_parallel( + zip(images_no_bet, images_bet_mask), + args.nb_threads, + nb_subjects // args.nb_threads, + report_package_generation_fn=partial( + generate_report_package, + stats_summary=summary, + skip=args.skip, + nb_columns=args.nb_columns, + blend_is_mask=True, + ), + ) + } -if __name__ == '__main__': + report = Report(args.output_report) + report.generate( + title="Quality Assurance BET " + args.image_type, + nb_subjects=nb_subjects, + summary_dict=summary_dict, + graph_array=qa_graphs, + metrics_dict=metrics_dict, + warning_dict=warning_dict, + online=args.online, + ) + + +if __name__ == "__main__": main() diff --git a/scripts/dmriqc_dti.py b/scripts/dmriqc_dti.py index fadfa14..0815f4a 100755 --- a/scripts/dmriqc_dti.py +++ b/scripts/dmriqc_dti.py @@ -2,22 +2,32 @@ # -*- coding: utf-8 -*- import argparse -import itertools -from multiprocessing import Pool +from functools import partial import os -import shutil import numpy as np -from dmriqcpy.analysis.stats import stats_mean_in_tissues from dmriqcpy.io.report import Report -from dmriqcpy.io.utils import (add_online_arg, add_overwrite_arg, - assert_inputs_exist, assert_outputs_exist, - list_files_from_paths) -from dmriqcpy.viz.graph import graph_mean_in_tissues -from dmriqcpy.viz.screenshot import (screenshot_fa_peaks, - screenshot_mosaic_wrapper) -from dmriqcpy.viz.utils import analyse_qa, dataframe_to_html +from dmriqcpy.io.utils import ( + add_online_arg, + add_overwrite_arg, + add_nb_columns_arg, + add_nb_threads_arg, + add_skip_arg, + assert_inputs_exist, + assert_list_arguments_equal_size, + assert_outputs_exist, + clean_output_directories, + list_files_from_paths, +) +from dmriqcpy.reporting.report import ( + generate_report_package, + generate_metric_reports_parallel, + get_generic_qa_stats_and_graph, +) +from dmriqcpy.viz.screenshot import screenshot_fa_peaks +from dmriqcpy.viz.utils import dataframe_to_html + DESCRIPTION = """ Compute the DTI report in HTML format. @@ -25,167 +35,137 @@ def _build_arg_parser(): - p = argparse.ArgumentParser(description=DESCRIPTION, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('output_report', - help='HTML report') - - p.add_argument('--fa', nargs='+', required=True, - help='Folder or FA images in Nifti format.') - - p.add_argument('--md', nargs='+', required=True, - help='Folder of MD images in Nifti format.') - - p.add_argument('--rd', nargs='+', required=True, - help='Folder or RD images in Nifti format.') - - p.add_argument('--ad', nargs='+', required=True, - help='Folder or AD images in Nifti format.') - - p.add_argument('--residual', nargs='+', required=True, - help='Folder or residual images in Nifti format.') - - p.add_argument('--evecs_v1', nargs='+', required=True, - help='Folder or evecs v1 images in Nifti format.') - - p.add_argument('--wm', nargs='+', required=True, - help='Folder or WM mask in Nifti format.') - - p.add_argument('--gm', nargs='+', required=True, - help='Folder or GM mask in Nifti format.') - - p.add_argument('--csf', nargs='+', required=True, - help='Folder or CSF mask in Nifti format.') - - p.add_argument('--skip', default=2, type=int, - help='Number of images skipped to build the ' - 'mosaic. [%(default)s]') - - p.add_argument('--nb_columns', default=12, type=int, - help='Number of columns for the mosaic. [%(default)s]') - - p.add_argument('--nb_threads', type=int, default=1, - help='Number of threads. [%(default)s]') - + p = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter + ) + + p.add_argument("output_report", help="Filename of QC report (in html format).") + + p.add_argument( + "--fa", nargs="+", required=True, help="Folder or list of FA images in Nifti format." + ) + + p.add_argument( + "--md", nargs="+", required=True, help="Folder or list of MD images in Nifti format." + ) + + p.add_argument( + "--rd", nargs="+", required=True, help="Folder or list of RD images in Nifti format." + ) + + p.add_argument( + "--ad", nargs="+", required=True, help="Folder or list of AD images in Nifti format." + ) + + p.add_argument( + "--residual", + nargs="+", + required=True, + help="Folder or list of residual images in Nifti format.", + ) + p.add_argument( + "--evecs_v1", + nargs="+", + required=True, + help="Folder or list of evecs v1 images in Nifti format.", + ) + + p.add_argument( + "--wm", nargs="+", required=True, help="Folder or list of WM mask in Nifti format." + ) + + p.add_argument( + "--gm", nargs="+", required=True, help="Folder or list of GM mask in Nifti format." + ) + + p.add_argument( + "--csf", nargs="+", required=True, help="Folder or list of CSF mask in Nifti format." + ) + + add_skip_arg(p) + add_nb_columns_arg(p) + add_nb_threads_arg(p) add_online_arg(p) add_overwrite_arg(p) return p -def _subj_parralel(subj_metric, summary, name, skip, nb_columns): - subjects_dict = {} - curr_key = os.path.basename(subj_metric).split('.')[0] - cmap = None - if name == "Residual": - cmap = "hot" - screenshot_path = screenshot_mosaic_wrapper(subj_metric, - output_prefix=name, - directory="data", skip=skip, - nb_columns=nb_columns, - cmap=cmap) - - summary_html = dataframe_to_html(summary.loc[curr_key].to_frame()) - subjects_dict[curr_key] = {} - subjects_dict[curr_key]['screenshot'] = screenshot_path - subjects_dict[curr_key]['stats'] = summary_html - return subjects_dict - - def main(): parser = _build_arg_parser() args = parser.parse_args() - fa = list_files_from_paths(args.fa) - md = list_files_from_paths(args.md) - rd = list_files_from_paths(args.rd) - ad = list_files_from_paths(args.ad) - residual = list_files_from_paths(args.residual) - evecs_v1 = list_files_from_paths(args.evecs_v1) - wm = list_files_from_paths(args.wm) - gm = list_files_from_paths(args.gm) - csf = list_files_from_paths(args.csf) - - if not len(fa) == len(md) == len(rd) == len(ad) == \ - len(residual) == len(evecs_v1) == len(wm) == len(gm) == len(csf): - parser.error("Not the same number of images in input.") - - all_images = np.concatenate([fa, md, rd, ad, residual, evecs_v1, wm, - gm, csf]) + ( + fa, md, rd, ad, residual, evecs_v1, wm, gm, csf + ) = images = [ + list_files_from_paths(args.fa), + list_files_from_paths(args.md), + list_files_from_paths(args.rd), + list_files_from_paths(args.ad), + list_files_from_paths(args.residual), + list_files_from_paths(args.evecs_v1), + list_files_from_paths(args.wm), + list_files_from_paths(args.gm), + list_files_from_paths(args.csf), + ] + + assert_list_arguments_equal_size(parser, *images) + all_images = np.concatenate(images) assert_inputs_exist(parser, all_images) assert_outputs_exist(parser, args, [args.output_report, "data", "libs"]) + clean_output_directories() - if os.path.exists("data"): - shutil.rmtree("data") - os.makedirs("data") - - if os.path.exists("libs"): - shutil.rmtree("libs") - - metrics_names = [[fa, 'FA'], [md, 'MD'], [rd, 'RD'], - [ad, 'AD'], [residual, "Residual"]] metrics_dict = {} summary_dict = {} graphs = [] warning_dict = {} - for metrics, name in metrics_names: - subjects_dict = {} - curr_metrics = ['Mean {} in WM'.format(name), - 'Mean {} in GM'.format(name), - 'Mean {} in CSF'.format(name), - 'Max {} in WM'.format(name)] - - summary, stats = stats_mean_in_tissues(curr_metrics, metrics, wm, - gm, csf) - - warning_dict[name] = analyse_qa(summary, stats, curr_metrics[:3]) - warning_list = np.concatenate( - [filenames for filenames in warning_dict[name].values()]) - warning_dict[name]['nb_warnings'] = len(np.unique(warning_list)) - - graph = graph_mean_in_tissues('Mean {}'.format(name), curr_metrics[:3], - summary, args.online) - graphs.append(graph) - - stats_html = dataframe_to_html(stats) - summary_dict[name] = stats_html - - pool = Pool(args.nb_threads) - subjects_dict_pool = pool.starmap(_subj_parralel, - zip(metrics, - itertools.repeat(summary), - itertools.repeat(name), - itertools.repeat(args.skip), - itertools.repeat( - args.nb_columns))) - - pool.close() - pool.join() - - for dict_sub in subjects_dict_pool: - for key in dict_sub: - subjects_dict[key] = dict_sub[key] - metrics_dict[name] = subjects_dict - - subjects_dict = {} - name = "Peaks" - for curr_fa, curr_evecs in zip(fa, evecs_v1): - evecs_filename = os.path.basename(curr_evecs).split('.')[0] - screenshot_path = screenshot_fa_peaks(curr_fa, curr_evecs, "data") - - subjects_dict[evecs_filename] = {} - subjects_dict[evecs_filename]['screenshot'] = screenshot_path - metrics_dict[name] = subjects_dict + for metrics, name in [ + [fa, "FA"], + [md, "MD"], + [rd, "RD"], + [ad, "AD"], + [residual, "Residual"], + ]: + summary, stats, qa_report, qa_graphs = get_generic_qa_stats_and_graph( + metrics, name, args.online + ) + warning_dict[name] = qa_report + summary_dict[name] = dataframe_to_html(stats) + graphs.extend(qa_graphs) + + cmap = "hot" if name == "Residual" else None + metrics_dict[name] = generate_metric_reports_parallel( + zip(metrics), + args.nb_threads, + len(metrics) // args.nb_threads, + report_package_generation_fn=partial( + generate_report_package, + stats_summary=summary, + skip=args.skip, + nb_columns=args.nb_columns, + cmap=cmap, + ), + ) + + metrics_dict["Peaks"] = { + os.path.basename(evecs).split('.')[0]: { + "screenshot": screenshot_fa_peaks(fa, evecs, "data") + } + for fa, evecs in zip(fa, evecs_v1) + } nb_subjects = len(fa) report = Report(args.output_report) - report.generate(title="Quality Assurance DTI metrics", - nb_subjects=nb_subjects, summary_dict=summary_dict, - graph_array=graphs, metrics_dict=metrics_dict, - warning_dict=warning_dict, - online=args.online) - - -if __name__ == '__main__': + report.generate( + title="Quality Assurance DTI metrics", + nb_subjects=nb_subjects, + summary_dict=summary_dict, + graph_array=graphs, + metrics_dict=metrics_dict, + warning_dict=warning_dict, + online=args.online, + ) + + +if __name__ == "__main__": main() diff --git a/scripts/dmriqc_dwi_protocol.py b/scripts/dmriqc_dwi_protocol.py index e71855f..723ec3f 100755 --- a/scripts/dmriqc_dwi_protocol.py +++ b/scripts/dmriqc_dwi_protocol.py @@ -3,23 +3,35 @@ import argparse import os -import shutil import numpy as np import pandas as pd -from dmriqcpy.analysis.utils import (dwi_protocol, read_protocol, - identify_shells, - build_ms_from_shell_idx) +from dmriqcpy.analysis.utils import ( + dwi_protocol, + get_bvecs_from_shells_idxs, + identify_shells, + read_protocol, +) from dmriqcpy.io.report import Report -from dmriqcpy.io.utils import (add_online_arg, add_overwrite_arg, - assert_inputs_exist, assert_outputs_exist, - list_files_from_paths) -from dmriqcpy.viz.graph import (graph_directions_per_shells, - graph_dwi_protocol, - graph_subjects_per_shells) +from dmriqcpy.io.utils import ( + add_online_arg, + add_overwrite_arg, + assert_inputs_exist, + assert_outputs_exist, + assert_list_arguments_equal_size, + clean_output_directories, + list_files_from_paths, +) +from dmriqcpy.reporting.report import get_qa_report +from dmriqcpy.viz.graph import ( + graph_directions_per_shells, + graph_dwi_protocol, + graph_subjects_per_shells, +) from dmriqcpy.viz.screenshot import plot_proj_shell -from dmriqcpy.viz.utils import analyse_qa, dataframe_to_html +from dmriqcpy.viz.utils import dataframe_to_html + DESCRIPTION = """ Compute DWI protocol report. @@ -27,31 +39,45 @@ def _build_arg_parser(): - p = argparse.ArgumentParser(description=DESCRIPTION, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('output_report', - help='Filename of QC report (in html format).') - - p.add_argument('--bval', nargs='+', required=True, - help='Folder or list of bval files.') - - p.add_argument('--bvec', nargs='+', required=True, - help='Folder or list of bvec files.') - - p.add_argument('--metadata', nargs='+', - help='Folder or list of json files to get the metadata.') - - p.add_argument('--dicom_fields', nargs='+', - default=["EchoTime", "RepetitionTime", "SliceThickness", - "Manufacturer", "ManufacturersModelName"], - help='DICOM fields used to compare information. ' - '%(default)s') - - p.add_argument('--tolerance', '-t', - metavar='INT', type=int, default=20, - help='The tolerated gap between the b-values to ' - 'extract\nand the actual b-values. [%(default)s]') + p = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter + ) + + p.add_argument("output_report", help="Filename of QC report (in html format).") + + p.add_argument( + "--bval", nargs="+", required=True, help="Folder or list of bval files." + ) + + p.add_argument( + "--bvec", nargs="+", required=True, help="Folder or list of bvec files." + ) + + p.add_argument( + "--metadata", nargs="+", help="Folder or list of json files to get the metadata." + ) + + p.add_argument( + "--dicom_fields", + nargs="+", + default=[ + "EchoTime", + "RepetitionTime", + "SliceThickness", + "Manufacturer", + "ManufacturersModelName", + ], + help="DICOM fields used to compare information. %(default)s", + ) + p.add_argument( + "--tolerance", + "-t", + metavar="INT", + type=int, + default=20, + help="The tolerated gap between the b-values to extract " + "and the actual b-values. [%(default)s]", + ) add_online_arg(p) add_overwrite_arg(p) @@ -63,85 +89,60 @@ def main(): parser = _build_arg_parser() args = parser.parse_args() - if args.metadata: - metadata = list_files_from_paths(args.metadata) - bval = list_files_from_paths(args.bval) bvec = list_files_from_paths(args.bvec) - if not len(bval) == len(bvec): - parser.error("Not the same number of images in input.") + files_to_validate = [bval, bvec] - stats_tags = [] - stats_tags_for_graph = [] + metadata = None if args.metadata: - if not len(metadata) == len(bval): - parser.error('Number of metadata files: {}.\n' - 'Number of bval files: {}.\n' - 'Not the same number of images ' - 'in input'.format(len(metadata), - len(bval))) - else: - stats_tags, stats_tags_for_graph,\ - stats_tags_for_graph_all = read_protocol(metadata, - args.dicom_fields) + metadata = list_files_from_paths(args.metadata) + files_to_validate.append(metadata) + assert_list_arguments_equal_size(parser, *files_to_validate) all_data = np.concatenate([bval, bvec]) assert_inputs_exist(parser, all_data) assert_outputs_exist(parser, args, [args.output_report, "data", "libs"]) + clean_output_directories() - if os.path.exists("data"): - shutil.rmtree("data") - os.makedirs("data") - - if os.path.exists("libs"): - shutil.rmtree("libs") + stats_tags = [] + stats_tags_for_graph = [] + stats_tags_for_graph_all = [] + if args.metadata: + ( + stats_tags, + stats_tags_for_graph, + stats_tags_for_graph_all, + ) = read_protocol(metadata, args.dicom_fields) name = "DWI Protocol" summary, stats_for_graph, stats_all, shells = dwi_protocol(bval) if stats_tags: - for curr_column in stats_tags: - tag = curr_column[0] - curr_df = curr_column[1] - if 'complete_' in tag: + for tag, curr_df in stats_tags: + if "complete_" in tag: metric = curr_df.columns[0] for nSub in curr_df.index: - currKey = [nKey for nKey in summary.keys() if nSub in nKey] - summary[currKey[0]][metric] = curr_df[metric][nSub] + curr_key = [nKey for nKey in summary.keys() if nSub in nKey] + summary[curr_key[0]][metric] = curr_df[metric][nSub] if not isinstance(stats_tags_for_graph, list): - stats_for_graph = pd.concat([stats_for_graph, stats_tags_for_graph], - axis=1, join="inner") - stats_all = pd.concat([stats_all, stats_tags_for_graph_all], - axis=1, join="inner") - - warning_dict = {} - warning_dict[name] = analyse_qa(stats_for_graph, stats_all, - stats_all.columns) - warning_images = [filenames for filenames in warning_dict[name].values()] - warning_list = np.concatenate(warning_images) - warning_dict[name]['nb_warnings'] = len(np.unique(warning_list)) - - stats_html = dataframe_to_html(stats_all) - summary_dict = {} - summary_dict[name] = stats_html + stats_for_graph = pd.concat([stats_for_graph, stats_tags_for_graph], axis=1, join="inner") + stats_all = pd.concat([stats_all, stats_tags_for_graph_all], axis=1, join="inner") + + warning_dict = {name: get_qa_report(stats_for_graph, stats_all, stats_all.columns)} + summary_dict = {name: dataframe_to_html(stats_all)} if args.metadata: for curr_tag in stats_tags: - if 'complete_' not in curr_tag[0]: + if "complete_" not in curr_tag[0]: summary_dict[curr_tag[0]] = dataframe_to_html(curr_tag[1]) - graphs = [] - - graphs.append( - graph_directions_per_shells("Nbr directions per shell", - shells, args.online)) - - graphs.append(graph_subjects_per_shells("Nbr subjects per shell", - shells, args.online)) + graphs = [ + graph_directions_per_shells("Nbr directions per shell", shells, not args.online), + graph_subjects_per_shells("Nbr subjects per shell", shells, not args.online), + ] for c in stats_for_graph.keys(): - graph = graph_dwi_protocol(c, c, stats_for_graph, args.online) - graphs.append(graph) + graphs.append(graph_dwi_protocol(c, c, stats_for_graph, not args.online)) subjects_dict = {} for curr_bval, curr_bvec in zip(bval, bvec): @@ -150,35 +151,40 @@ def main(): points = np.genfromtxt(curr_bvec) if points.shape[0] == 3: points = points.T - bvals = np.genfromtxt(curr_bval) - centroids, shell_idx = identify_shells(bvals) - ms = build_ms_from_shell_idx(points, shell_idx) - plot_proj_shell(ms, centroids, use_sym=True, use_sphere=True, - same_color=False, rad=0.025, opacity=0.2, - ofile=os.path.join("data", name.replace(" ", "_") + - "_" + curr_subj), - ores=(800, 800)) - subjects_dict[curr_subj]['screenshot'] = os.path.join("data", - name.replace(" ", - "_") + - "_" + - curr_subj + - '.png') - metrics_dict = {} + centroids, shell_idx = identify_shells(np.genfromtxt(curr_bval)) + plot_proj_shell( + get_bvecs_from_shells_idxs(points, shell_idx), + centroids, + opacity=0.2, + ofile=os.path.join( + "data", + name.replace(" ", "_") + "_" + curr_subj + ), + ores=(800, 800), + ) + subjects_dict[curr_subj]["screenshot"] = os.path.join( + "data", + name.replace(" ", "_") + "_" + curr_subj + ".png" + ) + for subj in bval: curr_subj = os.path.basename(subj).split('.')[0] summary_html = dataframe_to_html(summary[subj]) - subjects_dict[curr_subj]['stats'] = summary_html - metrics_dict[name] = subjects_dict + subjects_dict[curr_subj]["stats"] = summary_html + metrics_dict = {name: subjects_dict} nb_subjects = len(bval) report = Report(args.output_report) - report.generate(title="Quality Assurance DWI protocol", - nb_subjects=nb_subjects, metrics_dict=metrics_dict, - summary_dict=summary_dict, graph_array=graphs, - warning_dict=warning_dict, - online=args.online) - - -if __name__ == '__main__': + report.generate( + title="Quality Assurance DWI protocol", + nb_subjects=nb_subjects, + metrics_dict=metrics_dict, + summary_dict=summary_dict, + graph_array=graphs, + warning_dict=warning_dict, + online=args.online, + ) + + +if __name__ == "__main__": main() diff --git a/scripts/dmriqc_fodf.py b/scripts/dmriqc_fodf.py index 3ff3ce6..cf0eb69 100755 --- a/scripts/dmriqc_fodf.py +++ b/scripts/dmriqc_fodf.py @@ -2,21 +2,29 @@ # -*- coding: utf-8 -*- import argparse -import os -import shutil -import itertools -from multiprocessing import Pool +from functools import partial import numpy as np -from dmriqcpy.analysis.stats import stats_mean_in_tissues from dmriqcpy.io.report import Report -from dmriqcpy.io.utils import (add_online_arg, add_overwrite_arg, - assert_inputs_exist, assert_outputs_exist, - list_files_from_paths) -from dmriqcpy.viz.graph import graph_mean_in_tissues -from dmriqcpy.viz.screenshot import screenshot_mosaic_wrapper -from dmriqcpy.viz.utils import analyse_qa, dataframe_to_html +from dmriqcpy.io.utils import ( + add_online_arg, + add_overwrite_arg, + add_nb_columns_arg, + add_nb_threads_arg, + add_skip_arg, + assert_inputs_exist, + assert_list_arguments_equal_size, + assert_outputs_exist, + clean_output_directories, + list_files_from_paths, +) +from dmriqcpy.reporting.report import ( + generate_metric_reports_parallel, + generate_report_package, + get_qa_stats_and_graph_in_tissues, +) +from dmriqcpy.viz.utils import dataframe_to_html DESCRIPTION = """ @@ -25,142 +33,126 @@ def _build_arg_parser(): - p = argparse.ArgumentParser(description=DESCRIPTION, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('output_report', - help='HTML report') - - p.add_argument('--afd_max', nargs='+', required=True, - help='Folder or list of AFD max images in Nifti format.') - - p.add_argument('--afd_sum', nargs='+', required=True, - help='Folder or list of AFD sum images in Nifti format.') - - p.add_argument('--afd_total', nargs='+', required=True, - help='Folder or list of AFD total images in Nifti format.') - - p.add_argument('--nufo', nargs='+', required=True, - help='Folder or list of NUFO max images in Nifti format.') - - p.add_argument('--wm', nargs='+', required=True, - help='Folder or list of WM mask in Nifti format.') - - p.add_argument('--gm', nargs='+', required=True, - help='Folder or list of GM mask in Nifti format.') - - p.add_argument('--csf', nargs='+', required=True, - help='Folder or list of CSF mask in Nifti format.') - - p.add_argument('--skip', default=2, type=int, - help='Number of images skipped to build the ' - 'mosaic. [%(default)s]') - - p.add_argument('--nb_columns', default=12, type=int, - help='Number of columns for the mosaic. [%(default)s]') - - p.add_argument('--nb_threads', type=int, default=1, - help='Number of threads. [%(default)s]') - + p = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter + ) + + p.add_argument("output_report", help="Filename of QC report (in html format).") + p.add_argument( + "--afd_max", + nargs="+", + required=True, + help="Folder or list of AFD max images in Nifti format.", + ) + p.add_argument( + "--afd_sum", + nargs="+", + required=True, + help="Folder or list of AFD sum images in Nifti format.", + ) + p.add_argument( + "--afd_total", + nargs="+", + required=True, + help="Folder or list of AFD total images in Nifti format.", + ) + p.add_argument( + "--nufo", + nargs="+", + required=True, + help="Folder or list of NUFO max images in Nifti format.", + ) + + p.add_argument( + "--wm", nargs="+", required=True, help="Folder or list of WM mask in Nifti format." + ) + + p.add_argument( + "--gm", nargs="+", required=True, help="Folder or list of GM mask in Nifti format." + ) + + p.add_argument( + "--csf", nargs="+", required=True, help="Folder or list of CSF mask in Nifti format." + ) + + add_skip_arg(p) + add_nb_columns_arg(p) + add_nb_threads_arg(p) add_online_arg(p) add_overwrite_arg(p) return p -def _subj_parralel(subj_metric, summary, name, skip, nb_columns): - subjects_dict = {} - curr_key = os.path.basename(subj_metric).split('.')[0] - screenshot_path = screenshot_mosaic_wrapper(subj_metric, - output_prefix=name, - directory="data", skip=skip, - nb_columns=nb_columns) - - summary_html = dataframe_to_html(summary.loc[curr_key].to_frame()) - subjects_dict[curr_key] = {} - subjects_dict[curr_key]['screenshot'] = screenshot_path - subjects_dict[curr_key]['stats'] = summary_html - return subjects_dict - - def main(): parser = _build_arg_parser() args = parser.parse_args() - afd_max = list_files_from_paths(args.afd_max) - afd_sum = list_files_from_paths(args.afd_sum) - afd_total = list_files_from_paths(args.afd_total) - nufo = list_files_from_paths(args.nufo) - wm = list_files_from_paths(args.wm) - gm = list_files_from_paths(args.gm) - csf = list_files_from_paths(args.csf) - - if not len(afd_max) == len(afd_sum) == len(afd_total) ==\ - len(nufo) == len(wm) == len(gm) == len(csf): - parser.error("Not the same number of images in input.") - - all_images = np.concatenate([afd_max, afd_sum, afd_total, - nufo, wm, gm, csf]) + ( + afd_max, + afd_sum, + afd_total, + nufo, + wm, + gm, + csf + ) = images = [ + list_files_from_paths(args.afd_max), + list_files_from_paths(args.afd_sum), + list_files_from_paths(args.afd_total), + list_files_from_paths(args.nufo), + list_files_from_paths(args.wm), + list_files_from_paths(args.gm), + list_files_from_paths(args.csf) + ] + + assert_list_arguments_equal_size(parser, *images) + all_images = np.concatenate(images) assert_inputs_exist(parser, all_images) assert_outputs_exist(parser, args, [args.output_report, "data", "libs"]) + clean_output_directories() - if os.path.exists("data"): - shutil.rmtree("data") - os.makedirs("data") - - if os.path.exists("libs"): - shutil.rmtree("libs") - - metrics_names = [[afd_max, 'AFD_max'], [afd_sum, 'AFD_sum'], - [afd_total, 'AFD_total'], [nufo, 'NUFO']] metrics_dict = {} summary_dict = {} graphs = [] warning_dict = {} - for metrics, name in metrics_names: - subjects_dict = {} - curr_metrics = ['Mean {} in WM'.format(name), - 'Mean {} in GM'.format(name), - 'Mean {} in CSF'.format(name), - 'Max {} in WM'.format(name)] - - summary, stats = stats_mean_in_tissues(curr_metrics, metrics, wm, - gm, csf) - warning_dict[name] = analyse_qa(summary, stats, curr_metrics[:3]) - warning_list = np.concatenate([filenames for filenames in warning_dict[name].values()]) - warning_dict[name]['nb_warnings'] = len(np.unique(warning_list)) - - graph = graph_mean_in_tissues('Mean {}'.format(name), curr_metrics[:3], - summary, args.online) - graphs.append(graph) - - stats_html = dataframe_to_html(stats) - summary_dict[name] = stats_html - pool = Pool(args.nb_threads) - subjects_dict_pool = pool.starmap(_subj_parralel, - zip(metrics, - itertools.repeat(summary), - itertools.repeat(name), - itertools.repeat(args.skip), - itertools.repeat(args.nb_columns))) - - pool.close() - pool.join() - - for dict_sub in subjects_dict_pool: - for key in dict_sub: - curr_key = os.path.basename(key).split('.')[0] - subjects_dict[curr_key] = dict_sub[curr_key] - metrics_dict[name] = subjects_dict + for metrics, name in [ + [afd_max, "AFD_max"], + [afd_sum, "AFD_sum"], + [afd_total, "AFD_total"], + [nufo, "NUFO"], + ]: + summary, stats, qa_report, qa_graphs = get_qa_stats_and_graph_in_tissues( + metrics, name, wm, gm, csf, args.online + ) + warning_dict[name] = qa_report + summary_dict[name] = dataframe_to_html(stats) + graphs.extend(qa_graphs) + + metrics_dict[name] = generate_metric_reports_parallel( + zip(metrics), + args.nb_threads, + len(metrics) // args.nb_threads, + report_package_generation_fn=partial( + generate_report_package, + stats_summary=summary, + skip=args.skip, + nb_columns=args.nb_columns, + ), + ) nb_subjects = len(afd_max) report = Report(args.output_report) - report.generate(title="Quality Assurance FODF metrics", - nb_subjects=nb_subjects, summary_dict=summary_dict, - graph_array=graphs, metrics_dict=metrics_dict, - warning_dict=warning_dict, - online=args.online) - - -if __name__ == '__main__': + report.generate( + title="Quality Assurance FODF metrics", + nb_subjects=nb_subjects, + summary_dict=summary_dict, + graph_array=graphs, + metrics_dict=metrics_dict, + warning_dict=warning_dict, + online=args.online, + ) + + +if __name__ == "__main__": main() diff --git a/scripts/dmriqc_frf.py b/scripts/dmriqc_frf.py index 7101200..7bf13a4 100755 --- a/scripts/dmriqc_frf.py +++ b/scripts/dmriqc_frf.py @@ -5,15 +5,17 @@ import os import shutil -import numpy as np - -from dmriqcpy.analysis.stats import stats_frf from dmriqcpy.io.report import Report -from dmriqcpy.io.utils import (add_online_arg, add_overwrite_arg, - assert_inputs_exist, assert_outputs_exist, - list_files_from_paths) -from dmriqcpy.viz.graph import graph_frf_eigen, graph_frf_b0 -from dmriqcpy.viz.utils import analyse_qa, dataframe_to_html +from dmriqcpy.io.utils import ( + add_online_arg, + add_overwrite_arg, + assert_inputs_exist, + assert_outputs_exist, + clean_output_directories, + list_files_from_paths, +) +from dmriqcpy.reporting.report import get_frf_qa_stats_and_graph +from dmriqcpy.viz.utils import dataframe_to_html DESCRIPTION = """ @@ -22,15 +24,16 @@ def _build_arg_parser(): - p = argparse.ArgumentParser(description=DESCRIPTION, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('frf', nargs='+', - help='Folder or list of fiber response function (frf) ' - 'files (in txt format).') + p = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter + ) - p.add_argument('output_report', - help='Filename of QC report (in html format).') + p.add_argument( + "frf", + nargs="+", + help="Folder or list of fiber response function (frf) files (in txt format).", + ) + p.add_argument("output_report", help="Filename of QC report (in html format).") add_online_arg(p) add_overwrite_arg(p) @@ -46,46 +49,33 @@ def main(): assert_inputs_exist(parser, frf) assert_outputs_exist(parser, args, [args.output_report, "libs"]) - - if os.path.exists("libs"): - shutil.rmtree("libs") + clean_output_directories(False) name = "FRF" - metrics_names = ["Mean Eigen value 1", "Mean Eigen value 2", "Mean B0"] - - warning_dict = {} - summary, stats = stats_frf(metrics_names, frf) - warning_dict[name] = analyse_qa(summary, stats, metrics_names) - warning_list = np.concatenate([filenames for filenames in warning_dict[name].values()]) - warning_dict[name]['nb_warnings'] = len(set(warning_list)) - - graphs = [] - graphs.append(graph_frf_eigen("EigenValues", metrics_names, summary, - args.online)) - graphs.append(graph_frf_b0("Mean B0", metrics_names, summary, args.online)) - + nb_subjects = len(frf) - summary_dict = {} - stats_html = dataframe_to_html(stats) - summary_dict[name] = stats_html + summary, stats, qa_report, qa_graphs = get_frf_qa_stats_and_graph(frf, args.online) + warning_dict = {name: qa_report} + summary_dict = {name: dataframe_to_html(stats)} - metrics_dict = {} - subjects_dict = {} - for subj_metric in frf: - curr_subj = os.path.basename(subj_metric).split('.')[0] - summary_html = dataframe_to_html(summary.loc[curr_subj].to_frame()) - subjects_dict[curr_subj] = {} - subjects_dict[curr_subj]['stats'] = summary_html - metrics_dict[name] = subjects_dict + metrics_dict = { + name: { + subj_metric: {"stats": dataframe_to_html(summary.loc[subj_metric])} + for subj_metric in frf + } + } - nb_subjects = len(frf) report = Report(args.output_report) - report.generate(title="Quality Assurance FRF", - nb_subjects=nb_subjects, summary_dict=summary_dict, - graph_array=graphs, metrics_dict=metrics_dict, - warning_dict=warning_dict, - online=args.online) - - -if __name__ == '__main__': + report.generate( + title="Quality Assurance FRF", + nb_subjects=nb_subjects, + summary_dict=summary_dict, + graph_array=qa_graphs, + metrics_dict=metrics_dict, + warning_dict=warning_dict, + online=args.online, + ) + + +if __name__ == "__main__": main() diff --git a/scripts/dmriqc_from_screenshot.py b/scripts/dmriqc_from_screenshot.py index 20b0874..baac1bb 100755 --- a/scripts/dmriqc_from_screenshot.py +++ b/scripts/dmriqc_from_screenshot.py @@ -8,8 +8,13 @@ import shutil from dmriqcpy.io.report import Report -from dmriqcpy.io.utils import (add_online_arg, add_overwrite_arg, - assert_inputs_exist, assert_outputs_exist) +from dmriqcpy.io.utils import ( + add_online_arg, + add_overwrite_arg, + assert_inputs_exist, + assert_outputs_exist, + clean_output_directories, +) from dmriqcpy.viz.utils import dataframe_to_html DESCRIPTION = """ @@ -18,20 +23,24 @@ def _build_arg_parser(): - p = argparse.ArgumentParser(description=DESCRIPTION, - formatter_class=argparse.RawTextHelpFormatter) + p = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter + ) - p.add_argument('output_report', - help='HTML report') + p.add_argument("output_report", help="Filename of QC report (in html format).") - p.add_argument('--data', nargs='+', - help='Screenshot and stats (optional) folders.') + p.add_argument( + "--data", + nargs="+", + required=True, + help="Screenshot and stats (optional) folders." + ) - p.add_argument('--stats', action="store_true", - help='Use included csv files.') + p.add_argument( + "--stats", action="store_true", help="Use included csv files." + ) - p.add_argument('--sym_link', action="store_true", - help='Use symlink instead of copy') + p.add_argument("--sym_link", action="store_true", help="Use symlink instead of copy.") add_online_arg(p) add_overwrite_arg(p) @@ -45,56 +54,60 @@ def main(): assert_inputs_exist(parser, args.data, are_directories=True) assert_outputs_exist(parser, args, [args.output_report, "data", "libs"]) + clean_output_directories() nb_subjects = len(os.listdir(args.data[0])) for folder in args.data[1:]: nb_subjects += len(os.listdir(folder)) - if os.path.exists("data"): - shutil.rmtree("data") - os.makedirs("data") - - if os.path.exists("libs"): - shutil.rmtree("libs") - metrics_dict = {} types = "" for folder in args.data: screenshot_files = [] stats_files = [] - for ext in ["png","jpeg","jpg"]: + for ext in ["png", "jpeg", "jpg"]: screenshot_files = screenshot_files + sorted(glob.glob(folder + '/*' + ext)) if args.stats: stats_files = sorted(glob.glob(folder + '/*.csv')) if len(screenshot_files) != len(stats_files): parser.error("Not same number of stats and screenshots") - name = os.path.basename(os.path.normpath(folder)) subjects_dict = {} for index, curr_screenshot in enumerate(screenshot_files): screenshot_basename = os.path.basename(curr_screenshot) if args.sym_link: - os.symlink(os.path.abspath(folder) + "/" + screenshot_basename, - "data/" + screenshot_basename) + os.symlink( + os.path.abspath(folder) + "/" + screenshot_basename, + "data/" + screenshot_basename, + ) else: - shutil.copyfile(curr_screenshot, - "data/" + screenshot_basename) + shutil.copyfile( + curr_screenshot, "data/" + screenshot_basename + ) subjects_dict[screenshot_basename] = {} - subjects_dict[screenshot_basename]['screenshot'] =\ + subjects_dict[screenshot_basename]['screenshot'] = ( "data/" + screenshot_basename + ) + if args.stats: - subjects_dict[screenshot_basename]['stats'] = dataframe_to_html(pd.read_csv(stats_files[index], index_col=False)) + stats = dataframe_to_html( + pd.read_csv(stats_files[index], index_col=False) + ) + subjects_dict[screenshot_basename]['stats'] = stats metrics_dict[name] = subjects_dict types += " {0}".format(name) report = Report(args.output_report) - report.generate(title="Quality Assurance" + types, - nb_subjects=nb_subjects, metrics_dict=metrics_dict, - online=args.online) + report.generate( + title="Quality Assurance" + types, + nb_subjects=nb_subjects, + metrics_dict=metrics_dict, + online=args.online, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/dmriqc_generic.py b/scripts/dmriqc_generic.py index c361c95..a75a733 100755 --- a/scripts/dmriqc_generic.py +++ b/scripts/dmriqc_generic.py @@ -2,21 +2,30 @@ # -*- coding: utf-8 -*- import argparse -import itertools -from multiprocessing import Pool -import os -import shutil +from functools import partial import numpy as np -from dmriqcpy.analysis.stats import stats_mean_in_tissues, stats_mean_median from dmriqcpy.io.report import Report -from dmriqcpy.io.utils import (add_online_arg, add_overwrite_arg, - assert_inputs_exist, assert_outputs_exist, - list_files_from_paths) -from dmriqcpy.viz.graph import graph_mean_in_tissues, graph_mean_median -from dmriqcpy.viz.screenshot import screenshot_mosaic_wrapper -from dmriqcpy.viz.utils import analyse_qa, dataframe_to_html +from dmriqcpy.io.utils import ( + add_online_arg, + add_overwrite_arg, + add_nb_columns_arg, + add_nb_threads_arg, + add_skip_arg, + assert_inputs_exist, + assert_list_arguments_equal_size, + assert_outputs_exist, + clean_output_directories, + list_files_from_paths, +) +from dmriqcpy.reporting.report import ( + generate_metric_reports_parallel, + generate_report_package, + get_generic_qa_stats_and_graph, + get_qa_stats_and_graph_in_tissues, +) +from dmriqcpy.viz.utils import dataframe_to_html DESCRIPTION = """ Compute report in HTML format from images. @@ -24,63 +33,54 @@ def _build_arg_parser(): - p = argparse.ArgumentParser(description=DESCRIPTION, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('image_type', - help='Type of image (e.g. B0 resample).') - - p.add_argument('output_report', - help='HTML report.') - - p.add_argument('--images', nargs='+', required=True, - help='Folder or list of images in Nifti format.') - - p.add_argument('--wm', nargs='+', - help='Folder or list of WM mask in Nifti format.') - - p.add_argument('--gm', nargs='+', - help='Folder or list of GM mask in Nifti format') - - p.add_argument('--csf', nargs='+', - help='Folder or list of CSF mask in Nifti format.') - - p.add_argument('--skip', default=2, type=int, - help='Number of images skipped to build the ' - 'mosaic. [%(default)s]') - - p.add_argument('--nb_columns', default=12, type=int, - help='Number of columns for the mosaic. [%(default)s]') - - p.add_argument('--duration', default=100, type=int, - help='Duration of each image in GIF in milliseconds.' - ' [%(default)s]') - - p.add_argument('--nb_threads', type=int, default=1, - help='Number of threads. [%(default)s]') - + p = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter + ) + + p.add_argument("image_type", help="Type of image (e.g. B0 resample).") + p.add_argument("output_report", help="Filename of QC report (in html format).") + + p.add_argument( + "--images", + nargs="+", + required=True, + help="Folder or list of images in Nifti format." + ) + + p.add_argument( + "--wm", + nargs="+", + help="Folder or list of WM mask in Nifti format." + ) + + p.add_argument( + "--gm", + nargs="+", + help="Folder or list of GM mask in Nifti format" + ) + + p.add_argument( + "--csf", + nargs="+", + help="Folder or list of CSF mask in Nifti format." + ) + + p.add_argument( + "--duration", + default=100, + type=int, + help="Duration of each image in GIF in milliseconds. [%(default)s]", + ) + + add_skip_arg(p) + add_nb_columns_arg(p) + add_nb_threads_arg(p) add_online_arg(p) add_overwrite_arg(p) return p -def _subj_parralel(subj_metric, summary, name, skip, nb_columns, duration): - subjects_dict = {} - curr_key = os.path.basename(subj_metric).split('.')[0] - screenshot_path = screenshot_mosaic_wrapper(subj_metric, - output_prefix=name, - directory="data", skip=skip, - nb_columns=nb_columns, - duration=duration) - - summary_html = dataframe_to_html(summary.loc[curr_key].to_frame()) - subjects_dict[curr_key] = {} - subjects_dict[curr_key]['screenshot'] = screenshot_path - subjects_dict[curr_key]['stats'] = summary_html - return subjects_dict - - def main(): parser = _build_arg_parser() args = parser.parse_args() @@ -93,79 +93,56 @@ def main(): wm = list_files_from_paths(args.wm) gm = list_files_from_paths(args.gm) csf = list_files_from_paths(args.csf) - if not len(images) == len(wm) == len(gm) == len(csf): - parser.error("Not the same number of images in input.") + assert_list_arguments_equal_size(parser, images, wm, gm, csf) with_tissues = True all_images = np.concatenate([images, wm, gm, csf]) assert_inputs_exist(parser, all_images) assert_outputs_exist(parser, args, [args.output_report, "data", "libs"]) - - if os.path.exists("data"): - shutil.rmtree("data") - os.makedirs("data") - - if os.path.exists("libs"): - shutil.rmtree("libs") + clean_output_directories() name = args.image_type + nb_subjects = len(images) if with_tissues: - curr_metrics = ['Mean {} in WM'.format(name), - 'Mean {} in GM'.format(name), - 'Mean {} in CSF'.format(name), - 'Max {} in WM'.format(name)] - summary, stats = stats_mean_in_tissues(curr_metrics, images, - wm, gm, csf) - graph = graph_mean_in_tissues('Mean {}'.format(name), curr_metrics[:3], - summary, args.online) + summary, stats, qa_report, qa_graphs = get_qa_stats_and_graph_in_tissues( + images, name, wm, gm, csf, args.online + ) else: - curr_metrics = ['Mean {}'.format(name), - 'Median {}'.format(name)] - summary, stats = stats_mean_median(curr_metrics, images) - graph = graph_mean_median('Mean {}'.format(name), curr_metrics, - summary, args.online) - - warning_dict = {} - warning_dict[name] = analyse_qa(summary, stats, curr_metrics[:3]) - warning_list = np.concatenate( - [filenames for filenames in warning_dict[name].values()]) - warning_dict[name]['nb_warnings'] = len(np.unique(warning_list)) - - graphs = [] - graphs.append(graph) - - stats_html = dataframe_to_html(stats) - summary_dict = {} - summary_dict[name] = stats_html - pool = Pool(args.nb_threads) - subjects_dict_pool = pool.starmap(_subj_parralel, - zip(images, - itertools.repeat(summary), - itertools.repeat(name), - itertools.repeat(args.skip), - itertools.repeat(args.nb_columns), - itertools.repeat(args.duration))) - pool.close() - pool.join() - - metrics_dict = {} - subjects_dict = {} - for dict_sub in subjects_dict_pool: - for key in dict_sub: - curr_key = os.path.basename(key).split('.')[0] - subjects_dict[curr_key] = dict_sub[curr_key] - metrics_dict[name] = subjects_dict + summary, stats, qa_report, qa_graphs = get_generic_qa_stats_and_graph( + images, name, args.online + ) + + warning_dict = {name: qa_report} + summary_dict = {name: dataframe_to_html(stats)} + + metrics_dict = { + name: generate_metric_reports_parallel( + zip(images), + args.nb_threads, + nb_subjects // args.nb_threads, + report_package_generation_fn=partial( + generate_report_package, + stats_summary=summary, + skip=args.skip, + nb_columns=args.nb_columns, + duration=args.duration, + ), + ) + } - nb_subjects = len(images) report = Report(args.output_report) - report.generate(title="Quality Assurance " + args.image_type, - nb_subjects=nb_subjects, summary_dict=summary_dict, - graph_array=graphs, metrics_dict=metrics_dict, - warning_dict=warning_dict, - online=args.online) - - -if __name__ == '__main__': + report.generate( + title="Quality Assurance " + args.image_type, + nb_subjects=nb_subjects, + summary_dict=summary_dict, + graph_array=qa_graphs, + metrics_dict=metrics_dict, + warning_dict=warning_dict, + online=args.online, + ) + + +if __name__ == "__main__": main() diff --git a/scripts/dmriqc_labels.py b/scripts/dmriqc_labels.py index 8c9cc52..f1b4e00 100755 --- a/scripts/dmriqc_labels.py +++ b/scripts/dmriqc_labels.py @@ -2,18 +2,26 @@ # -*- coding: utf-8 -*- import argparse -import os -import shutil +from functools import partial -import itertools -from multiprocessing import Pool import numpy as np from dmriqcpy.io.report import Report -from dmriqcpy.io.utils import (add_online_arg, add_overwrite_arg, - assert_inputs_exist, assert_outputs_exist, - list_files_from_paths) -from dmriqcpy.viz.screenshot import screenshot_mosaic_blend +from dmriqcpy.io.utils import ( + add_online_arg, + add_overwrite_arg, + add_nb_columns_arg, + add_nb_threads_arg, + add_skip_arg, + assert_inputs_exist, + assert_outputs_exist, + clean_output_directories, + list_files_from_paths, +) +from dmriqcpy.reporting.report import ( + generate_metric_reports_parallel, + generate_report_package, +) DESCRIPTION = """ @@ -25,60 +33,45 @@ def _build_arg_parser(): - p = argparse.ArgumentParser(description=DESCRIPTION, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('output_report', - help='HTML report.') - - p.add_argument('--t1', nargs='+', required=True, - help='Folder or list of T1 images in Nifti format.') - - p.add_argument('--label', nargs='+', required=True, - help='Folder or list of label images in Nifti format.') - - p.add_argument('--skip', default=2, type=int, - help='Number of images skipped to build the ' - 'mosaic. [%(default)s]') - - p.add_argument('--nb_columns', default=12, type=int, - help='Number of columns for the mosaic. [%(default)s]') - - p.add_argument('--lut', nargs=1, default="", - help='Look Up Table for RGB.') - - p.add_argument('--compute_lut', action='store_true', - help='Compute Look Up Table for RGB.') - - p.add_argument('--nb_threads', type=int, default=1, - help='Number of threads. [%(default)s]') - + p = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter + ) + + p.add_argument("output_report", help="Filename of QC report (in html format).") + + p.add_argument( + "--t1", + nargs="+", + required=True, + help="Folder or list of T1 images in Nifti format." + ) + + p.add_argument( + "--label", + nargs="+", + required=True, + help="Folder or list of label images in Nifti format." + ) + + p.add_argument( + "--lut", nargs=1, default="", help="Look Up Table for RGB." + ) + + p.add_argument( + "--compute_lut", + action="store_true", + help="Compute Look Up Table for RGB." + ) + + add_skip_arg(p) + add_nb_columns_arg(p) + add_nb_threads_arg(p) add_online_arg(p) add_overwrite_arg(p) return p -def _subj_parralel(t1, label, name, skip, nb_columns, lut, compute_lut): - subjects_dict = {} - if not lut: - lut = None - - screenshot_path = screenshot_mosaic_blend(t1, label, - output_prefix=name, - directory="data", - blend_val=0.4, - skip=skip, nb_columns=nb_columns, - lut=lut, - compute_lut=compute_lut) - - key = os.path.basename(t1).split('.')[0] - - subjects_dict[key] = {} - subjects_dict[key]['screenshot'] = screenshot_path - return subjects_dict - - def main(): parser = _build_arg_parser() args = parser.parse_args() @@ -90,7 +83,7 @@ def main(): parser.error("Not the same number of images in input.") if len(label) == 1: - label = label * len(args.t1) + label = label * len(t1) all_images = np.concatenate([t1, label]) if args.lut: @@ -98,42 +91,34 @@ def main(): assert_inputs_exist(parser, all_images) assert_outputs_exist(parser, args, [args.output_report, "data", "libs"]) - - if os.path.exists("data"): - shutil.rmtree("data") - os.makedirs("data") - - if os.path.exists("libs"): - shutil.rmtree("libs") + clean_output_directories() name = "Labels" + nb_subjects = len(t1) - pool = Pool(args.nb_threads) - subjects_dict_pool = pool.starmap(_subj_parralel, - zip(t1, - label, - itertools.repeat(name), - itertools.repeat(args.skip), - itertools.repeat(args.nb_columns), - itertools.repeat(args.lut), - itertools.repeat(args.compute_lut))) - pool.close() - pool.join() - - metrics_dict = {} - subjects_dict = {} - for dict_sub in subjects_dict_pool: - for key in dict_sub: - curr_key = os.path.basename(key).split('.')[0] - subjects_dict[curr_key] = dict_sub[curr_key] - metrics_dict[name] = subjects_dict + metrics_dict = { + name: generate_metric_reports_parallel( + zip(t1, label), + args.nb_threads, + nb_subjects // args.nb_threads, + report_package_generation_fn=partial( + generate_report_package, + skip=args.skip, + nb_columns=args.nb_columns, + lut=args.lut, + compute_lut=args.compute_lut + ), + ) + } - nb_subjects = len(t1) report = Report(args.output_report) - report.generate(title="Quality Assurance labels", - nb_subjects=nb_subjects, metrics_dict=metrics_dict, - online=args.online) + report.generate( + title="Quality Assurance labels", + nb_subjects=nb_subjects, + metrics_dict=metrics_dict, + online=args.online, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/dmriqc_registration.py b/scripts/dmriqc_registration.py index 0d704e8..1b2c2df 100755 --- a/scripts/dmriqc_registration.py +++ b/scripts/dmriqc_registration.py @@ -2,22 +2,29 @@ # -*- coding: utf-8 -*- import argparse -import os -import shutil +from functools import partial -import itertools -from multiprocessing import Pool import numpy as np - -from dmriqcpy.analysis.stats import stats_mean_in_tissues from dmriqcpy.io.report import Report -from dmriqcpy.io.utils import (add_online_arg, add_overwrite_arg, - assert_inputs_exist, assert_outputs_exist, - list_files_from_paths) -from dmriqcpy.viz.graph import graph_mean_in_tissues -from dmriqcpy.viz.screenshot import screenshot_mosaic_blend -from dmriqcpy.viz.utils import analyse_qa, dataframe_to_html +from dmriqcpy.io.utils import ( + add_online_arg, + add_overwrite_arg, + add_nb_columns_arg, + add_nb_threads_arg, + add_skip_arg, + assert_inputs_exist, + assert_list_arguments_equal_size, + assert_outputs_exist, + clean_output_directories, + list_files_from_paths, +) +from dmriqcpy.reporting.report import ( + generate_metric_reports_parallel, + generate_report_package, + get_qa_stats_and_graph_in_tissues, +) +from dmriqcpy.viz.utils import dataframe_to_html DESCRIPTION = """ @@ -26,60 +33,43 @@ def _build_arg_parser(): - p = argparse.ArgumentParser(description=DESCRIPTION, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('output_report', - help='HTML report') - - p.add_argument('--t1_warped', nargs='+', required=True, - help='Folder or list of T1 registered images ' - 'in Nifti format.') - - p.add_argument('--rgb', nargs='+', required=True, - help='Folder or list of RGB images in Nifti format.') - - p.add_argument('--wm', nargs='+', required=True, - help='Folder or list of WM mask in Nifti format.') - - p.add_argument('--gm', nargs='+', required=True, - help='Folder or list of GM mask in Nifti format.') - - p.add_argument('--csf', nargs='+', required=True, - help='Folder or list of CSF mask in Nifti format.') - - p.add_argument('--skip', default=2, type=int, - help='Number of images skipped to build the ' - 'mosaic. [%(default)s]') - - p.add_argument('--nb_columns', default=12, type=int, - help='Number of columns for the mosaic. [%(default)s]') - - p.add_argument('--nb_threads', type=int, default=1, - help='Number of threads. [%(default)s]') - + p = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter + ) + + p.add_argument("output_report", help="Filename of QC report (in html format).") + p.add_argument( + "--t1_warped", + nargs="+", + required=True, + help="Folder or list of T1 registered images in Nifti format.", + ) + + p.add_argument( + "--rgb", nargs="+", required=True, help="Folder or list of RGB images in Nifti format." + ) + + p.add_argument( + "--wm", nargs="+", required=True, help="Folder or list of WM mask in Nifti format." + ) + + p.add_argument( + "--gm", nargs="+", required=True, help="Folder or list of GM mask in Nifti format." + ) + + p.add_argument( + "--csf", nargs="+", required=True, help="Folder or list of CSF mask in Nifti format." + ) + + add_skip_arg(p) + add_nb_columns_arg(p) + add_nb_threads_arg(p) add_online_arg(p) add_overwrite_arg(p) return p -def _subj_parralel(t1_metric, rgb_metric, summary, name, skip, nb_columns): - subjects_dict = {} - curr_key = os.path.basename(t1_metric).split('.')[0] - screenshot_path = screenshot_mosaic_blend(t1_metric, rgb_metric, - output_prefix=name, - directory="data", - blend_val=0.5, - skip=skip, nb_columns=nb_columns) - - summary_html = dataframe_to_html(summary.loc[curr_key].to_frame()) - subjects_dict[curr_key] = {} - subjects_dict[curr_key]['screenshot'] = screenshot_path - subjects_dict[curr_key]['stats'] = summary_html - return subjects_dict - - def main(): parser = _build_arg_parser() args = parser.parse_args() @@ -90,70 +80,47 @@ def main(): gm = list_files_from_paths(args.gm) csf = list_files_from_paths(args.csf) - if not len(t1_warped) == len(rgb) == len(wm) ==\ - len(gm) == len(csf): - parser.error("Not the same number of images in input.") - + assert_list_arguments_equal_size(parser, t1_warped, rgb, wm, gm, csf) all_images = np.concatenate([t1_warped, rgb, wm, gm, csf]) assert_inputs_exist(parser, all_images) assert_outputs_exist(parser, args, [args.output_report, "data", "libs"]) - - if os.path.exists("data"): - shutil.rmtree("data") - os.makedirs("data") - - if os.path.exists("libs"): - shutil.rmtree("libs") + clean_output_directories() name = "Register T1" - curr_metrics = ['Mean {} in WM'.format(name), - 'Mean {} in GM'.format(name), - 'Mean {} in CSF'.format(name), - 'Max {} in WM'.format(name)] - - warning_dict = {} - summary, stats = stats_mean_in_tissues(curr_metrics, t1_warped, - wm, gm, csf) - warning_dict[name] = analyse_qa(summary, stats, curr_metrics[:3]) - warning_list = np.concatenate([filenames for filenames in warning_dict[name].values()]) - warning_dict[name]['nb_warnings'] = len(np.unique(warning_list)) - - graphs = [] - graph = graph_mean_in_tissues('Mean {}'.format(name), curr_metrics[:3], - summary, args.online) - graphs.append(graph) - - stats_html = dataframe_to_html(stats) - summary_dict = {} - summary_dict[name] = stats_html - - pool = Pool(args.nb_threads) - subjects_dict_pool = pool.starmap(_subj_parralel, - zip(t1_warped, - rgb, - itertools.repeat(summary), - itertools.repeat(name), - itertools.repeat(args.skip), - itertools.repeat(args.nb_columns))) - pool.close() - pool.join() - - metrics_dict = {} - subjects_dict = {} - for dict_sub in subjects_dict_pool: - for key in dict_sub: - curr_key = os.path.basename(key).split('.')[0] - subjects_dict[curr_key] = dict_sub[curr_key] - metrics_dict[name] = subjects_dict - nb_subjects = len(t1_warped) - report = Report(args.output_report) - report.generate(title="Quality Assurance registration", - nb_subjects=nb_subjects, summary_dict=summary_dict, - graph_array=graphs, metrics_dict=metrics_dict, - warning_dict=warning_dict, - online=args.online) + summary, stats, qa_report, qa_graphs = get_qa_stats_and_graph_in_tissues( + t1_warped, name, wm, gm, csf, args.online + ) + + warning_dict = {name: qa_report} + summary_dict = {name: dataframe_to_html(stats)} + + metrics_dict = { + name: generate_metric_reports_parallel( + zip(t1_warped, rgb), + args.nb_threads, + nb_subjects // args.nb_threads, + report_package_generation_fn=partial( + generate_report_package, + stats_summary=summary, + skip=args.skip, + nb_columns=args.nb_columns, + ), + ) + } -if __name__ == '__main__': + report = Report(args.output_report) + report.generate( + title="Quality Assurance registration", + nb_subjects=nb_subjects, + summary_dict=summary_dict, + graph_array=qa_graphs, + metrics_dict=metrics_dict, + warning_dict=warning_dict, + online=args.online, + ) + + +if __name__ == "__main__": main() diff --git a/scripts/dmriqc_tissues.py b/scripts/dmriqc_tissues.py index b47b093..b898426 100755 --- a/scripts/dmriqc_tissues.py +++ b/scripts/dmriqc_tissues.py @@ -2,22 +2,29 @@ # -*- coding: utf-8 -*- import argparse -import os -import shutil +from functools import partial -import itertools -from multiprocessing import Pool import numpy as np - -from dmriqcpy.analysis.stats import stats_mask_volume from dmriqcpy.io.report import Report -from dmriqcpy.io.utils import (add_online_arg, add_overwrite_arg, - assert_inputs_exist, assert_outputs_exist, - list_files_from_paths) -from dmriqcpy.viz.graph import graph_mask_volume -from dmriqcpy.viz.screenshot import screenshot_mosaic_wrapper -from dmriqcpy.viz.utils import analyse_qa, dataframe_to_html +from dmriqcpy.io.utils import ( + add_online_arg, + add_overwrite_arg, + add_nb_columns_arg, + add_nb_threads_arg, + add_skip_arg, + assert_inputs_exist, + assert_list_arguments_equal_size, + assert_outputs_exist, + clean_output_directories, + list_files_from_paths, +) +from dmriqcpy.reporting.report import ( + generate_metric_reports_parallel, + generate_report_package, + get_mask_qa_stats_and_graph, +) +from dmriqcpy.viz.utils import dataframe_to_html DESCRIPTION = """ @@ -26,52 +33,24 @@ def _build_arg_parser(): - p = argparse.ArgumentParser(description=DESCRIPTION, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('output_report', - help='HTML report') - - p.add_argument('--wm', nargs='+', required=True, - help='WM mask in Nifti format') - - p.add_argument('--gm', nargs='+', required=True, - help='GM mask in Nifti format') - - p.add_argument('--csf', nargs='+', required=True, - help='CSF mask in Nifti format') - - p.add_argument('--skip', default=2, type=int, - help='Number of images skipped to build the ' - 'mosaic. [%(default)s]') - - p.add_argument('--nb_columns', default=12, type=int, - help='Number of columns for the mosaic. [%(default)s]') - - p.add_argument('--nb_threads', type=int, default=1, - help='Number of threads. [%(default)s]') - + p = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter + ) + + p.add_argument("output_report", help="Filename of QC report (in html format).") + p.add_argument("--wm", nargs="+", required=True, help="Folder or list of WM mask in Nifti format") + p.add_argument("--gm", nargs="+", required=True, help="Folder or list of GM mask in Nifti format") + p.add_argument("--csf", nargs="+", required=True, help="Folder or list of CSF mask in Nifti format") + + add_skip_arg(p) + add_nb_columns_arg(p) + add_nb_threads_arg(p) add_online_arg(p) add_overwrite_arg(p) return p -def _subj_parralel(subj_metric, summary, name, skip, nb_columns): - subjects_dict = {} - curr_key = os.path.basename(subj_metric).split('.')[0] - screenshot_path = screenshot_mosaic_wrapper(subj_metric, - output_prefix=name, - directory="data", skip=skip, - nb_columns=nb_columns) - - summary_html = dataframe_to_html(summary.loc[curr_key].to_frame()) - subjects_dict[curr_key] = {} - subjects_dict[curr_key]['screenshot'] = screenshot_path - subjects_dict[curr_key]['stats'] = summary_html - return subjects_dict - - def main(): parser = _build_arg_parser() args = parser.parse_args() @@ -80,67 +59,52 @@ def main(): gm = list_files_from_paths(args.gm) csf = list_files_from_paths(args.csf) - if not len(wm) == len(gm) == len(csf): - parser.error("Not the same number of images in input.") - + assert_list_arguments_equal_size(parser, wm, gm, csf) all_images = np.concatenate([wm, gm, csf]) assert_inputs_exist(parser, all_images) assert_outputs_exist(parser, args, [args.output_report, "data", "libs"]) + clean_output_directories() - if os.path.exists("data"): - shutil.rmtree("data") - os.makedirs("data") - - if os.path.exists("libs"): - shutil.rmtree("libs") - - metrics_names = [[wm, 'WM mask'], - [gm, 'GM mask'], - [csf, 'CSF mask']] metrics_dict = {} summary_dict = {} graphs = [] warning_dict = {} - for metrics, name in metrics_names: - columns = ["{} volume".format(name)] - summary, stats = stats_mask_volume(columns, metrics) - - warning_dict[name] = analyse_qa(summary, stats, columns) - warning_list = np.concatenate([filenames for filenames in warning_dict[name].values()]) - warning_dict[name]['nb_warnings'] = len(np.unique(warning_list)) - - graph = graph_mask_volume('{} mean volume'.format(name), - columns, summary, args.online) - graphs.append(graph) - - stats_html = dataframe_to_html(stats) - summary_dict[name] = stats_html - - subjects_dict = {} - pool = Pool(args.nb_threads) - subjects_dict_pool = pool.starmap(_subj_parralel, - zip(metrics, - itertools.repeat(summary), - itertools.repeat(name), - itertools.repeat(args.skip), - itertools.repeat(args.nb_columns))) - pool.close() - pool.join() - - for dict_sub in subjects_dict_pool: - for key in dict_sub: - curr_key = os.path.basename(key).split('.')[0] - subjects_dict[curr_key] = dict_sub[curr_key] - metrics_dict[name] = subjects_dict + for metrics, name in [ + [wm, "WM mask"], + [gm, "GM mask"], + [csf, "CSF mask"], + ]: + summary, stats, qa_report, qa_graphs = get_mask_qa_stats_and_graph( + metrics, name, args.online + ) + + warning_dict[name] = qa_report + summary_dict[name] = dataframe_to_html(stats) + graphs.extend(qa_graphs) + + metrics_dict[name] = generate_metric_reports_parallel( + zip(metrics), + args.nb_threads, + report_package_generation_fn=partial( + generate_report_package, + stats_summary=summary, + skip=args.skip, + nb_columns=args.nb_columns, + ), + ) nb_subjects = len(wm) report = Report(args.output_report) - report.generate(title="Quality Assurance tissue segmentation", - nb_subjects=nb_subjects, summary_dict=summary_dict, - graph_array=graphs, metrics_dict=metrics_dict, - warning_dict=warning_dict, - online=args.online) - - -if __name__ == '__main__': + report.generate( + title="Quality Assurance tissue segmentation", + nb_subjects=nb_subjects, + summary_dict=summary_dict, + graph_array=graphs, + metrics_dict=metrics_dict, + warning_dict=warning_dict, + online=args.online, + ) + + +if __name__ == "__main__": main() diff --git a/scripts/dmriqc_tracking_maps.py b/scripts/dmriqc_tracking_maps.py index c9c5eb9..b73c96e 100755 --- a/scripts/dmriqc_tracking_maps.py +++ b/scripts/dmriqc_tracking_maps.py @@ -2,22 +2,29 @@ # -*- coding: utf-8 -*- import argparse -import os -import shutil +from functools import partial -import itertools -from multiprocessing import Pool import numpy as np - -from dmriqcpy.analysis.stats import stats_mask_volume from dmriqcpy.io.report import Report -from dmriqcpy.io.utils import (add_online_arg, add_overwrite_arg, - assert_inputs_exist, assert_outputs_exist, - list_files_from_paths) -from dmriqcpy.viz.graph import graph_mask_volume -from dmriqcpy.viz.screenshot import screenshot_mosaic_wrapper -from dmriqcpy.viz.utils import analyse_qa, dataframe_to_html +from dmriqcpy.io.utils import ( + add_online_arg, + add_overwrite_arg, + add_nb_columns_arg, + add_nb_threads_arg, + add_skip_arg, + assert_inputs_exist, + assert_list_arguments_equal_size, + assert_outputs_exist, + clean_output_directories, + list_files_from_paths, +) +from dmriqcpy.reporting.report import ( + generate_report_package, + generate_metric_reports_parallel, + get_mask_qa_stats_and_graph, +) +from dmriqcpy.viz.utils import dataframe_to_html DESCRIPTION = """ @@ -26,58 +33,40 @@ def _build_arg_parser(): - p = argparse.ArgumentParser(description=DESCRIPTION, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('tracking_type', choices=["pft", "local"], - help='Tracking type') - - p.add_argument('output_report', - help='HTML report') - - p.add_argument('--seeding_mask', nargs='+', required=True, - help='Folder or list of seeding mask in Nifti format') - - p.add_argument('--tracking_mask', nargs='+', - help='Folder or list of tracking mask in Nifti format') - - p.add_argument('--map_include', nargs='+', - help='Folder or list of map include in Nifti format') - - p.add_argument('--map_exclude', nargs='+', - help='Folder or list of map exlude in Nifti format') - - p.add_argument('--skip', default=2, type=int, - help='Number of images skipped to build the ' - 'mosaic. [%(default)s]') - - p.add_argument('--nb_columns', default=12, type=int, - help='Number of columns for the mosaic. [%(default)s]') - - p.add_argument('--nb_threads', type=int, default=1, - help='Number of threads. [%(default)s]') - + p = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter + ) + + p.add_argument("tracking_type", choices=["pft", "local"], help="Tracking type.") + p.add_argument("output_report", help="Filename of QC report (in html format).") + p.add_argument( + "--seeding_mask", + nargs="+", + required=True, + help="Folder or list of seeding mask in Nifti format.", + ) + + p.add_argument( + "--tracking_mask", nargs="+", help="Folder or list of tracking mask in Nifti format." + ) + + p.add_argument( + "--map_include", nargs="+", help="Folder or list of map include in Nifti format." + ) + + p.add_argument( + "--map_exclude", nargs="+", help="Folder or list of map exlude in Nifti format." + ) + + add_skip_arg(p) + add_nb_columns_arg(p) + add_nb_threads_arg(p) add_online_arg(p) add_overwrite_arg(p) return p -def _subj_parralel(subj_metric, summary, name, skip, nb_columns): - subjects_dict = {} - curr_key = os.path.basename(subj_metric).split('.')[0] - screenshot_path = screenshot_mosaic_wrapper(subj_metric, - output_prefix=name, - directory="data", skip=skip, - nb_columns=nb_columns) - - summary_html = dataframe_to_html(summary.loc[curr_key].to_frame()) - subjects_dict[curr_key] = {} - subjects_dict[curr_key]['screenshot'] = screenshot_path - subjects_dict[curr_key]['stats'] = summary_html - return subjects_dict - - def main(): parser = _build_arg_parser() args = parser.parse_args() @@ -86,79 +75,68 @@ def main(): if args.tracking_type == "local": tracking_mask = list_files_from_paths(args.tracking_mask) - if not len(seeding_mask) == len(tracking_mask): - parser.error("Not the same number of images in input.") - all_images = np.concatenate([args.seeding_mask, args.tracking_mask]) + assert_list_arguments_equal_size(parser, seeding_mask, tracking_mask) + all_images = np.concatenate([seeding_mask, tracking_mask]) + metrics_names = [ + [seeding_mask, "Seeding mask"], + [tracking_mask, "Tracking mask"], + ] else: map_include = list_files_from_paths(args.map_include) map_exclude = list_files_from_paths(args.map_exclude) - if not len(seeding_mask) == len(map_include) ==\ - len(map_exclude): - parser.error("Not the same number of images in input.") - all_images = np.concatenate([seeding_mask, map_include, - map_exclude]) + assert_list_arguments_equal_size( + parser, seeding_mask, map_include, map_exclude + ) + all_images = np.concatenate( + [seeding_mask, map_include, map_exclude] + ) + metrics_names = [ + [seeding_mask, "Seeding mask"], + [map_include, "Map include"], + [map_exclude, "Maps exclude"], + ] assert_inputs_exist(parser, all_images) assert_outputs_exist(parser, args, [args.output_report, "data", "libs"]) + clean_output_directories() - if os.path.exists("data"): - shutil.rmtree("data") - os.makedirs("data") - - if os.path.exists("libs"): - shutil.rmtree("libs") - - if args.tracking_type == "local": - metrics_names = [[seeding_mask, 'Seeding mask'], - [tracking_mask, 'Tracking mask']] - else: - metrics_names = [[seeding_mask, 'Seeding mask'], - [map_include, 'Map include'], - [map_exclude, 'Maps exclude']] metrics_dict = {} summary_dict = {} graphs = [] warning_dict = {} for metrics, name in metrics_names: - columns = ["{} volume".format(name)] - summary, stats = stats_mask_volume(columns, metrics) - - warning_dict[name] = analyse_qa(summary, stats, columns) - warning_list = np.concatenate([filenames for filenames in warning_dict[name].values()]) - warning_dict[name]['nb_warnings'] = len(np.unique(warning_list)) - - graph = graph_mask_volume('{} mean volume'.format(name), - columns, summary, args.online) - graphs.append(graph) - - stats_html = dataframe_to_html(stats) - summary_dict[name] = stats_html - - subjects_dict = {} - pool = Pool(args.nb_threads) - subjects_dict_pool = pool.starmap(_subj_parralel, - zip(metrics, - itertools.repeat(summary), - itertools.repeat(name), - itertools.repeat(args.skip), - itertools.repeat(args.nb_columns))) - pool.close() - pool.join() - - for dict_sub in subjects_dict_pool: - for key in dict_sub: - curr_key = os.path.basename(key).split('.')[0] - subjects_dict[curr_key] = dict_sub[curr_key] - metrics_dict[name] = subjects_dict + summary, stats, qa_report, qa_graphs = get_mask_qa_stats_and_graph( + metrics, name, args.online + ) + + warning_dict[name] = qa_report + summary_dict[name] = dataframe_to_html(stats) + graphs.extend(qa_graphs) + + metrics_dict[name] = generate_metric_reports_parallel( + zip(metrics), + args.nb_threads, + len(metrics) // args.nb_threads, + report_package_generation_fn=partial( + generate_report_package, + stats_summary=summary, + skip=args.skip, + nb_columns=args.nb_columns, + ), + ) nb_subjects = len(seeding_mask) report = Report(args.output_report) - report.generate(title="Quality Assurance tracking maps", - nb_subjects=nb_subjects, summary_dict=summary_dict, - graph_array=graphs, metrics_dict=metrics_dict, - warning_dict=warning_dict, - online=args.online) - - -if __name__ == '__main__': + report.generate( + title="Quality Assurance tracking maps", + nb_subjects=nb_subjects, + summary_dict=summary_dict, + graph_array=graphs, + metrics_dict=metrics_dict, + warning_dict=warning_dict, + online=args.online, + ) + + +if __name__ == "__main__": main() diff --git a/scripts/dmriqc_tractogram.py b/scripts/dmriqc_tractogram.py index 8795d77..26eaebb 100755 --- a/scripts/dmriqc_tractogram.py +++ b/scripts/dmriqc_tractogram.py @@ -2,20 +2,27 @@ # -*- coding: utf-8 -*- import argparse -import os -import shutil +from functools import partial import numpy as np - from dmriqcpy.io.report import Report -from dmriqcpy.io.utils import (add_online_arg, add_overwrite_arg, - assert_inputs_exist, assert_outputs_exist, - list_files_from_paths) -from dmriqcpy.analysis.stats import stats_tractogram -from dmriqcpy.viz.graph import graph_tractogram -from dmriqcpy.viz.screenshot import screenshot_tracking -from dmriqcpy.viz.utils import analyse_qa, dataframe_to_html +from dmriqcpy.io.utils import ( + add_nb_threads_arg, + add_online_arg, + add_overwrite_arg, + assert_inputs_exist, + assert_list_arguments_equal_size, + assert_outputs_exist, + clean_output_directories, + list_files_from_paths, +) +from dmriqcpy.reporting.report import ( + generate_metric_reports_parallel, + generate_report_package, + get_tractogram_qa_stats_and_graph, +) +from dmriqcpy.viz.utils import dataframe_to_html DESCRIPTION = """ @@ -24,20 +31,21 @@ def _build_arg_parser(): - p = argparse.ArgumentParser(description=DESCRIPTION, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('output_report', - help='HTML report') - - p.add_argument('--tractograms', nargs='+', - help='Folder or list of tractograms in format supported' - ' by Nibabel.') - - p.add_argument('--t1', nargs='+', - help='Folder or list of T1 images in Nifti format.') + p = argparse.ArgumentParser( + description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter + ) + + p.add_argument("output_report", help="Filename of QC report (in html format).") + p.add_argument( + "--tractograms", + nargs="+", + required=True, + help="Tractograms in format supported by Nibabel.", + ) + p.add_argument("--t1", nargs="+", required=True, help="Folder or list of T1 images in Nifti format.") add_online_arg(p) + add_nb_threads_arg(p) add_overwrite_arg(p) return p @@ -50,56 +58,46 @@ def main(): t1 = list_files_from_paths(args.t1) tractograms = list_files_from_paths(args.tractograms) - if not len(tractograms) == len(t1): - parser.error("Not the same number of images in input.") - + assert_list_arguments_equal_size(parser, t1, tractograms) all_images = np.concatenate([tractograms, t1]) assert_inputs_exist(parser, all_images) assert_outputs_exist(parser, args, [args.output_report, "data", "libs"]) - - if os.path.exists("data"): - shutil.rmtree("data") - os.makedirs("data") - - if os.path.exists("libs"): - shutil.rmtree("libs") + clean_output_directories() name = "Tracking" - columns = ["Nb streamlines"] - - warning_dict = {} - summary, stats = stats_tractogram(columns, tractograms) - warning_dict[name] = analyse_qa(summary, stats, ["Nb streamlines"]) - warning_list = np.concatenate([filenames for filenames in warning_dict[name].values()]) - warning_dict[name]['nb_warnings'] = len(np.unique(warning_list)) - - graphs = [] - graph = graph_tractogram("Tracking", columns, summary, args.online) - graphs.append(graph) - - summary_dict = {} - stats_html = dataframe_to_html(stats) - summary_dict[name] = stats_html - - metrics_dict = {} - subjects_dict = {} - for subj_metric, curr_t1 in zip(tractograms, t1): - curr_key = os.path.basename(subj_metric).split('.')[0] - screenshot_path = screenshot_tracking(subj_metric, curr_t1, "data") - summary_html = dataframe_to_html(summary.loc[curr_key].to_frame()) - subjects_dict[curr_key] = {} - subjects_dict[curr_key]['screenshot'] = screenshot_path - subjects_dict[curr_key]['stats'] = summary_html - metrics_dict[name] = subjects_dict - nb_subjects = len(tractograms) - report = Report(args.output_report) - report.generate(title="Quality Assurance tractograms", - nb_subjects=nb_subjects, summary_dict=summary_dict, - graph_array=graphs, metrics_dict=metrics_dict, - warning_dict=warning_dict, - online=args.online) + summary, stats, qa_report, qa_graphs = get_tractogram_qa_stats_and_graph( + tractograms, args.online + ) + + warning_dict = {name: qa_report} + summary_dict = {name: dataframe_to_html(stats)} + + metrics_dict = { + name: generate_metric_reports_parallel( + zip(tractograms, t1), + args.nb_threads, + nb_subjects // args.nb_threads, + report_package_generation_fn=partial( + generate_report_package, + stats_summary=summary, + metric_is_tracking=True + ), + ) + } -if __name__ == '__main__': + report = Report(args.output_report) + report.generate( + title="Quality Assurance tractograms", + nb_subjects=nb_subjects, + summary_dict=summary_dict, + graph_array=qa_graphs, + metrics_dict=metrics_dict, + warning_dict=warning_dict, + online=args.online, + ) + + +if __name__ == "__main__": main() diff --git a/setup.py b/setup.py index e1a019c..955b348 100644 --- a/setup.py +++ b/setup.py @@ -2,32 +2,35 @@ import os from setuptools import setup, find_packages + PACKAGES = find_packages() # Get version and release info, which is all stored in dmriqc/version.py -ver_file = os.path.join('dmriqcpy', 'version.py') +ver_file = os.path.join("dmriqcpy", "version.py") with open(ver_file) as f: exec(f.read()) -opts = dict(name=NAME, - maintainer=MAINTAINER, - maintainer_email=MAINTAINER_EMAIL, - description=DESCRIPTION, - long_description=LONG_DESCRIPTION, - url=URL, - download_url=DOWNLOAD_URL, - license=LICENSE, - classifiers=CLASSIFIERS, - author=AUTHOR, - author_email=AUTHOR_EMAIL, - platforms=PLATFORMS, - version=VERSION, - packages=PACKAGES, - install_requires=REQUIRES, - requires=REQUIRES, - scripts=SCRIPTS, - include_package_data=True) +opts = dict( + name=NAME, + maintainer=MAINTAINER, + maintainer_email=MAINTAINER_EMAIL, + description=DESCRIPTION, + long_description=LONG_DESCRIPTION, + url=URL, + download_url=DOWNLOAD_URL, + license=LICENSE, + classifiers=CLASSIFIERS, + author=AUTHOR, + author_email=AUTHOR_EMAIL, + platforms=PLATFORMS, + version=VERSION, + packages=PACKAGES, + install_requires=REQUIRES, + requires=REQUIRES, + scripts=SCRIPTS, + include_package_data=True, +) -if __name__ == '__main__': +if __name__ == "__main__": setup(**opts)