diff --git a/src/main/java/org/broadinstitute/hellbender/tools/spark/transforms/BaseRecalibratorSparkFn.java b/src/main/java/org/broadinstitute/hellbender/tools/spark/transforms/BaseRecalibratorSparkFn.java index 695544ee704..35cc726485a 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/spark/transforms/BaseRecalibratorSparkFn.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/spark/transforms/BaseRecalibratorSparkFn.java @@ -11,8 +11,12 @@ import org.broadinstitute.hellbender.utils.io.IOUtils; import org.broadinstitute.hellbender.utils.read.GATKRead; import org.broadinstitute.hellbender.utils.recalibration.*; -import org.broadinstitute.hellbender.utils.recalibration.covariates.StandardCovariateList; +import org.broadinstitute.hellbender.utils.recalibration.covariates.BQSRCovariateList; +import org.broadinstitute.hellbender.utils.reference.ReferenceBases; import org.broadinstitute.hellbender.utils.variant.GATKVariant; +import scala.Tuple2; + +import java.util.Arrays; public final class BaseRecalibratorSparkFn { @@ -34,7 +38,7 @@ public static RecalibrationReport apply(final JavaPairRDD * It updates the base qualities of the read with the new recalibrated qualities (for all event types) *

* Implements a serial recalibration of the reads using the combinational table. - * First, we perform a positional recalibration, and then a subsequent dinuc correction. + * First, we perform a positional recalibration, and then a subsequent dinucleotide (dinuc) correction. *

* Given the full recalibration table, we perform the following preprocessing steps: *

* - calculate the global quality score shift across all data [DeltaQ] * - calculate for each of cycle and dinuc the shift of the quality scores relative to the global shift - * -- i.e., DeltaQ(dinuc) = Sum(pos) Sum(Qual) Qempirical(pos, qual, dinuc) - Qreported(pos, qual, dinuc) / Npos * Nqual + * -- i.e., DeltaQ(dinuc) = Sum(pos) Sum(Qual) Q_empirical(pos, qual, dinuc) - Q_reported(pos, qual, dinuc) / Npos * Nqual * - The final shift equation is: *

- * Qrecal = Qreported + DeltaQ + DeltaQ(pos) + DeltaQ(dinuc) + DeltaQ( ... any other covariate ... ) + * Q_recal = Q_reported + DeltaQ + DeltaQ(pos) + DeltaQ(dinuc) + DeltaQ( ... any other covariate ... ) * * @param originalRead the read to recalibrate */ @@ -203,26 +207,26 @@ public GATKRead apply(final GATKRead originalRead) { } // clear and reuse this array to save space - Arrays.fill(recalDatumsForSpecialCovariates, null); + Arrays.fill(recalDatumsForAdditionalCovariates, null); final int[] covariatesAtOffset = covariatesForRead[offset]; - final int reportedBaseQualityAtOffset = covariatesAtOffset[StandardCovariateList.BASE_QUALITY_COVARIATE_DEFAULT_INDEX]; + final int reportedBaseQualityAtOffset = covariatesAtOffset[BQSRCovariateList.BASE_QUALITY_COVARIATE_DEFAULT_INDEX]; // Datum for the tuple (read group, reported quality score). final RecalDatum qualityScoreDatum = recalibrationTables.getQualityScoreTable() .get3Keys(rgKey, reportedBaseQualityAtOffset, BASE_SUBSTITUTION_INDEX); - for (int j = StandardCovariateList.NUM_REQUIRED_COVARITES; j < totalCovariateCount; j++) { + for (int j = BQSRCovariateList.NUM_REQUIRED_COVARITES; j < totalCovariateCount; j++) { // If the covariate is -1 (e.g. the first base in each read should have -1 for the context covariate), // we simply leave the corresponding Datum to be null, which will subsequently be ignored when it comes time to recalibrate. if (covariatesAtOffset[j] >= 0) { - recalDatumsForSpecialCovariates[j - StandardCovariateList.NUM_REQUIRED_COVARITES] = recalibrationTables.getTable(j) + recalDatumsForAdditionalCovariates[j - BQSRCovariateList.NUM_REQUIRED_COVARITES] = recalibrationTables.getTable(j) .get4Keys(rgKey, reportedBaseQualityAtOffset, covariatesAtOffset[j], BASE_SUBSTITUTION_INDEX); } } // Use the reported quality score of the read group as the prior, which can be non-integer because of collapsing. final double priorQualityScore = constantQualityScorePrior > 0.0 ? constantQualityScorePrior : readGroupDatum.getReportedQuality(); - final double rawRecalibratedQualityScore = hierarchicalBayesianQualityEstimate(priorQualityScore, readGroupDatum, qualityScoreDatum, recalDatumsForSpecialCovariates); + final double rawRecalibratedQualityScore = hierarchicalBayesianQualityEstimate(priorQualityScore, readGroupDatum, qualityScoreDatum, recalDatumsForAdditionalCovariates); final byte quantizedQualityScore = quantizedQuals.get(getBoundedIntegerQual(rawRecalibratedQualityScore)); // TODO: as written the code quantizes *twice* if the static binning is enabled (first time to the dynamic bin). It should be quantized once. diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/BaseRecalibrationEngine.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/BaseRecalibrationEngine.java index 7a3aa3262a5..c765422faa4 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/BaseRecalibrationEngine.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/BaseRecalibrationEngine.java @@ -25,7 +25,7 @@ import org.broadinstitute.hellbender.utils.recalibration.covariates.Covariate; import org.broadinstitute.hellbender.utils.recalibration.covariates.CovariateKeyCache; import org.broadinstitute.hellbender.utils.recalibration.covariates.PerReadCovariateMatrix; -import org.broadinstitute.hellbender.utils.recalibration.covariates.StandardCovariateList; +import org.broadinstitute.hellbender.utils.recalibration.covariates.BQSRCovariateList; import java.io.Serializable; import java.util.Arrays; @@ -68,7 +68,7 @@ public SimpleInterval apply( GATKRead read ) { /** * list to hold the all the covariate objects that were requested (required + standard + experimental) */ - private StandardCovariateList covariates; + private BQSRCovariateList covariates; private BAQ baq; // BAQ the reads on the fly to generate the alignment uncertainty vector private static final byte NO_BAQ_UNCERTAINTY = (byte)'@'; @@ -90,7 +90,7 @@ public BaseRecalibrationEngine( final RecalibrationArgumentCollection recalArgs, baq = null; } - covariates = new StandardCovariateList(recalArgs, readsHeader); + covariates = new BQSRCovariateList(recalArgs, readsHeader); final int numReadGroups = readsHeader.getReadGroups().size(); if ( numReadGroups < 1 ) { @@ -236,7 +236,7 @@ public RecalibrationTables getFinalRecalibrationTables() { return recalTables; } - public StandardCovariateList getCovariates() { + public BQSRCovariateList getCovariates() { return covariates; } @@ -260,6 +260,7 @@ private void updateRecalTablesForRead( final ReadRecalibrationInfo recalInfo ) { final NestedIntegerArray qualityScoreTable = recalTables.getQualityScoreTable(); final int nCovariates = covariates.size(); + final int nSpecialCovariates = BQSRCovariateList.numberOfRequiredCovariates(); final int readLength = read.getLength(); for( int offset = 0; offset < readLength; offset++ ) { if( ! recalInfo.skip(offset) ) { @@ -270,8 +271,8 @@ private void updateRecalTablesForRead( final ReadRecalibrationInfo recalInfo ) { final byte qual = recalInfo.getQual(eventType, offset); final double isError = recalInfo.getErrorFraction(eventType, offset); - final int readGroup = covariatesAtOffset[StandardCovariateList.READ_GROUP_COVARIATE_DEFAULT_INDEX]; - final int baseQuality = covariatesAtOffset[StandardCovariateList.BASE_QUALITY_COVARIATE_DEFAULT_INDEX]; + final int readGroup = covariatesAtOffset[BQSRCovariateList.READ_GROUP_COVARIATE_DEFAULT_INDEX]; + final int baseQuality = covariatesAtOffset[BQSRCovariateList.BASE_QUALITY_COVARIATE_DEFAULT_INDEX]; RecalUtils.incrementDatum3keys(qualityScoreTable, qual, isError, readGroup, baseQuality, eventIndex); diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/ReadRecalibrationInfo.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/ReadRecalibrationInfo.java index 81e14593b80..2aca6710ed9 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/ReadRecalibrationInfo.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/ReadRecalibrationInfo.java @@ -118,8 +118,8 @@ public boolean skip(final int offset) { } /** - * Get the ReadCovariates object carrying the mapping from offsets -> covariate key sets - * @return a non-null ReadCovariates object + * Get the PerReadCovariateMatrix object carrying the mapping from offsets -> covariate key sets + * @return a non-null PerReadCovariateMatrix object */ public PerReadCovariateMatrix getCovariatesValues() { return covariates; diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalDatum.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalDatum.java index c4ce436754d..93e4b767dea 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalDatum.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalDatum.java @@ -365,4 +365,31 @@ protected static double getLogBinomialLikelihood(final double qualityScore, long final double logLikelihood = MathUtils.logBinomialProbability((int) nObservations, (int) nErrors, QualityUtils.qualToErrorProb(qualityScore)); return ( Double.isInfinite(logLikelihood) || Double.isNaN(logLikelihood) ) ? -Double.MAX_VALUE : logLikelihood; } + + @Override + public boolean equals(Object obj) { + if (this == obj){ + return true; + } + + if (!(obj instanceof RecalDatum) || obj == null) { + return false; + } + + RecalDatum other = (RecalDatum) obj; + return numObservations == other.numObservations + && Double.compare(numMismatches, other.numMismatches) == 0 + && Double.compare(reportedQuality, other.reportedQuality) == 0 + && empiricalQuality == other.empiricalQuality; + } + + @Override + public int hashCode() { + int result = 17; + result = 31 * result + Long.hashCode(numObservations); + result = 31 * result + Double.hashCode(numMismatches); + result = 31 * result + Double.hashCode(reportedQuality); + result = 31 * result + Integer.hashCode(empiricalQuality); + return result; + } } diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalUtils.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalUtils.java index 102730e4e20..04b17da231f 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalUtils.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalUtils.java @@ -17,7 +17,7 @@ import org.broadinstitute.hellbender.utils.recalibration.covariates.Covariate; import org.broadinstitute.hellbender.utils.recalibration.covariates.CovariateKeyCache; import org.broadinstitute.hellbender.utils.recalibration.covariates.PerReadCovariateMatrix; -import org.broadinstitute.hellbender.utils.recalibration.covariates.StandardCovariateList; +import org.broadinstitute.hellbender.utils.recalibration.covariates.BQSRCovariateList; import org.broadinstitute.hellbender.utils.report.GATKReport; import org.broadinstitute.hellbender.utils.report.GATKReportTable; @@ -79,7 +79,7 @@ public final class RecalUtils { private static class CsvPrinter { private final PrintStream ps; - private final StandardCovariateList covariates; + private final BQSRCovariateList covariates; /** * Constructs a printer redirected to an output file. @@ -87,7 +87,7 @@ private static class CsvPrinter { * @param covs covariates to print out. * @throws FileNotFoundException if the file could not be created anew. */ - protected CsvPrinter(final File out, final StandardCovariateList covs) + protected CsvPrinter(final File out, final BQSRCovariateList covs) throws FileNotFoundException { this(new FileOutputStream(out), covs); } @@ -97,7 +97,7 @@ protected CsvPrinter(final File out, final StandardCovariateList covs) * @param os the output. * @param covs covariates to print out. */ - protected CsvPrinter(final OutputStream os, final StandardCovariateList covs) { + protected CsvPrinter(final OutputStream os, final BQSRCovariateList covs) { covariates = covs; ps = new PrintStream(os); printHeader(); @@ -144,7 +144,7 @@ public void close() { * * @return never null */ - protected static CsvPrinter csvPrinter(final File out, final StandardCovariateList covs) throws FileNotFoundException { + protected static CsvPrinter csvPrinter(final File out, final BQSRCovariateList covs) throws FileNotFoundException { Utils.nonNull(covs, "the input covariate array cannot be null"); return new CsvPrinter(out,covs); } @@ -165,7 +165,32 @@ public static void generateCsv(final File out, final Map rit = reports.values().iterator(); + final RecalibrationReport first = rit.next(); + final Covariate[] firstCovariates = first.getRequestedCovariates(); + final Set covariates = new LinkedHashSet<>(); + Utils.addAll(covariates,firstCovariates); + while (rit.hasNext() && covariates.size() > 0) { + final Covariate[] nextCovariates = rit.next().getRequestedCovariates(); + final Set nextCovariateNames = new LinkedHashSet<>(nextCovariates.length); + for (final Covariate nc : nextCovariates) { + nextCovariateNames.add(nc.getClass().getSimpleName()); + } + final Iterator cit = covariates.iterator(); + while (cit.hasNext()) { + if (!nextCovariateNames.contains(cit.next().getClass().getSimpleName())) { + cit.remove(); + } + } + } + writeCsv(out, reports, covariates.toArray(new Covariate[covariates.size()])); + */ + writeCsv(out, reports, covariates); } @@ -178,7 +203,7 @@ public static void generateCsv(final File out, final Mapout could not be created anew. */ - private static void writeCsv(final File out, final Map reports, final StandardCovariateList covs) + private static void writeCsv(final File out, final Map reports, final BQSRCovariateList covs) throws FileNotFoundException { final CsvPrinter p = csvPrinter(out, covs); for (final Map.Entry e : reports.entrySet()) { @@ -187,7 +212,7 @@ private static void writeCsv(final File out, final Map generateReportTables(final RecalibrationTables recalibrationTables, final StandardCovariateList covariates) { + public static List generateReportTables(final RecalibrationTables recalibrationTables, final BQSRCovariateList covariates) { final List result = new LinkedList<>(); int rowIndex = 0; @@ -297,7 +322,7 @@ private static GATKReportTable makeNewTableWithColumns(ArrayList deltaTable = createDeltaTable(recalibrationTables, covariates.size()); @@ -447,7 +472,7 @@ private static NestedIntegerArray createDeltaTable(final Recalibrati return new NestedIntegerArray<>(dimensionsForDeltaTable); } - static List generateValuesFromKeys(final int[] keys, final StandardCovariateList covariates) { + static List generateValuesFromKeys(final int[] keys, final BQSRCovariateList covariates) { final List values = new ArrayList<>(4); values.add(covariates.getReadGroupCovariate().formatKey(keys[0])); @@ -525,10 +550,30 @@ public static void updatePlatformForRead(final GATKRead read, final SAMFileHeade * * @return a matrix with all the covariates calculated for every base in the read */ - public static PerReadCovariateMatrix computeCovariates(final GATKRead read, final SAMFileHeader header, final StandardCovariateList covariates, final boolean recordIndelValues, final CovariateKeyCache keyCache) { - final PerReadCovariateMatrix covariateTable = new PerReadCovariateMatrix(read.getLength(), covariates.size(), keyCache); - covariates.populatePerReadCovariateMatrix(read, header, covariateTable, recordIndelValues); - return covariateTable; + public static PerReadCovariateMatrix computeCovariates(final GATKRead read, final SAMFileHeader header, final BQSRCovariateList covariates, final boolean recordIndelValues, final CovariateKeyCache keyCache) { + final PerReadCovariateMatrix covariateMatrix = new PerReadCovariateMatrix(read.getLength(), covariates.size(), keyCache); + // tsato: this seems redundant; comment out for now + // covariates.populatePerReadCovariateMatrix(read, header, covariateMatrix, recordIndelValues); + computeCovariates(read, header, covariates, covariateMatrix, recordIndelValues); + return covariateMatrix; + } + + /** + * Computes all requested covariates for every offset in the given read + * by calling covariate.getValues(..). + * + * It populates an array of covariate values where result[i][j] is the covariate + * value for the ith position in the read and the jth covariate in + * covariates list. + * + * @param read The read for which to compute covariate values. + * @param header SAM header for the read + * @param covariates The list of covariates. + * @param covariateMatrix The object to store the covariate values + * @param recordIndelValues should we compute covariates for indel BQSR? + */ + public static void computeCovariates( final GATKRead read, final SAMFileHeader header, final BQSRCovariateList covariates, final PerReadCovariateMatrix covariateMatrix, final boolean recordIndelValues) { + covariates.populatePerReadCovariateMatrix(read, header, covariateMatrix, recordIndelValues); } /** diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationArgumentCollection.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationArgumentCollection.java index 01f1d853432..52939bdaa6a 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationArgumentCollection.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationArgumentCollection.java @@ -9,8 +9,12 @@ import java.io.File; import java.io.Serializable; +import java.util.ArrayList; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.HashSet; /** * A collection of the arguments that are used for BQSR. Used to be common to both CovariateCounterWalker and TableRecalibrationWalker. @@ -20,9 +24,6 @@ public final class RecalibrationArgumentCollection implements Serializable { private static final long serialVersionUID = 1L; - // We always use the same covariates. The field is retained for compatibility with GATK3 reports. - public static final boolean DO_NOT_USE_STANDARD_COVARIATES = false; - //We don't support SOLID. The field is retained for compatibility with GATK3 reports. public static final String SOLID_RECAL_MODE = "SET_Q_ZERO"; public static final String SOLID_NOCALL_STRATEGY = "THROW_EXCEPTION"; @@ -30,6 +31,32 @@ public final class RecalibrationArgumentCollection implements Serializable { //It makes no sense to run BQSR without sites. so we remove this option. public static final boolean RUN_WITHOUT_DBSNP = false; + public static final String LIST_ONLY_LONG_NAME = "list-covariates"; + public static final String COVARIATES_LONG_NAME = "covariate"; + public static final String COVARIATES_SHORT_NAME = "cov"; + public static final String DO_NOT_USE_STANDARD_COVARIATES_LONG_NAME = "no-standard-covariates"; + /** + * Note that the --list-covariates argument requires a fully resolved and correct command-line to work. + */ + @Argument(fullName = LIST_ONLY_LONG_NAME, doc = "List the available covariates and exit", optional = true) + public boolean LIST_ONLY = false; + + /** + * Note that the ReadGroup and QualityScore covariates are required and do not need to be specified. + * Also, unless --no-standard-covariates is specified, the Cycle and Context covariates are standard and are included by default. + * Use the --list argument to see the available covariates. + * + */ + @Argument(fullName = COVARIATES_LONG_NAME, shortName = COVARIATES_SHORT_NAME, doc = "One or more covariates to be used in the recalibration. Can be specified multiple times", optional = true) + public List COVARIATES = new ArrayList<>(); + + /** + * The Cycle and Context covariates are standard and are included by default unless this argument is provided. + * Note that the ReadGroup and QualityScore covariates are required and cannot be excluded. + */ + @Argument(fullName = DO_NOT_USE_STANDARD_COVARIATES_LONG_NAME, doc = "Do not use the standard set of covariates, but rather just the ones listed using the -cov argument", optional = true) + public boolean DO_NOT_USE_STANDARD_COVARIATES = false; + /** * The context covariate will use a context of this size to calculate its covariate value for base mismatches. Must be between 1 and 13 (inclusive). Note that higher values will increase runtime and required java heap size. */ @@ -207,7 +234,8 @@ public GATKReportTable generateReportTable(final String covariateNames) { */ public Map compareReportArguments(final RecalibrationArgumentCollection other,final String thisRole, final String otherRole) { final Map result = new LinkedHashMap<>(15); - compareSimpleReportArgument(result,"no_standard_covs", DO_NOT_USE_STANDARD_COVARIATES, DO_NOT_USE_STANDARD_COVARIATES, thisRole, otherRole); + compareRequestedCovariates(result, other, thisRole, otherRole); + compareSimpleReportArgument(result,"no_standard_covs", DO_NOT_USE_STANDARD_COVARIATES, other.DO_NOT_USE_STANDARD_COVARIATES, thisRole, otherRole); compareSimpleReportArgument(result,"run_without_dbsnp",RUN_WITHOUT_DBSNP, RUN_WITHOUT_DBSNP,thisRole,otherRole); compareSimpleReportArgument(result,"solid_recal_mode", SOLID_RECAL_MODE, SOLID_RECAL_MODE,thisRole,otherRole); compareSimpleReportArgument(result,"solid_nocall_strategy", SOLID_NOCALL_STRATEGY, SOLID_NOCALL_STRATEGY,thisRole,otherRole); @@ -224,6 +252,52 @@ public Map compareReportArguments(final Recalibra return result; } + /** + * A helper method for {@link #compareReportArguments}. + * Compares the covariate report lists and update diffs with + * key = "covariate" and + * value = a message that explains the difference to the end user. + * + * @param diffs the map to be updated by side-effect by this method. + * @param other the argument collection to compare against. + * @param thisRole the name for this argument collection that makes sense to the user. + * @param otherRole the name for the other argument collection that makes sense to the end user. + * + * @return true if a difference was found. + */ + private boolean compareRequestedCovariates(final Map diffs, final RecalibrationArgumentCollection other, + final String thisRole, final String otherRole) { + + final Set beforeNames = new HashSet<>(this.COVARIATES.size()); + final Set afterNames = new HashSet<>(other.COVARIATES.size()); + beforeNames.addAll(this.COVARIATES); + afterNames.addAll(other.COVARIATES); + final Set intersection = new HashSet<>(Math.min(beforeNames.size(), afterNames.size())); + intersection.addAll(beforeNames); + intersection.retainAll(afterNames); + + String diffMessage = null; + if (intersection.size() == 0) { // In practice this is not possible due to required covariates but... + diffMessage = String.format("There are no common covariates between '%s' and '%s'" + + " recalibrator reports. Covariates in '%s': {%s}. Covariates in '%s': {%s}.", thisRole, otherRole, + thisRole, String.join(", ", this.COVARIATES), + otherRole, String.join(",", other.COVARIATES)); + } else if (intersection.size() != beforeNames.size() || intersection.size() != afterNames.size()) { + beforeNames.removeAll(intersection); + afterNames.removeAll(intersection); + diffMessage = String.format("There are differences in the set of covariates requested in the" + + " '%s' and '%s' recalibrator reports. " + + " Exclusive to '%s': {%s}. Exclusive to '%s': {%s}.", thisRole, otherRole, + thisRole, String.join(", ", beforeNames), + otherRole, String.join(", ", afterNames)); + } + if (diffMessage != null) { + diffs.put("covariate",diffMessage); + return true; + } else { + return false; + } + } /** * Annotates a map with any difference encountered in a simple value report argument that differs between this an diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationReport.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationReport.java index 32b50648b31..f8ebbbace20 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationReport.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationReport.java @@ -10,7 +10,7 @@ import org.broadinstitute.hellbender.utils.Utils; import org.broadinstitute.hellbender.utils.collections.NestedIntegerArray; import org.broadinstitute.hellbender.utils.recalibration.covariates.Covariate; -import org.broadinstitute.hellbender.utils.recalibration.covariates.StandardCovariateList; +import org.broadinstitute.hellbender.utils.recalibration.covariates.BQSRCovariateList; import org.broadinstitute.hellbender.utils.report.GATKReport; import org.broadinstitute.hellbender.utils.report.GATKReportTable; @@ -35,7 +35,7 @@ public final class RecalibrationReport { private static final Logger logger = LogManager.getLogger(RecalibrationReport.class); private QuantizationInfo quantizationInfo; // histogram containing the counts for qual quantization (calculated after recalibration is done) private final RecalibrationTables recalibrationTables; // quick access reference to the tables - private final StandardCovariateList covariates; // list of all covariates to be used in this calculation + private final BQSRCovariateList covariates; // list of all covariates to be used in this calculation private final GATKReportTable argumentTable; // keep the argument table untouched just for output purposes private final RecalibrationArgumentCollection RAC; // necessary for quantizing qualities with the same parameter @@ -59,7 +59,7 @@ public RecalibrationReport(final GATKReport report, final SortedSet allR final GATKReportTable quantizedTable = report.getTable(RecalUtils.QUANTIZED_REPORT_TABLE_TITLE); quantizationInfo = initializeQuantizationTable(quantizedTable); - covariates = new StandardCovariateList(RAC, new ArrayList<>(allReadGroups)); + covariates = new BQSRCovariateList(RAC, new ArrayList<>(allReadGroups)); recalibrationTables = new RecalibrationTables(covariates, allReadGroups.size()); @@ -175,7 +175,7 @@ public RecalibrationTables getRecalibrationTables() { return recalibrationTables; } - public StandardCovariateList getCovariates() { + public BQSRCovariateList getCovariates() { return covariates; } @@ -325,7 +325,7 @@ private static QuantizationInfo initializeQuantizationTable(GATKReportTable tabl private static RecalibrationArgumentCollection initializeArgumentCollectionTable(GATKReportTable table) { final RecalibrationArgumentCollection RAC = new RecalibrationArgumentCollection(); - final List standardCovariateClassNames = new StandardCovariateList(RAC, Collections.emptyList()).getStandardCovariateClassNames(); + final List standardCovariateClassNames = new BQSRCovariateList(RAC, Collections.emptyList()).getCovariateClassNames(); for ( int i = 0; i < table.getNumRows(); i++ ) { final String argument = table.get(i, "Argument").toString(); @@ -335,15 +335,9 @@ private static RecalibrationArgumentCollection initializeArgumentCollectionTable } if (argument.equals("covariate") && value != null) { - final List covs = new ArrayList<>(Arrays.asList(value.toString().split(","))); - if (!covs.equals(standardCovariateClassNames)) { - throw new UserException("Non-standard covariates are not supported. Only the following are supported " + standardCovariateClassNames + " but was " + covs); - } + RAC.COVARIATES = Arrays.asList(value.toString().split(",")); } else if (argument.equals("no_standard_covs")) { - final boolean no_standard_covs = decodeBoolean(value); - if (no_standard_covs){ - throw new UserException("Non-standard covariates are not supported. Only the following are supported " + standardCovariateClassNames + " but no_standard_covs was true"); - } + RAC.DO_NOT_USE_STANDARD_COVARIATES = decodeBoolean(value); } else if (argument.equals("solid_recal_mode")) { final String solid_recal_mode = (String) value; if (!RecalibrationArgumentCollection.SOLID_RECAL_MODE.equals(solid_recal_mode)){ diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationTables.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationTables.java index 057cdbc56e9..84a10a8a438 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationTables.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationTables.java @@ -1,7 +1,7 @@ package org.broadinstitute.hellbender.utils.recalibration; import org.broadinstitute.hellbender.utils.recalibration.covariates.Covariate; -import org.broadinstitute.hellbender.utils.recalibration.covariates.StandardCovariateList; +import org.broadinstitute.hellbender.utils.recalibration.covariates.BQSRCovariateList; import org.broadinstitute.hellbender.utils.Utils; import org.broadinstitute.hellbender.utils.collections.NestedIntegerArray; @@ -21,7 +21,7 @@ public final class RecalibrationTables implements Serializable, Iterable readGroupTable; @@ -36,11 +36,11 @@ public final class RecalibrationTables implements Serializable, Iterable(); this.allTables = new ArrayList<>(); diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/BQSRCovariateList.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/BQSRCovariateList.java new file mode 100644 index 00000000000..23244d77fb4 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/BQSRCovariateList.java @@ -0,0 +1,264 @@ +package org.broadinstitute.hellbender.utils.recalibration.covariates; + +import htsjdk.samtools.SAMFileHeader; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.broadinstitute.barclay.argparser.ClassFinder; +import org.broadinstitute.hellbender.exceptions.GATKException; +import org.broadinstitute.hellbender.exceptions.UserException; +import org.broadinstitute.hellbender.utils.recalibration.RecalibrationArgumentCollection; +import org.broadinstitute.hellbender.utils.read.GATKRead; +import org.reflections.Reflections; + +import java.io.Serializable; +import java.util.*; +import java.util.stream.Collectors; + +/** + * Represents a list of BQSR covariates. Formerly called StandardCovariateList. + * + * Note: the first two covariates ({@link ReadGroupCovariate} and {@link QualityScoreCovariate}) + * are special in the way that they are represented in the BQSR recalibration table, and are + * always required. We call these the "required covariates." + * + * The remaining covariates are called "additional covariates". The default additional covariates + * are the context and cycle covariates, but the client can request others and/or disable the default + * additional covariates. + * + * Also see the documentation in: {@link CustomCovariate}. + */ +public final class BQSRCovariateList implements Iterable, Serializable { + private static final long serialVersionUID = 1L; + + private static final Logger logger = LogManager.getLogger(BQSRCovariateList.class); + + private final ReadGroupCovariate readGroupCovariate; + private final QualityScoreCovariate qualityScoreCovariate; + private final List additionalCovariates; // additional covariates are [context, cycle, { custom covariates }] + private final List allCovariates; + + private final Map, Integer> indexByClass; + + private static final List REQUIRED_COVARIATE_NAMES = + Collections.unmodifiableList(Arrays.asList("ReadGroupCovariate", "QualityScoreCovariate")); + + private static final List STANDARD_COVARIATE_NAMES = + Collections.unmodifiableList(Arrays.asList("ContextCovariate", "CycleCovariate")); + + public static final List COVARIATE_PACKAGES = + Collections.unmodifiableList(Arrays.asList("org.broadinstitute.hellbender.utils.recalibration.covariates")); + public static final Set> DISCOVERED_COVARIATES; + + public static final int READ_GROUP_COVARIATE_DEFAULT_INDEX = 0; + public static final int BASE_QUALITY_COVARIATE_DEFAULT_INDEX = 1; + public static final int CONTEXT_COVARIATE_DEFAULT_INDEX = 2; + public static final int CYCLE_COVARIATE_DEFAULT_INDEX = 3; + public static final int NUM_REQUIRED_COVARITES = 2; + + static { + final ClassFinder classFinder = new ClassFinder(); + + for ( final String covariatePackage : COVARIATE_PACKAGES ) { + classFinder.find(covariatePackage, Covariate.class); + } + + DISCOVERED_COVARIATES = Collections.unmodifiableSet(classFinder.getConcreteClasses().stream() + .filter(cl -> !cl.getSimpleName().isEmpty()).collect(Collectors.toSet())); // Filter out the annonymous UnitTest classes. + } + + public static List getAllDiscoveredCovariateNames() { + return DISCOVERED_COVARIATES.stream().map(Class::getSimpleName).collect(Collectors.toList()); + } + + /** + * Creates a new list of BQSR covariates and initializes each covariate. + */ + public BQSRCovariateList(final RecalibrationArgumentCollection rac, final SAMFileHeader header) { + this(rac, ReadGroupCovariate.getReadGroupIDs(header)); + } + + /** + * Creates a new list of BQSR covariates and initializes each covariate. + */ + public BQSRCovariateList(final RecalibrationArgumentCollection rac, final List allReadGroups) { + readGroupCovariate = new ReadGroupCovariate(); + readGroupCovariate.initialize(rac, allReadGroups); + qualityScoreCovariate = new QualityScoreCovariate(); + qualityScoreCovariate.initialize(rac, allReadGroups); + + this.additionalCovariates = Collections.unmodifiableList(createNonrequiredCovariates(rac, allReadGroups)); + + final List allCovariatesList = new ArrayList<>(); + allCovariatesList.add(readGroupCovariate); + allCovariatesList.add(qualityScoreCovariate); + additionalCovariates.forEach(allCovariatesList::add); + this.allCovariates = Collections.unmodifiableList(allCovariatesList); + + //precompute for faster lookup (shows up on profile) + indexByClass = new LinkedHashMap<>(); + for(int i = 0; i < allCovariates.size(); i++){ + indexByClass.put(allCovariates.get(i).getClass(), i); + } + } + + public static boolean isRequiredCovariate(final String covariateName) { + return REQUIRED_COVARIATE_NAMES.contains(covariateName); + } + + public static boolean isStandardCovariate(final String covariateName) { + return STANDARD_COVARIATE_NAMES.contains(covariateName); + } + + /** + * Create the list of covariate objects containing all non-required covariates. + * i.e. Cycle, Context, and any custom covariates. + */ + private List createNonrequiredCovariates(final RecalibrationArgumentCollection rac, final List allReadGroups) { + final List result = new ArrayList<>(); + + // Add the standard covariates i.e. Cycle and Context. + if ( ! rac.DO_NOT_USE_STANDARD_COVARIATES ) { + result.addAll(createStandardCovariates(rac, allReadGroups)); + } + + for ( final String customCovariates : rac.COVARIATES ) { + if ( isRequiredCovariate(customCovariates) ) { + logger.warn("Covariate " + customCovariates + " is a required covariate that is always on. Ignoring explicit request for it."); + } + else if ( ! rac.DO_NOT_USE_STANDARD_COVARIATES && isStandardCovariate(customCovariates) ) { + logger.warn("Covariate " + customCovariates + " is a standard covariate that is always on when not running with --no-standard-covariates. Ignoring explicit request for it."); + } + else { + if ( result.stream().anyMatch(cov -> cov.getClass().getSimpleName().equals(customCovariates)) ) { + throw new UserException("Covariate " + customCovariates + " was requested multiple times"); + } + + result.add(createCovariate(customCovariates, rac, allReadGroups)); + } + } + + return result; + } + + // Initialize the standard covariates (ContextCovariate, CycleCovariate) and add to the list. + private List createStandardCovariates(final RecalibrationArgumentCollection rac, final List allReadGroups) { + final List result = new ArrayList<>(); + + for ( final String standardCovariateName : STANDARD_COVARIATE_NAMES ) { + result.add(createCovariate(standardCovariateName, rac, allReadGroups)); + } + + return result; + } + + private Covariate createCovariate(final String covariateName, final RecalibrationArgumentCollection rac, final List allReadGroups) { + for ( final Class covariateClass : DISCOVERED_COVARIATES ) { + if ( covariateName.equals(covariateClass.getSimpleName()) ) { + try { + @SuppressWarnings("unchecked") + final Covariate covariate = ((Class)covariateClass).getDeclaredConstructor().newInstance(); + covariate.initialize(rac, allReadGroups); + return covariate; + } + catch ( Exception e ) { + throw new GATKException("Error instantiating covariate class " + covariateClass.getSimpleName()); + } + } + } + + throw new UserException("No covariate with the name " + covariateName + " was found. " + + "Available covariates are: " + DISCOVERED_COVARIATES); + } + + /** + * Returns 2. ReadGroupCovariate and QualityScoreCovariate are always required + */ + public static int numberOfRequiredCovariates() { + return REQUIRED_COVARIATE_NAMES.size(); + } + + /** + * Returns the list of simple class names of our covariates. The returned list is unmodifiable. + * For example "CycleCovariate". + */ + public List getCovariateClassNames() { + return Collections.unmodifiableList(allCovariates.stream().map(cov -> cov.getClass().getSimpleName()).collect(Collectors.toList())); + } + + /** + * Returns the size of the list of standard covariates. + */ + public int size(){ + return allCovariates.size(); + } + + /** + * Returns a new iterator over all covariates in this list. + * Note: the list is unmodifiable and the iterator does not support modifying the list. + */ + @Override + public Iterator iterator() { + return allCovariates.iterator(); + } + + public ReadGroupCovariate getReadGroupCovariate() { + return readGroupCovariate; + } + + public QualityScoreCovariate getQualityScoreCovariate() { + return qualityScoreCovariate; + } + + /** + * returns an unmodifiable view of the nonrequired covariates stored in this list. + */ + public List getAdditionalCovariates() { + return additionalCovariates; + } + + /** + * Return a human-readable string representing the used covariates + * + * @return a non-null comma-separated string + */ + public String covariateNames() { + return String.join(",", getCovariateClassNames()); + } + + /** + * Get the covariate by the index. + * @throws IndexOutOfBoundsException if the index is out of range + * (index < 0 || index >= size()) + */ + public Covariate get(final int covIndex) { + return allCovariates.get(covIndex); + } + + /** + * Returns the index of the covariate by class name or -1 if not found. + */ + public int indexByClass(final Class clazz){ + return indexByClass.getOrDefault(clazz, -1); + } + + /** + * For each covariate compute the values for all positions in this read and + * record the values in the provided storage object. + */ + public void populatePerReadCovariateMatrix(final GATKRead read, final SAMFileHeader header, final PerReadCovariateMatrix resultsStorage, final boolean recordIndelValues) { + for (int i = 0, n = allCovariates.size(); i < n; i++) { + final Covariate cov = allCovariates.get(i); + resultsStorage.setCovariateIndex(i); + cov.recordValues(read, header, resultsStorage, recordIndelValues); + } + } + + /** + * Retrieves a covariate by the parsed name {@link Covariate#parseNameForReport()} or null + * if no covariate with that name exists in the list. + */ + public Covariate getCovariateByParsedName(final String covName) { + return allCovariates.stream().filter(cov -> cov.parseNameForReport().equals(covName)).findFirst().orElse(null); + } + +} diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ContextCovariate.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ContextCovariate.java index 8f6f5affa27..82e816a7f39 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ContextCovariate.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ContextCovariate.java @@ -13,19 +13,20 @@ import org.broadinstitute.hellbender.utils.clipping.ReadClipper; import org.broadinstitute.hellbender.utils.read.GATKRead; import org.broadinstitute.hellbender.utils.recalibration.RecalibrationArgumentCollection; +import java.util.List; /** * The read bases preceding and including the base in question (as opposed to reference bases). */ -public final class ContextCovariate implements Covariate { +public final class ContextCovariate implements StandardCovariate { private static final long serialVersionUID = 1L; private static final Logger logger = LogManager.getLogger(ContextCovariate.class); - private final int mismatchesContextSize; - private final int indelsContextSize; + private int mismatchesContextSize; + private int indelsContextSize; - private final int mismatchesKeyMask; - private final int indelsKeyMask; + private int mismatchesKeyMask; + private int indelsKeyMask; private static final int LENGTH_BITS = 4; private static final int LENGTH_MASK = 15; @@ -34,11 +35,12 @@ public final class ContextCovariate implements Covariate { // the maximum context size (number of bases) permitted; we need to keep the leftmost base free so that values are // not negative and we reserve 4 more bits to represent the length of the context; it takes 2 bits to encode one base. private static final int MAX_DNA_CONTEXT = 13; - private final byte lowQualTail; + private byte lowQualTail; public static final int UNKNOWN_OR_ERROR_CONTEXT_CODE = -1; - public ContextCovariate(final RecalibrationArgumentCollection RAC){ + @Override + public void initialize(final RecalibrationArgumentCollection RAC, final List readGroups) { mismatchesContextSize = RAC.MISMATCHES_CONTEXT_SIZE; indelsContextSize = RAC.INDELS_CONTEXT_SIZE; logger.debug("\t\tContext sizes: base substitution model " + mismatchesContextSize + ", indel substitution model " + indelsContextSize); @@ -75,7 +77,7 @@ public void recordValues(final GATKRead read, final SAMFileHeader header, final final int readLengthAfterClipping = strandedClippedBases.length; - // this is necessary to ensure that we don't keep historical data in the ReadCovariates values + // this is necessary to ensure that we don't keep historical data in the PerReadCovariateMatrix values // since the context covariate may not span the entire set of values in read covariates // due to the clipping of the low quality bases if ( readLengthAfterClipping != originalReadLength) { diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/Covariate.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/Covariate.java index dd69bb99b93..208e0f27807 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/Covariate.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/Covariate.java @@ -2,18 +2,27 @@ import htsjdk.samtools.SAMFileHeader; import org.broadinstitute.hellbender.utils.read.GATKRead; +import org.broadinstitute.hellbender.utils.recalibration.RecalibrationArgumentCollection; import java.io.Serializable; +import java.util.List; /** * The Covariate interface. A Covariate is a feature used in the recalibration that can be picked out of the read. * In general most error checking and adjustments to the data are done before the call to the covariates getValue methods in order to speed up the code. * This unfortunately muddies the code, but most of these corrections can be done per read while the covariates get called per base, resulting in a big speed up. - * - * Covariates are immutable objects after construction. All state setting and parameterization must happen during the construction call. */ public interface Covariate extends Serializable { public static long serialVersionUID = 1L; + + /** + * Initialize any member variables using the command-line arguments passed to the walker + * + * @param RAC the recalibration argument collection + * @param readGroups (only used by the ReadGroup covariate --- consider refactoring) + */ + public void initialize(final RecalibrationArgumentCollection RAC, final List readGroups); + /** * Calculates covariate values for all positions in the read. * diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/CustomCovariate.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/CustomCovariate.java new file mode 100644 index 00000000000..6dc7102ce73 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/CustomCovariate.java @@ -0,0 +1,13 @@ +package org.broadinstitute.hellbender.utils.recalibration.covariates; + +/** + * An interface to classify Covariate classes into: + * + * 1. Required (ReadGroup, QualityScore) + * 2. Standard (Cycle, Context) + * 3. Custom (any covariates defined by the user e.g. RepeatLength) + * + * We call 2 and 3 together the "additional" covariates. + */ +public interface CustomCovariate extends Covariate { +} diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/CycleCovariate.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/CycleCovariate.java index ed56308ec87..4d887fd3fcd 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/CycleCovariate.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/CycleCovariate.java @@ -5,17 +5,21 @@ import org.broadinstitute.hellbender.utils.read.GATKRead; import org.broadinstitute.hellbender.utils.recalibration.RecalibrationArgumentCollection; +import java.util.List; + /** * The Cycle covariate. * For ILLUMINA the cycle is simply the position in the read (counting backwards if it is a negative strand read) */ -public final class CycleCovariate implements Covariate { +public final class CycleCovariate implements StandardCovariate { private static final long serialVersionUID = 1L; - private final int MAXIMUM_CYCLE_VALUE; + private int MAXIMUM_CYCLE_VALUE; public static final int CUSHION_FOR_INDELS = 4; - public CycleCovariate(final RecalibrationArgumentCollection RAC){ + + @Override + public void initialize(final RecalibrationArgumentCollection RAC, final List readGroups) { this.MAXIMUM_CYCLE_VALUE = RAC.MAXIMUM_CYCLE_VALUE; } diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/PerReadCovariateMatrix.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/PerReadCovariateMatrix.java index 7fbf7bfb204..9f90a403619 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/PerReadCovariateMatrix.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/PerReadCovariateMatrix.java @@ -58,7 +58,7 @@ public void setCovariateIndex(final int index) { * @param mismatch the mismatch key value * @param insertion the insertion key value * @param deletion the deletion key value - * @param readOffset the read offset, must be >= 0 and <= the read length used to create this ReadCovariates + * @param readOffset the read offset, must be >= 0 and <= the read length used to create this PerReadCovariateMatrix */ public void addCovariate(final int mismatch, final int insertion, final int deletion, final int readOffset) { covariates[EventType.BASE_SUBSTITUTION.ordinal()][readOffset][currentCovariateIndex] = mismatch; diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/QualityScoreCovariate.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/QualityScoreCovariate.java index 5b51596d947..90d51795146 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/QualityScoreCovariate.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/QualityScoreCovariate.java @@ -6,13 +6,17 @@ import org.broadinstitute.hellbender.utils.read.GATKRead; import org.broadinstitute.hellbender.utils.read.ReadUtils; +import java.util.List; + /** * The Reported Quality Score covariate. */ -public final class QualityScoreCovariate implements Covariate { +public final class QualityScoreCovariate implements RequiredCovariate { private static final long serialVersionUID = 1L; + public static final int MAX_QUAL_SCORE_KEY = QualityUtils.MAX_SAM_QUAL_SCORE; - public QualityScoreCovariate(final RecalibrationArgumentCollection RAC){ + @Override + public void initialize(final RecalibrationArgumentCollection RAC, final List readGroups) { //nothing to initialize } diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ReadGroupCovariate.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ReadGroupCovariate.java index dc46b1f6cc2..835683d1650 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ReadGroupCovariate.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ReadGroupCovariate.java @@ -5,6 +5,7 @@ import org.broadinstitute.hellbender.utils.Utils; import org.broadinstitute.hellbender.utils.read.GATKRead; import org.broadinstitute.hellbender.utils.read.ReadUtils; +import org.broadinstitute.hellbender.utils.recalibration.RecalibrationArgumentCollection; import java.util.*; import java.util.stream.Collectors; @@ -12,24 +13,23 @@ /** * The Read Group covariate. */ -public final class ReadGroupCovariate implements Covariate { +public final class ReadGroupCovariate implements RequiredCovariate { private static final long serialVersionUID = 1L; - public static final int MISSING_READ_GROUP_KEY = -1; //Note: these maps are initialized and made umodifiable at construction so the whole covariate is an immutable object once it's constructed. - /* * Stores the mapping from read group id to a number. */ - private final Map readGroupLookupTable; + private Map readGroupLookupTable; /* * Stores the reverse mapping, from number to read group id. */ - private final Map readGroupReverseLookupTable; + private Map readGroupReverseLookupTable; - public ReadGroupCovariate(final List readGroups){ + @Override + public void initialize(final RecalibrationArgumentCollection RAC, final List readGroups) { final Map rgLookupTable = new LinkedHashMap<>(); final Map rgReverseLookupTable = new LinkedHashMap<>(); diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/RepeatLengthCovariate.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/RepeatLengthCovariate.java new file mode 100644 index 00000000000..3b6b218bd2e --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/RepeatLengthCovariate.java @@ -0,0 +1,197 @@ +package org.broadinstitute.hellbender.utils.recalibration.covariates; + +import htsjdk.samtools.SAMFileHeader; +import org.apache.commons.lang3.tuple.MutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.broadinstitute.hellbender.utils.BaseUtils; +import org.broadinstitute.hellbender.utils.read.GATKRead; +import org.broadinstitute.hellbender.utils.recalibration.RecalibrationArgumentCollection; +import org.broadinstitute.hellbender.utils.variant.GATKVariantContextUtils; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; + +public class RepeatLengthCovariate implements CustomCovariate { + private static final long serialVersionUID = 1L; + + protected int MAX_REPEAT_LENGTH; + protected int MAX_STR_UNIT_LENGTH; + private final HashMap repeatLookupTable = new HashMap<>(); + private final HashMap repeatReverseLookupTable = new HashMap<>(); + private int nextId = 0; + + // Initialize any member variables using the command-line arguments passed to the walkers + @Override + public void initialize(final RecalibrationArgumentCollection RAC, final List readGroups) { + MAX_STR_UNIT_LENGTH = 8; + MAX_REPEAT_LENGTH = 20; + } + + @Override + public void recordValues( final GATKRead read, final SAMFileHeader header, final PerReadCovariateMatrix values, final boolean recordIndelValues) { + // store the original bases and then write Ns over low quality ones + final byte[] originalBases = Arrays.copyOf(read.getBases(), read.getBases().length); + + final boolean negativeStrand = read.isReverseStrand(); + byte[] bases = read.getBases(); + if (negativeStrand) + bases = BaseUtils.simpleReverseComplement(bases); + + // don't record reads with N's + if (!BaseUtils.isAllRegularBases(bases)) + return; + + for (int i = 0; i < bases.length; i++) { + final Pair res = findTandemRepeatUnits(bases, i); + // to merge repeat unit and repeat length to get covariate value: + final String repeatID = getCovariateValueFromUnitAndLength(res.getLeft(), res.getRight()); + final int key = keyForRepeat(repeatID); + + final int readOffset = (negativeStrand ? bases.length - i - 1 : i); + values.addCovariate(key, key, key, readOffset); + } + + // put the original bases back in + read.setBases(originalBases); + } + + /** + * + * [A,C,G,A,C,G,(T),A,C,G,A,C,G] + * [T,T,T,A,C,T,(A),C,T,A,A,A,A] + * + * @param readBases + * @param offset + * @return a pair of byte array (the repeat unit) and integer (the number of repetitions). + */ // should really be a static method + public Pair findTandemRepeatUnits(byte[] readBases, int offset) { + int numRepetitions = 0; + byte[] bestBWRepeatUnit = new byte[]{readBases[offset]}; + for (int strSize = 1; strSize <= MAX_STR_UNIT_LENGTH; strSize++) { + // fix repeat unit length + //edge case: if candidate tandem repeat unit falls beyond edge of read, skip + if (offset+1-strSize < 0) + break; + + // + // Count the number of backwardRepeatUnits to the left of the offset. + // Example: strSize = 3, then backwardRepeatUnit = GTG + // offset + // | + // [A, G, T, G, T, (G), *, *, *, *, *] ( input string ) + // [A, G, T, G, T, (G) ] ( search repeat units in this substring ) + // + // backward repeat unit *includes* the offset base (unlike the forward repeat unit) + final byte[] backwardRepeatUnit = Arrays.copyOfRange(readBases, offset - strSize + 1, offset + 1); + String backwardRepeatUnitStr = new String(backwardRepeatUnit, StandardCharsets.UTF_8); + + // "leadingRepeats = false" indicates that we look for the consecutive repeat units in the *back end* of the substring. + numRepetitions = GATKVariantContextUtils.findNumberOfRepetitions(backwardRepeatUnit, Arrays.copyOfRange(readBases, 0, offset + 1), false); + if (numRepetitions > 1) { + bestBWRepeatUnit = Arrays.copyOf(backwardRepeatUnit, backwardRepeatUnit.length); + // By exiting early, could we miss longer, more informative STRs? + // Also, this means that TTTTT will be represented as TT repeated .... 2 times? + // It needs to be T repeated 5 times. + break; + } + } + byte[] bestRepeatUnit = bestBWRepeatUnit; + int maxRepeatLength = numRepetitions; + + if (offset < readBases.length-1) { + byte[] bestFWRepeatUnit = new byte[]{readBases[offset+1]}; + int maxFW = 0; + for (int strLength = 1; strLength <= MAX_STR_UNIT_LENGTH; strLength++) { + // fix repeat unit length + // edge case: if candidate tandem repeat unit falls beyond edge of read, skip + if (offset+strLength+1 > readBases.length) + break; + + // get forward repeat unit and # repeats (offset + 1 .... so we don't include the base at offset...) + byte[] forwardRepeatUnit = Arrays.copyOfRange(readBases, offset + 1, offset + strLength + 1); + String forwardRepeatUnitStr = new String(forwardRepeatUnit, StandardCharsets.UTF_8); + maxFW = GATKVariantContextUtils.findNumberOfRepetitions(forwardRepeatUnit, Arrays.copyOfRange(readBases, offset + 1, readBases.length), true); + if (maxFW > 1) { + bestFWRepeatUnit = Arrays.copyOf(forwardRepeatUnit, forwardRepeatUnit.length); + break; + } + } + // if FW repeat unit = BW repeat unit it means we're in the middle of a tandem repeat - add FW and BW components + if (Arrays.equals(bestFWRepeatUnit, bestBWRepeatUnit)) { + maxRepeatLength = numRepetitions + maxFW; + bestRepeatUnit = bestFWRepeatUnit; // arbitrary + } + else { + // tandem repeat starting forward from current offset. + // It could be the case that best BW unit was different from FW unit, but that BW still contains FW unit. + // For example, TTCTT(C) CCC - at (C) place, best BW unit is (TTC)2, best FW unit is (C)3. + // but correct representation at that place might be (C)4. + // Hence, if the FW and BW units don't match, check if BW unit can still be a part of FW unit and add + // representations to total + numRepetitions = GATKVariantContextUtils.findNumberOfRepetitions(bestFWRepeatUnit, Arrays.copyOfRange(readBases, 0, offset + 1), false); + maxRepeatLength = maxFW + numRepetitions; + bestRepeatUnit = bestFWRepeatUnit; + + } + + } + + if(maxRepeatLength > MAX_REPEAT_LENGTH) { maxRepeatLength = MAX_REPEAT_LENGTH; } + return new MutablePair<>(bestRepeatUnit, maxRepeatLength); + + } + + @Override + public String formatKey(final int key) { + return repeatReverseLookupTable.get(key); + } + + protected String getCovariateValueFromUnitAndLength(final byte[] repeatFromUnitAndLength, final int repeatLength) { + return String.format("%d", repeatLength); + } + + + @Override + public int keyFromValue(final Object value) { + return keyForRepeat((String) value); + } + + private int keyForRepeat(final String repeatID) { + if ( ! repeatLookupTable.containsKey(repeatID) ) { + repeatLookupTable.put(repeatID, nextId); + repeatReverseLookupTable.put(nextId, repeatID); + nextId++; + } + return repeatLookupTable.get(repeatID); + } + + + /** + * Splits repeat unit and num repetitions from covariate value. + * For example, if value if "ATG4" it returns (ATG,4) + * @param value Covariate value + * @return Split pair + */ + public static Pair getRUandNRfromCovariate(final String value) { + + int k = 0; + for ( k=0; k < value.length(); k++ ) { + if (!BaseUtils.isRegularBase(value.getBytes()[k])) + break; + } + Integer nr = Integer.valueOf(value.substring(k, value.length())); // will throw NumberFormatException if format illegal + if (k == value.length() || nr <= 0) + throw new IllegalStateException("Covariate is not of form (Repeat Unit) + Integer"); + + return new MutablePair<>(value.substring(0,k), nr); + } + + @Override + public int maximumKeyValue() { + // max possible values of covariate: for repeat unit, length is up to MAX_STR_UNIT_LENGTH, + // so we have 4^MAX_STR_UNIT_LENGTH * MAX_REPEAT_LENGTH possible values + return (1+MAX_REPEAT_LENGTH); + } +} \ No newline at end of file diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/RequiredCovariate.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/RequiredCovariate.java new file mode 100644 index 00000000000..6634007c66f --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/RequiredCovariate.java @@ -0,0 +1,7 @@ +package org.broadinstitute.hellbender.utils.recalibration.covariates; + +/** + * See {@link CustomCovariate} for the classification of covariates. + */ +public interface RequiredCovariate extends Covariate { +} diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/StandardCovariate.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/StandardCovariate.java new file mode 100644 index 00000000000..3f178e5de22 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/StandardCovariate.java @@ -0,0 +1,7 @@ +package org.broadinstitute.hellbender.utils.recalibration.covariates; + +/** + * See {@link CustomCovariate} for the classification of covariates. + */ +public interface StandardCovariate extends Covariate { +} diff --git a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/StandardCovariateList.java b/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/StandardCovariateList.java deleted file mode 100644 index 5884a8bb559..00000000000 --- a/src/main/java/org/broadinstitute/hellbender/utils/recalibration/covariates/StandardCovariateList.java +++ /dev/null @@ -1,150 +0,0 @@ -package org.broadinstitute.hellbender.utils.recalibration.covariates; - -import htsjdk.samtools.SAMFileHeader; -import org.broadinstitute.hellbender.utils.recalibration.RecalibrationArgumentCollection; -import org.broadinstitute.hellbender.utils.read.GATKRead; - -import java.io.Serializable; -import java.util.*; -import java.util.stream.Collectors; - -/** - * Represents the list of standard BQSR covariates. - * - * Note: the first two covariates ({@link ReadGroupCovariate} and {@link QualityScoreCovariate}) - * are special in the way that they are represented in the BQSR recalibration table. - * - * The remaining covariates are called "additional covariates". - */ -public final class StandardCovariateList implements Iterable, Serializable { - private static final long serialVersionUID = 1L; - private final ReadGroupCovariate readGroupCovariate; - private final QualityScoreCovariate qualityScoreCovariate; - private final List additionalCovariates; - private final List allCovariates; - - private final Map, Integer> indexByClass; - - public static final int READ_GROUP_COVARIATE_DEFAULT_INDEX = 0; - public static final int BASE_QUALITY_COVARIATE_DEFAULT_INDEX = 1; - public static final int CONTEXT_COVARIATE_DEFAULT_INDEX = 2; - public static final int CYCLE_COVARIATE_DEFAULT_INDEX = 3; - public static final int NUM_REQUIRED_COVARITES = 2; - - /** - * Creates a new list of standard BQSR covariates and initializes each covariate. - */ - public StandardCovariateList(final RecalibrationArgumentCollection rac, final List allReadGroups) { - readGroupCovariate = new ReadGroupCovariate(allReadGroups); - qualityScoreCovariate = new QualityScoreCovariate(rac); - final ContextCovariate contextCovariate = new ContextCovariate(rac); - final CycleCovariate cycleCovariate = new CycleCovariate(rac); - - additionalCovariates = Collections.unmodifiableList(Arrays.asList(contextCovariate, cycleCovariate)); - allCovariates = Collections.unmodifiableList(Arrays.asList(readGroupCovariate, qualityScoreCovariate, contextCovariate, cycleCovariate)); - //precompute for faster lookup (shows up on profile) - indexByClass = new LinkedHashMap<>(); - for(int i = 0; i < allCovariates.size(); i++){ - indexByClass.put(allCovariates.get(i).getClass(), i); - } - } - - /** - * Creates a new list of standard BQSR covariates and initializes each covariate. - */ - public StandardCovariateList(final RecalibrationArgumentCollection rac, final SAMFileHeader header){ - this(rac, ReadGroupCovariate.getReadGroupIDs(header)); - } - - /** - * Returns 2. ReadGroupCovariate and QualityScoreCovariate are special - */ - public int numberOfSpecialCovariates() { - return 2; - } - - /** - * Returns the list of simple class names of standard covariates. The returned list is unmodifiable. - * For example CycleCovariate. - */ - public List getStandardCovariateClassNames() { - return Collections.unmodifiableList(allCovariates.stream().map(cov -> cov.getClass().getSimpleName()).collect(Collectors.toList())); - } - - /** - * Returns the size of the list of standard covariates. - */ - public int size(){ - return allCovariates.size(); - } - - /** - * Returns a new iterator over all covariates in this list. - * Note: the list is unmodifiable and the iterator does not support modifying the list. - */ - @Override - public Iterator iterator() { - return allCovariates.iterator(); - } - - public ReadGroupCovariate getReadGroupCovariate() { - return readGroupCovariate; - } - - public QualityScoreCovariate getQualityScoreCovariate() { - return qualityScoreCovariate; - } - - /** - * returns an unmodifiable view of the additional covariates stored in this list. - */ - public Iterable getAdditionalCovariates() { - return additionalCovariates; - } - - /** - * Return a human-readable string representing the used covariates - * - * @return a non-null comma-separated string - */ - public String covariateNames() { - return String.join(",", getStandardCovariateClassNames()); - } - - /** - * Get the covariate by the index. - * @throws IndexOutOfBoundsException if the index is out of range - * (index < 0 || index >= size()) - */ - public Covariate get(final int covIndex) { - return allCovariates.get(covIndex); - } - - /** - * Returns the index of the covariate by class name or -1 if not found. - */ - public int indexByClass(final Class clazz){ - return indexByClass.getOrDefault(clazz, -1); - } - - /** - * For each covariate compute the values for all positions in this read and - * record the values in the provided storage object. - */ - public void populatePerReadCovariateMatrix(final GATKRead read, final SAMFileHeader header, final PerReadCovariateMatrix perReadCovariateMatrix, final boolean recordIndelValues) { - for (int i = 0, n = allCovariates.size(); i < n; i++) { - final Covariate cov = allCovariates.get(i); - perReadCovariateMatrix.setCovariateIndex(i); // TODO: avoid this pattern. "cov" should already know which index it belongs to. - cov.recordValues(read, header, perReadCovariateMatrix, recordIndelValues); - } - } - - /** - * Retrieves a covariate by the parsed name {@link Covariate#parseNameForReport()} or null - * if no covariate with that name exists in the list. - */ - public Covariate getCovariateByParsedName(final String covName) { - return allCovariates.stream().filter(cov -> cov.parseNameForReport().equals(covName)).findFirst().orElse(null); - } - -} diff --git a/src/main/java/org/broadinstitute/hellbender/utils/variant/GATKVariantContextUtils.java b/src/main/java/org/broadinstitute/hellbender/utils/variant/GATKVariantContextUtils.java index 2df684087e4..2123ca9029a 100644 --- a/src/main/java/org/broadinstitute/hellbender/utils/variant/GATKVariantContextUtils.java +++ b/src/main/java/org/broadinstitute/hellbender/utils/variant/GATKVariantContextUtils.java @@ -934,10 +934,10 @@ public static int findRepeatedSubstring(byte[] bases) { } /** - * Finds number of repetitions a string consists of. + * Finds the number of repetitions a string consists of. * For example, for string ATAT and repeat unit AT, number of repetitions = 2 - * @param repeatUnit Non-empty substring represented by byte array - * @param testString String to test (represented by byte array), may be empty + * @param repeatUnit The repeat unit to count e.g. AT in ATATAT + * @param targetString The sequence in which we count the number of repeatUnit. May be empty. * @param leadingRepeats Look for leading (at the beginning of string) or trailing (at end of string) repetitions * For example: * GATAT has 0 leading repeats of AT but 2 trailing repeats of AT @@ -946,14 +946,14 @@ public static int findRepeatedSubstring(byte[] bases) { * * @return Number of repetitions (0 if testString is not a concatenation of n repeatUnit's, including the case of empty testString) */ - public static int findNumberOfRepetitions(byte[] repeatUnit, byte[] testString, boolean leadingRepeats) { - Utils.nonNull(repeatUnit, "repeatUnit"); - Utils.nonNull(testString, "testString"); + public static int findNumberOfRepetitions(byte[] repeatUnit, byte[] targetString, boolean leadingRepeats) { + Utils.nonNull(repeatUnit, "repeatUnit cannot be null."); + Utils.nonNull(targetString, "testString cannot be null."); Utils.validateArg(repeatUnit.length != 0, "empty repeatUnit"); - if (testString.length == 0){ + if (targetString.length == 0){ return 0; } - return findNumberOfRepetitions(repeatUnit, 0, repeatUnit.length, testString, 0, testString.length, leadingRepeats); + return findNumberOfRepetitions(repeatUnit, 0, repeatUnit.length, targetString, 0, targetString.length, leadingRepeats); } /** diff --git a/src/test/java/org/broadinstitute/hellbender/tools/spark/BaseRecalibratorSparkIntegrationTest.java b/src/test/java/org/broadinstitute/hellbender/tools/spark/BaseRecalibratorSparkIntegrationTest.java index bebdd29018b..08244caf28b 100644 --- a/src/test/java/org/broadinstitute/hellbender/tools/spark/BaseRecalibratorSparkIntegrationTest.java +++ b/src/test/java/org/broadinstitute/hellbender/tools/spark/BaseRecalibratorSparkIntegrationTest.java @@ -15,6 +15,7 @@ import java.io.File; import java.io.IOException; +import java.io.Serializable; import java.util.Arrays; public final class BaseRecalibratorSparkIntegrationTest extends CommandLineProgramTest { diff --git a/src/test/java/org/broadinstitute/hellbender/tools/walkers/bqsr/ApplyBQSRIntegrationTest.java b/src/test/java/org/broadinstitute/hellbender/tools/walkers/bqsr/ApplyBQSRIntegrationTest.java index 8bdd83593e7..02c9093e0c0 100644 --- a/src/test/java/org/broadinstitute/hellbender/tools/walkers/bqsr/ApplyBQSRIntegrationTest.java +++ b/src/test/java/org/broadinstitute/hellbender/tools/walkers/bqsr/ApplyBQSRIntegrationTest.java @@ -17,11 +17,19 @@ import org.broadinstitute.hellbender.testutils.ArgumentsBuilder; import org.broadinstitute.hellbender.tools.ApplyBQSRArgumentCollection; import org.broadinstitute.hellbender.tools.ApplyBQSRUniqueArgumentCollection; +import org.broadinstitute.hellbender.transformers.BQSRReadTransformer; +import org.broadinstitute.hellbender.utils.collections.NestedIntegerArray; import org.broadinstitute.hellbender.utils.gcs.BucketUtils; import org.broadinstitute.hellbender.testutils.IntegrationTestSpec; import org.broadinstitute.hellbender.testutils.SamAssertionUtils; import org.broadinstitute.hellbender.utils.read.GATKRead; import org.broadinstitute.hellbender.utils.read.ReadUtils; +import org.broadinstitute.hellbender.utils.recalibration.RecalDatum; +import org.broadinstitute.hellbender.utils.recalibration.RecalibrationArgumentCollection; +import org.broadinstitute.hellbender.utils.recalibration.RecalibrationReport; +import org.broadinstitute.hellbender.utils.recalibration.RecalibrationTables; +import org.broadinstitute.hellbender.utils.recalibration.covariates.BQSRCovariateList; +import org.broadinstitute.hellbender.utils.recalibration.covariates.QualityScoreCovariate; import org.testng.Assert; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; @@ -30,6 +38,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.stream.Stream; @@ -368,4 +377,136 @@ public void testAddingPG() throws IOException { //output has a GATK ApplyBQSR in headers Assert.assertNotNull(SamReaderFactory.makeDefault().open(outFile).getFileHeader().getProgramRecord("GATK ApplyBQSR")); } + + @Test + public void testCustomCovariates() { + final File inputBam = new File(resourceDir + WGS_B37_CH20_1M_1M1K_BAM); + final File knownSites = new File(resourceDir + DBSNP_138_B37_CH20_1M_1M1K_VCF); + + //*** BaseRecalibrator with custom covariates ***// + final File recalTableOutput = createTempFile("recal_table", ".txt"); + final ArgumentsBuilder recalibratorArgs = getArgumentsForRecalibration(inputBam, recalTableOutput, + knownSites, Arrays.asList("RepeatLengthCovariate")); + runCommandLine(recalibratorArgs, BaseRecalibrator.class.getSimpleName()); + + //*** Apply BQSR with customCovariates ***// + final File recalibratedBam = createTempFile("customCovariateTest", ".bam"); + final ArgumentsBuilder argsForApplyBQSR = getArgumentsForApplyBQSR(inputBam, recalibratedBam, recalTableOutput); + runCommandLine(argsForApplyBQSR, ApplyBQSR.class.getSimpleName()); + + //*** Baseline BaseRecalibrator ***// + final File baselineRecalTableOutput = createTempFile("baseline_recal_table", ".txt"); + final ArgumentsBuilder recalibratorArgsBaseline = getArgumentsForRecalibration(inputBam, baselineRecalTableOutput, + knownSites, Collections.emptyList()); + runCommandLine(recalibratorArgsBaseline, BaseRecalibrator.class.getSimpleName()); + + //*** Baseline ApplyBQSR ***// + final File baselineRecalibratedBam = createTempFile("baseline_customCovariateTest", ".bam"); + final ArgumentsBuilder baselineArgsForApplyBQSR = getArgumentsForApplyBQSR(inputBam, baselineRecalibratedBam, baselineRecalTableOutput); + runCommandLine(baselineArgsForApplyBQSR, ApplyBQSR.class.getSimpleName()); + + //*** Validation ***// + // Check that...the output is the same for the rest of the other covariates when this special covariates are not used. + final RecalibrationReport evalRecalReport = new RecalibrationReport(recalTableOutput); + final RecalibrationReport baselineRecalReport = new RecalibrationReport(baselineRecalTableOutput); + + // worry about porting this to recalibrator test class later + final RecalibrationTables evalTables = evalRecalReport.getRecalibrationTables(); + final NestedIntegerArray evalReadGroupTable = evalTables.getReadGroupTable(); + final NestedIntegerArray evalQualityScoreTable = evalTables.getQualityScoreTable(); + final NestedIntegerArray evalContextTable = evalTables.getTable(BQSRCovariateList.CONTEXT_COVARIATE_DEFAULT_INDEX); + final NestedIntegerArray evalCycleTable = evalTables.getTable(BQSRCovariateList.CYCLE_COVARIATE_DEFAULT_INDEX); + final NestedIntegerArray evalRepeatLengthTable = evalTables.getTable(BQSRCovariateList.CYCLE_COVARIATE_DEFAULT_INDEX + 1); + + + final RecalibrationTables baselineTables = baselineRecalReport.getRecalibrationTables(); + final NestedIntegerArray baselineReadGroupTable = baselineTables.getReadGroupTable(); + final NestedIntegerArray baselineQualityScoreTable = baselineTables.getQualityScoreTable(); + // tsato: The dimension for context is something like 1000, this has got to be so wasteful. + final NestedIntegerArray baselineContextTable = baselineTables.getTable(BQSRCovariateList.CONTEXT_COVARIATE_DEFAULT_INDEX); + final NestedIntegerArray baselineCycleTable = baselineTables.getTable(BQSRCovariateList.CYCLE_COVARIATE_DEFAULT_INDEX); + int d = 3; // I forget --- are cycle and context tables separate? But printed as the same in the file? + + // Read groups shouldn't have changed... + int numReadGroups = new ReadsPathDataSource(inputBam.toPath()).getHeader().getReadGroups().size(); + long readGroupCount = 0; + for (final NestedIntegerArray.Leaf leaf : baselineReadGroupTable.getAllLeaves()) { + int[] keys = leaf.keys; + RecalDatum baselineReadGroupDatum = leaf.value; + RecalDatum evalReadGroupDatum = evalReadGroupTable.get2Keys(keys[0], BQSRReadTransformer.BASE_SUBSTITUTION_INDEX); // Ah, this is probably the right thing to do. 0 indexes the event --- get snp + Assert.assertEquals(evalReadGroupDatum, baselineReadGroupDatum); + readGroupCount += baselineReadGroupDatum.getNumObservations(); + } + + long reportedQualityCount = 0; + for (final NestedIntegerArray.Leaf leaf : baselineQualityScoreTable.getAllLeaves()) { + int[] keys = leaf.keys; + RecalDatum baselineQualityScoreDatum = leaf.value; + RecalDatum evalQualityScoreDatum = evalQualityScoreTable.get3Keys(keys[0], keys[1], BQSRReadTransformer.BASE_SUBSTITUTION_INDEX); // Ah, this is probably the right thing to do. 0 indexes the event --- get snp + Assert.assertEquals(evalQualityScoreDatum, baselineQualityScoreDatum); + reportedQualityCount += baselineQualityScoreDatum.getNumObservations(); + } + + + + // Check that the context covariate did not change after adding a custom covariate + long contextCount = 0L; + for (final NestedIntegerArray.Leaf leaf : baselineContextTable.getAllLeaves()) { + int[] keys = leaf.keys; + final RecalDatum baselineContextDatum = leaf.value; + final RecalDatum evalContextDatum = evalContextTable.get4Keys(keys[0], keys[1], keys[2], BQSRReadTransformer.BASE_SUBSTITUTION_INDEX); + Assert.assertEquals(evalContextDatum, baselineContextDatum); + contextCount += baselineContextDatum.getNumObservations(); + } + + // Ditto cycle covariate + long cycleCount = 0L; + for (final NestedIntegerArray.Leaf leaf : baselineCycleTable.getAllLeaves()) { + int[] keys = leaf.keys; + final RecalDatum baselineCycleDatum = leaf.value; + final RecalDatum evalCycleDatum = evalCycleTable.get4Keys(keys[0], keys[1], keys[2], BQSRReadTransformer.BASE_SUBSTITUTION_INDEX); + Assert.assertEquals(evalCycleDatum, baselineCycleDatum); + cycleCount += baselineCycleDatum.getNumObservations(); + } + + // Some basics checks on the repeat length covariates + long repeatLengthCount = 0; + for (final NestedIntegerArray.Leaf leaf : evalRepeatLengthTable.getAllLeaves()) { + RecalDatum datum = leaf.value; + repeatLengthCount += datum.getNumObservations(); + } + + // TODO: contextCount isn't the same as the rest. Investigate. + Assert.assertEquals(repeatLengthCount, readGroupCount); + } + + private ArgumentsBuilder getArgumentsForRecalibration(final File inputBam, final File outputBam, + final File knownSites, final List customCovariates){ + final ArgumentsBuilder result = new ArgumentsBuilder(); + result.addInput(inputBam); + result.addOutput(outputBam); + result.add(BaseRecalibrator.KNOWN_SITES_ARG_FULL_NAME, knownSites); + result.addReference(GCS_b37_CHR20_21_REFERENCE); + result.add(ApplyBQSRArgumentCollection.USE_ORIGINAL_QUALITIES_LONG_NAME, true); + for (String customCovariate : customCovariates){ + result.add(RecalibrationArgumentCollection.COVARIATES_LONG_NAME, customCovariate); + } + + return result; + } + + private ArgumentsBuilder getArgumentsForApplyBQSR(final File inputBam, final File outputBam, final File recalTable){ + final ArgumentsBuilder result = new ArgumentsBuilder(); + result.addInput(inputBam); + result.add(StandardArgumentDefinitions.BQSR_TABLE_LONG_NAME, recalTable); + result.addOutput(outputBam); + result.add(ApplyBQSRArgumentCollection.ALLOW_MISSING_READ_GROUPS_LONG_NAME, true); + // per the warp pipeline + result.add(ApplyBQSRUniqueArgumentCollection.STATIC_QUANTIZED_QUALS_LONG_NAME, 10); + result.add(ApplyBQSRUniqueArgumentCollection.STATIC_QUANTIZED_QUALS_LONG_NAME, 20); + result.add(ApplyBQSRUniqueArgumentCollection.STATIC_QUANTIZED_QUALS_LONG_NAME, 30); + result.add(ApplyBQSRUniqueArgumentCollection.STATIC_QUANTIZED_QUALS_LONG_NAME, 40); + result.add(ApplyBQSRArgumentCollection.USE_ORIGINAL_QUALITIES_LONG_NAME, true); + return result; + } } diff --git a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationReportUnitTest.java b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationReportUnitTest.java index 6d683828c18..0e0dcdf0687 100644 --- a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationReportUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationReportUnitTest.java @@ -142,10 +142,19 @@ private static RecalDatum createRandomRecalDatum(int maxObservations, int maxErr return new RecalDatum((long)nObservations, (double)nErrors, (byte)qual); } - @Test(expectedExceptions = UserException.class) - public void testUnsupportedCovariates(){ - File file = new File(toolsTestDir + "unsupported-covariates.table.gz"); - new RecalibrationReport(file); + @Test + public void testNonStandardCovariates(){ + File file = new File(toolsTestDir + "nonstandard-covariates.table.gz"); + final RecalibrationReport report = new RecalibrationReport(file); + final BQSRCovariateList covariates = report.getCovariates(); + + Assert.assertEquals(covariates.size(), 5, "There should be 5 covariates in the report"); + final List covariateNames = covariates.getCovariateClassNames(); + Assert.assertTrue(covariateNames.contains("ReadGroupCovariate"), "ReadGroupCovariate should be present but wasn't"); + Assert.assertTrue(covariateNames.contains("QualityScoreCovariate"), "QualityScoreCovariate should be present but wasn't"); + Assert.assertTrue(covariateNames.contains("ContextCovariate"), "ContextCovariate should be present but wasn't"); + Assert.assertTrue(covariateNames.contains("CycleCovariate"), "CycleCovariate should be present but wasn't"); + Assert.assertTrue(covariateNames.contains("RepeatLengthCovariate"), "RepeatLengthCovariate should be present but wasn't"); } @Test @@ -165,7 +174,7 @@ public void testOutput() { quantizationInfo.noQuantization(); final String readGroupID = "id"; - final StandardCovariateList covariateList = new StandardCovariateList(RAC, Collections.singletonList(readGroupID)); + final BQSRCovariateList covariateList = new BQSRCovariateList(RAC, Collections.singletonList(readGroupID)); final SAMReadGroupRecord rg = new SAMReadGroupRecord(readGroupID); rg.setPlatform("illumina"); diff --git a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationTablesUnitTest.java b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationTablesUnitTest.java index fcdb8245121..b4373055771 100644 --- a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationTablesUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/RecalibrationTablesUnitTest.java @@ -2,7 +2,7 @@ import org.broadinstitute.hellbender.utils.collections.NestedIntegerArray; import org.broadinstitute.hellbender.utils.recalibration.covariates.Covariate; -import org.broadinstitute.hellbender.utils.recalibration.covariates.StandardCovariateList; +import org.broadinstitute.hellbender.utils.recalibration.covariates.BQSRCovariateList; import org.broadinstitute.hellbender.GATKBaseTest; import org.testng.Assert; import org.testng.annotations.BeforeMethod; @@ -15,7 +15,7 @@ public final class RecalibrationTablesUnitTest extends GATKBaseTest { private RecalibrationTables tables; - private StandardCovariateList covariates; + private BQSRCovariateList covariates; private int numReadGroups = 6; final byte qualByte = 1; final List combineStates = Arrays.asList(0, 1, 2); @@ -23,7 +23,7 @@ public final class RecalibrationTablesUnitTest extends GATKBaseTest { @BeforeMethod private void makeTables() { final List readGroups= IntStream.range(1, numReadGroups).mapToObj(i -> "readgroup"+i).collect(Collectors.toList()); - covariates = new StandardCovariateList(new RecalibrationArgumentCollection(), readGroups); + covariates = new BQSRCovariateList(new RecalibrationArgumentCollection(), readGroups); tables = new RecalibrationTables(covariates, numReadGroups); fillTable(tables); } diff --git a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/StandardCovariateListUnitTest.java b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/BQSRCovariateListUnitTest.java similarity index 74% rename from src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/StandardCovariateListUnitTest.java rename to src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/BQSRCovariateListUnitTest.java index bf11718082e..1965047ae76 100644 --- a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/StandardCovariateListUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/BQSRCovariateListUnitTest.java @@ -10,45 +10,43 @@ import java.util.Collections; import java.util.List; -import java.util.stream.Collectors; -public final class StandardCovariateListUnitTest extends GATKBaseTest { +public final class BQSRCovariateListUnitTest extends GATKBaseTest { - public StandardCovariateList makeCovariateList() { - return new StandardCovariateList(new RecalibrationArgumentCollection(), Collections.singletonList("readGroup")); + public BQSRCovariateList makeCovariateList() { + return new BQSRCovariateList(new RecalibrationArgumentCollection(), Collections.singletonList("readGroup")); } @Test public void testSize() { - StandardCovariateList scl = makeCovariateList(); + BQSRCovariateList scl = makeCovariateList(); Assert.assertEquals(scl.size(), 4); } @Test public void testCovariateNames() { - StandardCovariateList scl = makeCovariateList(); + BQSRCovariateList scl = makeCovariateList(); Assert.assertEquals(scl.covariateNames(), "ReadGroupCovariate,QualityScoreCovariate,ContextCovariate,CycleCovariate"); } @Test public void testIterator() { - StandardCovariateList scl = makeCovariateList(); + BQSRCovariateList scl = makeCovariateList(); Assert.assertEquals(Utils.stream(scl).count(), 4); } @Test public void testGetCovariates() { - StandardCovariateList scl = makeCovariateList(); + BQSRCovariateList scl = makeCovariateList(); Assert.assertEquals(scl.getReadGroupCovariate().parseNameForReport(), "ReadGroup"); Assert.assertEquals(scl.getQualityScoreCovariate().parseNameForReport(), "QualityScore"); - final List additionalCovars = Utils.stream(scl.getAdditionalCovariates()).collect(Collectors.toList()); - Assert.assertEquals(additionalCovars.get(0).parseNameForReport(), "Context"); - Assert.assertEquals(additionalCovars.get(1).parseNameForReport(), "Cycle"); + Assert.assertEquals(scl.getAdditionalCovariates().get(0).parseNameForReport(), "Context"); + Assert.assertEquals(scl.getAdditionalCovariates().get(1).parseNameForReport(), "Cycle"); } @Test public void testGetCovariatesByIndex() { - StandardCovariateList scl = makeCovariateList(); + BQSRCovariateList scl = makeCovariateList(); Assert.assertEquals(scl.get(0).parseNameForReport(), "ReadGroup"); Assert.assertEquals(scl.get(1).parseNameForReport(), "QualityScore"); Assert.assertEquals(scl.get(2).parseNameForReport(), "Context"); @@ -57,13 +55,13 @@ public void testGetCovariatesByIndex() { @Test(expectedExceptions = IndexOutOfBoundsException.class) public void testGetCovariatesByIndexInvalid() { - StandardCovariateList scl = makeCovariateList(); + BQSRCovariateList scl = makeCovariateList(); scl.get(4); } @Test public void testGetCovariatesByIndexClass() { - StandardCovariateList scl = makeCovariateList(); + BQSRCovariateList scl = makeCovariateList(); Assert.assertEquals(scl.indexByClass(ReadGroupCovariate.class), 0); Assert.assertEquals(scl.indexByClass(QualityScoreCovariate.class), 1); Assert.assertEquals(scl.indexByClass(ContextCovariate.class), 2); @@ -73,6 +71,11 @@ public void testGetCovariatesByIndexClass() { Assert.assertEquals(scl.indexByClass(new Covariate() { private static final long serialVersionUID = 1L; + @Override + public void initialize( RecalibrationArgumentCollection RAC, List readGroups ) { + + } + @Override public void recordValues(GATKRead read, SAMFileHeader header, PerReadCovariateMatrix values, boolean recordIndels) { @@ -97,7 +100,7 @@ public int maximumKeyValue() { @Test public void testGetCovariatesByParsedName() { - StandardCovariateList scl = makeCovariateList(); + BQSRCovariateList scl = makeCovariateList(); final String[] parsedNames = {"ReadGroup", "QualityScore", "Context", "Cycle"}; for (String parsedName : parsedNames) { Assert.assertEquals(scl.getCovariateByParsedName(parsedName).parseNameForReport(), parsedName); @@ -107,7 +110,7 @@ public void testGetCovariatesByParsedName() { @Test public void testCovariateInitialize() { - StandardCovariateList scl = makeCovariateList(); + BQSRCovariateList scl = makeCovariateList(); //this just tests non blowing up. } } diff --git a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ContextCovariateUnitTest.java b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ContextCovariateUnitTest.java index d7dcdc0de7f..df2c1657f14 100644 --- a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ContextCovariateUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ContextCovariateUnitTest.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.Collections; import java.util.Random; public final class ContextCovariateUnitTest extends GATKBaseTest { @@ -27,9 +28,9 @@ public final class ContextCovariateUnitTest extends GATKBaseTest { @BeforeClass public void init() { RAC = new RecalibrationArgumentCollection(); - covariate = new ContextCovariate(RAC); Utils.resetRandomGenerator(); - + covariate = new ContextCovariate(); + covariate.initialize(RAC, Collections.emptyList()); } @Test diff --git a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/CycleCovariateUnitTest.java b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/CycleCovariateUnitTest.java index 53b6af2abfd..0950240529c 100644 --- a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/CycleCovariateUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/CycleCovariateUnitTest.java @@ -13,6 +13,8 @@ import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import java.util.Arrays; + import static java.lang.Math.abs; public final class CycleCovariateUnitTest extends GATKBaseTest { @@ -23,10 +25,12 @@ public final class CycleCovariateUnitTest extends GATKBaseTest { @BeforeClass public void init() { - RAC = new RecalibrationArgumentCollection(); - covariate = new CycleCovariate(RAC); illuminaReadGroup = new SAMReadGroupRecord("MY.ID"); illuminaReadGroup.setPlatform("illumina"); + + RAC = new RecalibrationArgumentCollection(); + covariate = new CycleCovariate(); + covariate.initialize(RAC, Arrays.asList(illuminaReadGroup.getId())); } @Test diff --git a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/PerReadCovariateMatrixUnitTest.java b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/PerReadCovariateMatrixUnitTest.java index 4223813d692..829229a5a0c 100644 --- a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/PerReadCovariateMatrixUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/PerReadCovariateMatrixUnitTest.java @@ -22,12 +22,16 @@ public void testCovariateGeneration() { final RecalibrationArgumentCollection RAC = new RecalibrationArgumentCollection(); final String[] readGroups = {"RG1", "RG2", "RGbla"}; - ReadGroupCovariate rgCov = new ReadGroupCovariate(Arrays.asList(readGroups)); - QualityScoreCovariate qsCov = new QualityScoreCovariate(RAC); - ContextCovariate coCov = new ContextCovariate(RAC); - CycleCovariate cyCov = new CycleCovariate(RAC); - - StandardCovariateList covariates = new StandardCovariateList(RAC, Arrays.asList(readGroups)); + ReadGroupCovariate rgCov = new ReadGroupCovariate(); + rgCov.initialize(RAC, Arrays.asList(readGroups)); + QualityScoreCovariate qsCov = new QualityScoreCovariate(); + qsCov.initialize(RAC, Arrays.asList(readGroups)); + ContextCovariate coCov = new ContextCovariate(); + coCov.initialize(RAC, Arrays.asList(readGroups)); + CycleCovariate cyCov = new CycleCovariate(); + cyCov.initialize(RAC, Arrays.asList(readGroups)); + + BQSRCovariateList covariates = new BQSRCovariateList(RAC, Arrays.asList(readGroups)); final int NUM_READS = 100; final Random rnd = Utils.getRandomGenerator(); diff --git a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ReadGroupCovariateUnitTest.java b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ReadGroupCovariateUnitTest.java index dd933dedc25..acb58bb09bb 100644 --- a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ReadGroupCovariateUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/ReadGroupCovariateUnitTest.java @@ -2,8 +2,10 @@ import htsjdk.samtools.SAMFileHeader; import htsjdk.samtools.SAMReadGroupRecord; +import org.broadinstitute.hellbender.GATKBaseTest; import org.broadinstitute.hellbender.utils.read.ArtificialReadUtils; import org.broadinstitute.hellbender.utils.read.GATKRead; +import org.broadinstitute.hellbender.utils.recalibration.RecalibrationArgumentCollection; import org.testng.Assert; import org.testng.annotations.Test; @@ -11,13 +13,14 @@ import java.util.List; import java.util.stream.Collectors; -public final class ReadGroupCovariateUnitTest { +public final class ReadGroupCovariateUnitTest extends GATKBaseTest { @Test public void testSingleRecord() { final String id = "MY.ID"; final String expected = "SAMPLE.1"; - final ReadGroupCovariate covariate = new ReadGroupCovariate(Arrays.asList(expected)); + final ReadGroupCovariate covariate = new ReadGroupCovariate(); + covariate.initialize(new RecalibrationArgumentCollection(), Arrays.asList(expected)); SAMReadGroupRecord rg = new SAMReadGroupRecord(id); rg.setPlatformUnit(expected); runTest(rg, expected, covariate); @@ -27,7 +30,8 @@ public void testSingleRecord() { public void testMaxValue() { final String id = "MY.ID"; final String expected = "SAMPLE.1"; - final ReadGroupCovariate covariate = new ReadGroupCovariate(Arrays.asList(expected)); + final ReadGroupCovariate covariate = new ReadGroupCovariate(); + covariate.initialize(new RecalibrationArgumentCollection(), Arrays.asList(expected)); SAMReadGroupRecord rg = new SAMReadGroupRecord(id); rg.setPlatformUnit(expected); Assert.assertEquals(covariate.maximumKeyValue(), 0);//there's just 1 read group, so 0 is the max value @@ -37,7 +41,8 @@ public void testMaxValue() { public void testReadGroupNames() { final String id = "MY.ID"; final String expected = "SAMPLE.1"; - final ReadGroupCovariate covariate = new ReadGroupCovariate(Arrays.asList(expected)); + final ReadGroupCovariate covariate = new ReadGroupCovariate(); + covariate.initialize(new RecalibrationArgumentCollection(), Arrays.asList(expected)); final SAMFileHeader headerWithGroups = ArtificialReadUtils.createArtificialSamHeaderWithGroups(1, 0, 100, 2); final List rgs = Arrays.asList("rg1", "rg2"); Assert.assertEquals(ReadGroupCovariate.getReadGroupIDs(headerWithGroups), headerWithGroups.getReadGroups().stream().map(rg -> ReadGroupCovariate.getReadGroupIdentifier(rg)).collect(Collectors.toList())); @@ -47,13 +52,16 @@ public void testReadGroupNames() { public void testMissingKey() { final String id = "MY.ID"; final String expected = "SAMPLE.1"; - final ReadGroupCovariate covariate = new ReadGroupCovariate(Arrays.asList(expected)); + final ReadGroupCovariate covariate = new ReadGroupCovariate(); + covariate.initialize(new RecalibrationArgumentCollection(), Arrays.asList(expected)); final String s = covariate.formatKey(1); } @Test() public void testMissingReadGroup() { - final ReadGroupCovariate covariate = new ReadGroupCovariate(Arrays.asList("SAMPLE.1")); + final String expected = "SAMPLE.1"; + final ReadGroupCovariate covariate = new ReadGroupCovariate(); + covariate.initialize(new RecalibrationArgumentCollection(), Arrays.asList(expected)); final int badKey = covariate.keyFromValue("bad_read_name"); Assert.assertEquals(badKey, ReadGroupCovariate.MISSING_READ_GROUP_KEY); } @@ -61,7 +69,8 @@ public void testMissingReadGroup() { @Test public void testMissingPlatformUnit() { final String expected = "MY.7"; - final ReadGroupCovariate covariate = new ReadGroupCovariate(Arrays.asList(expected)); + final ReadGroupCovariate covariate = new ReadGroupCovariate(); + covariate.initialize(new RecalibrationArgumentCollection(), Arrays.asList(expected)); SAMReadGroupRecord rg = new SAMReadGroupRecord(expected); runTest(rg, expected, covariate); } diff --git a/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/RepeatLengthCovariateUnitTest.java b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/RepeatLengthCovariateUnitTest.java new file mode 100644 index 00000000000..05dd2ac577b --- /dev/null +++ b/src/test/java/org/broadinstitute/hellbender/utils/recalibration/covariates/RepeatLengthCovariateUnitTest.java @@ -0,0 +1,66 @@ +package org.broadinstitute.hellbender.utils.recalibration.covariates; + +import org.apache.commons.lang3.tuple.Pair; +import org.broadinstitute.hellbender.GATKBaseTest; +import org.broadinstitute.hellbender.utils.recalibration.RecalibrationArgumentCollection; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class RepeatLengthCovariateUnitTest extends GATKBaseTest { + + @DataProvider + public static Object[][] FindTandemRepeatUnitsTestData() { + return new Object[][]{ + // Parentheses added for readability (removed in the test). + // Current implementation: the repeat unit is determined such that the character at the offset + // will be the *last* character in the repeat unit in the case of two letter repeats. + // In other words the repeat unit is determined up to cyclic permutations. + + // { readBaseString, offset, expectedRepeatUnit, expectedRepeatLength } + {"A_CA_C(A)_CA_CA_C", 4, "CA", 4}, + {"(A)_CA_CA_CA_CA_C", 0, "CA", 4}, + {"(A)_CTA_CTA_CTA_CTA_CT", 0, "CTA", 4}, + {"A_CTA_CT(A)_CTA_CTA_CT", 6, "CTA", 4}, + {"AC_TAC_TA(C)_TAC_TAC_T", 7, "TAC", 4}, + {"ACT_ACT_AC(T)_ACT_ACT", 8, "ACT", 5}, + {"A_CTA_GCT_AC(T)_ACT_ACT", 9, "ACT", 3}, + {"(A)AAAAAAAAA", 0, "A", 10}, + {"AAAA(A)AAAAA", 4, "A", 10}, + {"A_C_(A)AAAAA", 2, "A", 6}, + {"AAAAAAAAA(A)", 9, "A", 10}, + {"A(C)_TACT_TACT_ACTA_CTAC", 1, "TACT", 2}, + // this (ACT)*2 is missed. When max repeat count is 1, the base at offset+1 is chosen as the repeat unit. + {"AC(T)_ACT_TAC_TAC_TAC_TAC", 2, "A", 1}, + {"TTTTTT(C)_AAA_AAA_AAA", 6, "A", 9}, // Note that the repeat unit is "A", not "AAA" + {"TTT(T)TTC_AAA_AAA_AAA", 3, "T", 6}, + {"A_CTA_CT(G)_ACT_TAC_TAC_TAC_TAC", 6, "A", 1}, // "TAC" downstream is not recognized. + {"CT_ACT_ACT_A(C)T_ACT_ACT", 9, "TAC", 5}, // "ACT" is not detected, but "TAC" is, which is the same repeat up to cyclic permutation. + {"ATTT_(A)TTT_ATTT_CTTT", 4, "T", 3} + }; + } + + @Test(dataProvider="FindTandemRepeatUnitsTestData") + public void testFindTandemRepeatUnits(final String readBaseString, final int offset, + final String expectedRepeatUnit, final int expectedRepeatLength){ + // remove the characters that were put in to make the input string easier to read + final byte[] readBases = readBaseString.replaceAll("[()_]", "").getBytes(); + RepeatLengthCovariate repeatLengthCovariate = new RepeatLengthCovariate(); + repeatLengthCovariate.initialize(new RecalibrationArgumentCollection(), Arrays.asList("yo")); + + Pair ans = repeatLengthCovariate.findTandemRepeatUnits(readBases, offset); + byte[] repeatUnit = ans.getLeft(); + int repeatLength = ans.getRight(); + + // for debugging + String repeatUnitStr = new String(repeatUnit, StandardCharsets.UTF_8); + + Assert.assertEquals(repeatUnitStr, expectedRepeatUnit); + Assert.assertEquals(repeatLength, expectedRepeatLength); + } +} diff --git a/src/test/resources/org/broadinstitute/hellbender/tools/unsupported-covariates.table.gz b/src/test/resources/org/broadinstitute/hellbender/tools/nonstandard-covariates.table.gz similarity index 100% rename from src/test/resources/org/broadinstitute/hellbender/tools/unsupported-covariates.table.gz rename to src/test/resources/org/broadinstitute/hellbender/tools/nonstandard-covariates.table.gz