Skip to content

Commit 7244cf9

Browse files
committed
Removes expit normalization from OnnxDetectionPredictor
While a custom post-processor can reverse the normalization back to the original values via logit, it would make more sense to leave normalization to the post-processor. This way the post-processor gets the raw values from the model, which is easier to reason with, instead of being forced to run `logit` on everything to get raw values... Additionally, with this change there should be fewer expit calls, as we only need it for the box prediction scores now. THIS IS A BACKWARDS COMPATIBILITY BREAKAGE! If a user made a custom detection post-processor, it will start getting different values in the buffer than before!
1 parent 550178f commit 7244cf9

File tree

6 files changed

+57
-14
lines changed

6 files changed

+57
-14
lines changed

pdfocr-onnxtr/src/main/java/com/itextpdf/pdfocr/onnxtr/detection/IDetectionPostProcessor.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ This file is part of the iText (R) project.
2929
import java.util.List;
3030

3131
/**
32-
* Interface for post-processors, which convert normalized, but still raw output
33-
* of an ML model and returns rotated boxes with the detected objects.
32+
* Interface for post-processors, which convert raw output of an ML model and
33+
* returns rotated boxes with the detected objects.
3434
*
3535
* <p>
3636
* Output point arrays should represent a rectangle and contain 4 points. Order
@@ -48,7 +48,7 @@ public interface IDetectionPostProcessor {
4848
* detected objects.
4949
*
5050
* @param input input image, which was used to produce the inputs to the ML model
51-
* @param output normalized output of the ML model
51+
* @param output output of the ML model
5252
*
5353
* @return a list of detected objects. See interface documentation for more information
5454
*/

