Skip to content

Commit c25c3ed

Browse files
q10facebook-github-bot
authored andcommitted
Fix memory overflow issue that shows up with when OutType is uint8_t and input_stride > output_stride (pytorch#5169)
Summary: X-link: facebookresearch/FBGEMM#2170 X-link: facebookresearch/FBGEMM#2166 In edge cases, the weight input_stride and output_stride could be different. In this case, out will move based on out_stride but copy data with size input_stride. If input_stride is larger than output_stride, the memcpy could overflow. this fix guarantees the safe memcpy, and is backward compatible for cases that input_stride <= output_stride. ==== The context is S592128, where the predictor crashes every so often bc of bad input being provided to memcpy, which causes memory overflow. The issues show up when nobag=true, OutType=uint8_t, and input_stride > output_stride. From what I can tell, if input_stride and output_stride are not provided to us (i.e. a -1 is passed), we compute it ourselves. But it is possible for users to have users pass in custom-sized tables to us, where input_stride is greater than output_stride. The code already has an assert beforehand to check for this case, but apparently, assert() can be disabled when -DNDEBUG is provided during compilation, which is probably what is being passed in when building the package and why we see the crash. Emma provided the initial patch, which is to copy just the min(input_stride, output_stride) amount of buffer per row. It provides backward compatibility with existing behavior, and just adjusts it so that the corner case doesnt crash the predictor. I have added the unit test to the patch, based on the existing EmbeddingSpMDMTest.basicTest. I initially tried to fit it into EmbeddingSpMDMTest.basicTest, but figured it would be better to have a standalone test bc it is testing a corner case, not adding more coordinates to the test matrix. Reviewed By: emlin, spcyppt Differential Revision: D87734295
1 parent 96409fa commit c25c3ed

File tree

3 files changed

+226
-4
lines changed

3 files changed

+226
-4
lines changed

src/RefImplementations.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,9 +1216,7 @@ bool EmbeddingSpMDM_ref(
12161216
if (output_stride == -1) {
12171217
output_stride = block_size;
12181218
}
1219-
if constexpr (isOutput8bit) {
1220-
assert(input_stride == output_stride);
1221-
}
1219+
12221220
vector<float> buf(block_size);
12231221

12241222
if constexpr (isWeight8bit) {
@@ -1241,8 +1239,15 @@ bool EmbeddingSpMDM_ref(
12411239
return false;
12421240
}
12431241
if constexpr (isOutput8bit) {
1242+
// In edge cases, input_stride can be larger than output_stride, and
1243+
// the assert macro will be a no-op if -DNDEBUG compilation flag is
1244+
// added (e.g. prod package). To prevent memory overflow and also be
1245+
// backward- compatible, we get the min value of input_stride and
1246+
// output_stride, and only copy the overlap part of data.
1247+
const auto copy_width = std::min(output_stride, input_stride);
12441248
const InType* input_row_ptr = input + input_stride * idx;
1245-
memcpy(out, input_row_ptr, sizeof(InType) * input_stride);
1249+
memcpy(out, input_row_ptr, sizeof(InType) * copy_width);
1250+
12461251
} else {
12471252
memset(buf.data(), 0, sizeof(float) * block_size);
12481253
const float* scale_bias = reinterpret_cast<const float*>(

test/EmbeddingSpMDMTest.cc

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,222 @@ TEST_P(EmbeddingSpMDMTest, basicTest) {
433433
} // end for input
434434
}
435435

436+
TEST_P(EmbeddingSpMDMTest, noBagUint8Test) {
437+
vector<vector<int>> inputs(GetInputs_());
438+
439+
random_device r;
440+
default_random_engine generator(r());
441+
uniform_int_distribution<> bool_dist(0, 1);
442+
443+
bool isIndex64b = bool_dist(generator);
444+
bool isOffset64b = bool_dist(generator);
445+
bool use_offsets = bool_dist(generator);
446+
auto [prefetch, weight_choice, corner_case, _tmp1, _tmp2] = GetParam();
447+
448+
// Fix the input and output types to be uint8_t
449+
const auto in_type = QINT8, out_type = QINT8;
450+
451+
// Skip corner cases for this focused test
452+
if (corner_case != NONE) {
453+
return;
454+
}
455+
456+
for (auto input : inputs) {
457+
int batch_size = input[0];
458+
int num_rows = input[1];
459+
int embedding_dim = input[2];
460+
461+
// Set output_stride < input_stride to test the bug fix
462+
// This is the edge case where memory overflow could occur
463+
int output_stride = embedding_dim;
464+
int input_stride = embedding_dim * 2 + 3;
465+
466+
// Create embedding table with uint8_t data
467+
vector<uint8_t> embedding_table_qint8(num_rows * input_stride);
468+
uniform_int_distribution<int> uint8_dist(0, 255);
469+
for (int i = 0; i < num_rows; ++i) {
470+
for (int j = 0; j < embedding_dim; ++j) {
471+
embedding_table_qint8[i * input_stride + j] =
472+
static_cast<uint8_t>(uint8_dist(generator));
473+
}
474+
}
475+
476+
// For no_bag case, each index produces one output row
477+
// So output_size = number of indices = batch_size for simplicity
478+
int output_size = batch_size;
479+
480+
vector<int64_t> indices;
481+
vector<int32_t> indices_32;
482+
for (int i = 0; i < output_size; ++i) {
483+
uniform_int_distribution<> row_dist(0, num_rows - 1);
484+
int64_t idx = row_dist(generator);
485+
indices.push_back(idx);
486+
indices_32.push_back(static_cast<int32_t>(idx));
487+
}
488+
489+
// For no_bag, we still need offsets/lengths but they're just 0,1,2,3...
490+
vector<int64_t> offsets, lengths;
491+
vector<int32_t> offsets_32, lengths_32;
492+
if (use_offsets) {
493+
for (int i = 0; i <= output_size; ++i) {
494+
offsets.push_back(i);
495+
offsets_32.push_back(i);
496+
}
497+
} else {
498+
for (int i = 0; i < output_size; ++i) {
499+
lengths.push_back(1);
500+
lengths_32.push_back(1);
501+
}
502+
}
503+
504+
const int64_t* offsets_or_lengths =
505+
(use_offsets ? offsets : lengths).data();
506+
const int32_t* offsets_or_lengths_32 =
507+
(use_offsets ? offsets_32 : lengths_32).data();
508+
509+
int output_size_wo_sentries = output_size * output_stride;
510+
constexpr int num_sentries = 10;
511+
512+
vector<uint8_t> output_ref(output_size_wo_sentries + num_sentries);
513+
vector<uint8_t> output(output_size_wo_sentries + num_sentries);
514+
515+
// Initialize sentries
516+
const uint8_t sentry_value = 0xFF;
517+
for (size_t i = output_size_wo_sentries; i < output.size(); ++i) {
518+
output_ref[i] = sentry_value;
519+
output[i] = sentry_value;
520+
}
521+
522+
bool success = false, success_ref = false;
523+
524+
#define TEST_NOBAG_BASE( \
525+
table, \
526+
indices, \
527+
offsets_or_lengths, \
528+
output_ref, \
529+
output, \
530+
InType, \
531+
IndexType, \
532+
OffsetType, \
533+
OutType) \
534+
success_ref = EmbeddingSpMDM_ref( \
535+
embedding_dim, \
536+
output_size, \
537+
output_size, \
538+
num_rows, \
539+
table.data(), \
540+
indices.data(), \
541+
offsets_or_lengths, \
542+
nullptr, /* weights */ \
543+
false, /* normalize_by_lengths */ \
544+
output_ref.data(), \
545+
false, /* is_wt_positional */ \
546+
use_offsets, \
547+
output_stride, \
548+
input_stride, \
549+
true, /* scale_bias_last */ \
550+
true, /* no_bag */ \
551+
false, /* is_output_bfloat16 */ \
552+
false /* isBf16 */); \
553+
\
554+
auto kernel = GenerateEmbeddingSpMDMWithStrides< \
555+
InType, \
556+
IndexType, \
557+
OffsetType, \
558+
OutType>( \
559+
embedding_dim, \
560+
false, /* has_weight */ \
561+
false, /* normalize_by_lengths */ \
562+
prefetch, \
563+
false, /* is_wt_positional */ \
564+
use_offsets, \
565+
output_stride, \
566+
input_stride, \
567+
true, /* scale_bias_last */ \
568+
true, /* no_bag */ \
569+
false, /* is_bf16_out */ \
570+
false /* is_bf16_in */); \
571+
success = kernel( \
572+
output_size, \
573+
output_size, \
574+
num_rows, \
575+
table.data(), \
576+
indices.data(), \
577+
offsets_or_lengths, \
578+
nullptr /* weights */, \
579+
output.data());
580+
581+
#define TEST_NOBAG_OFFSET_TYPE(table, indices, InType, IndexType) \
582+
if (isOffset64b) { \
583+
TEST_NOBAG_BASE( \
584+
table, \
585+
indices, \
586+
offsets_or_lengths, \
587+
output_ref, \
588+
output, \
589+
InType, \
590+
IndexType, \
591+
int64_t, \
592+
uint8_t); \
593+
} else { \
594+
TEST_NOBAG_BASE( \
595+
table, \
596+
indices, \
597+
offsets_or_lengths_32, \
598+
output_ref, \
599+
output, \
600+
InType, \
601+
IndexType, \
602+
int32_t, \
603+
uint8_t); \
604+
}
605+
606+
#define TEST_NOBAG_INDEX_TYPE(table, InType) \
607+
if (isIndex64b) { \
608+
TEST_NOBAG_OFFSET_TYPE(table, indices, InType, int64_t); \
609+
} else { \
610+
TEST_NOBAG_OFFSET_TYPE(table, indices_32, InType, int32_t); \
611+
}
612+
613+
TEST_NOBAG_INDEX_TYPE(embedding_table_qint8, uint8_t);
614+
615+
#undef TEST_NOBAG_INDEX_TYPE
616+
#undef TEST_NOBAG_OFFSET_TYPE
617+
#undef TEST_NOBAG_BASE
618+
619+
// Check correctness
620+
EXPECT_EQ(success, success_ref)
621+
<< "Reference and JIT impl did not both succeed";
622+
EXPECT_TRUE(success) << "Both implementations should succeed";
623+
624+
if (success) {
625+
// Verify the output data
626+
for (int i = 0; i < output_size; ++i) {
627+
for (int j = 0; j < embedding_dim; ++j) {
628+
int offset = i * output_stride + j;
629+
EXPECT_EQ(output[offset], output_ref[offset])
630+
<< "results differ at (" << i << ", " << j
631+
<< ") reference: " << static_cast<int>(output_ref[offset])
632+
<< ", FBGEMM: " << static_cast<int>(output[offset])
633+
<< " emb dim: " << embedding_dim;
634+
}
635+
}
636+
637+
// Verify sentries weren't overwritten (tests for buffer overflow)
638+
for (int offset = output_size_wo_sentries;
639+
offset < output_size_wo_sentries + num_sentries;
640+
++offset) {
641+
EXPECT_EQ(output[offset], sentry_value)
642+
<< "Sentry value corrupted at offset " << offset
643+
<< " - potential buffer overflow! Got: "
644+
<< static_cast<int>(output[offset])
645+
<< " expected: " << static_cast<int>(sentry_value);
646+
EXPECT_EQ(output_ref[offset], sentry_value);
647+
}
648+
}
649+
} // end for input
650+
}
651+
436652
TEST_P(rowwiseSparseEmbeddingSpMDMTest, rowwiseSparseTest) {
437653
vector<vector<int>> inputs(GetInputs_());
438654

test/EmbeddingSpMDMTestUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ enum EmbeddingSpMDMDtypeChoice {
3030
FLOAT,
3131
FLOAT16,
3232
BFLOAT16,
33+
QINT8,
3334
};
3435

3536
using EmbeddingSpMDMInputDtypeChoice = EmbeddingSpMDMDtypeChoice;

0 commit comments

Comments
 (0)