pdfocr-onnxtr/src/main/java/com/itextpdf/pdfocr/onnxtr/detection/OnnxDetectionPostProcessor.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,13 @@ public class OnnxDetectionPostProcessor implements IDetectionPostProcessor {
5656
* Threshold value used, when binarizing a monochromatic image. If pixel
5757
* value is greater or equal to the threshold, it is mapped to 1, otherwise
5858
* it is mapped to 0.
59+
*
60+
* <p>
61+
* This is not the original binarization threshold, provided by the user,
62+
* but the value of logit on it. This is so we don't need to run expit
63+
* over the model output for binarization.
5964
*/
60-
private final float binarizationThreshold;
65+
private final float binarizationThresholdLogit;
6166
/**
6267
* Score threshold for a detected box. If score is lower than this value,
6368
* the box gets discarded.
@@ -73,7 +78,7 @@ public class OnnxDetectionPostProcessor implements IDetectionPostProcessor {
7378
* the box gets discarded
7479
*/
7580
public OnnxDetectionPostProcessor(float binarizationThreshold, float scoreThreshold) {
76-
this.binarizationThreshold = binarizationThreshold;
81+
this.binarizationThresholdLogit = (float) MathUtil.logit(MathUtil.clamp(binarizationThreshold, 0., 1.));
7782
this.scoreThreshold = scoreThreshold;
7883
}
7984

@@ -96,7 +101,7 @@ public List<Point[]> process(BufferedImage input, FloatBufferMdArray output) {
96101
// or use a smaller mask with only the contour. Though based on profiling, it doesn't look
97102
// like it is that bad, when it is only once per input image.
98103
try (final Mat scoreMask = new Mat(height, width, CvType.CV_8U, new Scalar(0));
99-
final MatVector contours = findTextContours(output, binarizationThreshold)) {
104+
final MatVector contours = findTextContours(output, binarizationThresholdLogit)) {
100105
final long contourCount = contours.size();
101106
for (long contourIdx = 0; contourIdx < contourCount; ++contourIdx) {
102107
try (final Mat contour = contours.get(contourIdx);
@@ -159,7 +164,8 @@ private static float getPredictionScore(
159164
/*
160165
* Algorithm here is pretty simple. We go over all the points, painted
161166
* by the contour shape, and calculate the mean prediction score
162-
* value over the original normalized output array.
167+
* value over the original output array, values of which we normalize
168+
* via expit.
163169
*/
164170
final FloatBufferMdArray hwMdArray = predictions.getSubArray(0);
165171
final int height = hwMdArray.getDimension(0);
@@ -180,7 +186,7 @@ private static float getPredictionScore(
180186
if (maskIndexer.get(y, x) != 1) {
181187
continue;
182188
}
183-
final float prediction = predictionsRow.getScalar(x);
189+
final float prediction = MathUtil.expit(predictionsRow.getScalar(x));
184190
if (prediction > 0) {
185191
sum += prediction;
186192
++nonZeroCount;

pdfocr-onnxtr/src/main/java/com/itextpdf/pdfocr/onnxtr/detection/OnnxDetectionPredictor.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,6 @@ protected FloatBufferMdArray toInputBuffer(List<BufferedImage> batch) {
188188
@Override
189189
protected List<List<Point[]>> fromOutputBuffer(List<BufferedImage> inputBatch, FloatBufferMdArray outputBatch) {
190190
final IDetectionPostProcessor postProcessor = properties.getPostProcessor();
191-
// Normalizing pixel values via a sigmoid expit function
192-
final float[] outputBuffer = outputBatch.getData().array();
193-
int offset = outputBatch.getArrayOffset();
194-
for (int i = offset; i < offset + outputBatch.getArraySize(); ++i) {
195-
outputBuffer[i] = MathUtil.expit(outputBuffer[i]);
196-
}
197191
final List<List<Point[]>> batchTextBoxes = new ArrayList<>(inputBatch.size());
198192
for (int i = 0; i < inputBatch.size(); ++i) {
199193
final BufferedImage image = inputBatch.get(i);

pdfocr-onnxtr/src/main/java/com/itextpdf/pdfocr/onnxtr/exceptions/PdfOcrOnnxTrExceptionMessageConstant.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ public final class PdfOcrOnnxTrExceptionMessageConstant {
5757
public static final String UNEXPECTED_SHAPE_SIZE = "Shape should be a {0}-element array (BCHW).";
5858
public static final String UNEXPECTED_STD_CHANNEL_COUNT = "Std should be a {0}-element array.";
5959
public static final String VALUES_SHOULD_BE_A_NON_EMPTY_ARRAY = "Values should be a non-empty array.";
60+
public static final String X_SHOULD_BE_IN_0_1_RANGE = "X should be in [0; 1] range.";
6061

6162
private PdfOcrOnnxTrExceptionMessageConstant() {
6263
// Private constructor will prevent the instantiation of this class directly.

pdfocr-onnxtr/src/main/java/com/itextpdf/pdfocr/onnxtr/util/MathUtil.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,39 @@ public static float expit(float x) {
114114
return (float) (1 / (1 + Math.exp(-x)));
115115
}
116116

117+
/**
118+
* Computes the logit function, which is the inverse of expit, for the given input.
119+
*
120+
* @param x the input value
121+
*
122+
* @return the logit of the input value
123+
*/
124+
public static double logit(double x) {
125+
if (0 < x && x < 1) {
126+
return Math.log(x / (1.0 - x));
127+
}
128+
if (x == 0F) {
129+
return Float.NEGATIVE_INFINITY;
130+
}
131+
if (x == 1F) {
132+
return Float.POSITIVE_INFINITY;
133+
}
134+
throw new IllegalArgumentException(
135+
PdfOcrOnnxTrExceptionMessageConstant.X_SHOULD_BE_IN_0_1_RANGE
136+
);
137+
}
138+
139+
/**
140+
* Computes the logit function, which is the inverse of expit, for the given input.
141+
*
142+
* @param x the input value
143+
*
144+
* @return the logit of the input value
145+
*/
146+
public static float logit(float x) {
147+
return (float) logit((double) x);
148+
}
149+
117150
/**
118151
* Computes the Euclidean modulo (non-negative remainder) of {@code x} modulo {@code y}.
119152
*

pdfocr-onnxtr/src/test/java/com/itextpdf/pdfocr/onnxtr/util/MathUtilTest.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ public void clampWithValidArgs() {
6363
Assertions.assertEquals(1.9, MathUtil.clamp(2.0, 1.1, 1.9));
6464
}
6565

66+
@Test
67+
public void logitTest() {
68+
Assertions.assertThrows(IllegalArgumentException.class, () -> MathUtil.logit(-0.1F));
69+
Assertions.assertEquals(Float.NEGATIVE_INFINITY, MathUtil.logit(0F));
70+
Assertions.assertEquals(0F, MathUtil.logit(0.5F));
71+
Assertions.assertEquals(Float.POSITIVE_INFINITY, MathUtil.logit(1F));
72+
Assertions.assertThrows(IllegalArgumentException.class, () -> MathUtil.logit(1.1F));
73+
}
74+
6675
@Test
6776
public void levenshteinDistanceTest(){
6877
Assertions.assertEquals(5, MathUtil.calculateLevenshteinDistance("kitten", "meat"));

0 commit comments

Comments
 (0)