Skip to content

[Fix] Retriever tokenization function in atlas.py needs correction #20

@silencio94

Description

@silencio94

When the code runs, the maximum passage length becomes the smaller of the two variables, self.opt.text_maxlength and gpu_embedder_batch_size. By default, gpu_embedder_batch_size is set to 512, and if you run the code without modifying default option, most BERT-style dual encoders will work without issues (see line 74).

However, if you reduce gpu_embedder_batch_size to conserve GPU memory, unexpected results can occur without warning.

atlas/src/atlas.py

Lines 61 to 89 in f8bec5c

@torch.no_grad()
def build_index(self, index, passages, gpu_embedder_batch_size, logger=None):
n_batch = math.ceil(len(passages) / gpu_embedder_batch_size)
retrieverfp16 = self._get_fp16_retriever_copy()
total = 0
for i in range(n_batch):
batch = passages[i * gpu_embedder_batch_size : (i + 1) * gpu_embedder_batch_size]
batch = [self.opt.retriever_format.format(**example) for example in batch]
batch_enc = self.retriever_tokenizer(
batch,
padding="longest",
return_tensors="pt",
max_length=min(self.opt.text_maxlength, gpu_embedder_batch_size),
truncation=True,
)
embeddings = retrieverfp16(**_to_cuda(batch_enc), is_passages=True)
index.embeddings[:, total : total + len(embeddings)] = embeddings.T
total += len(embeddings)
if i % 500 == 0 and i > 0:
logger.info(f"Number of passages encoded: {total}")
dist_utils.barrier()
logger.info(f"{total} passages encoded on process: {dist_utils.get_rank()}")
if not index.is_index_trained():
logger.info(f"Building faiss indices")
index.train_index()

So, it is recommended to modify line 74 as follows (as done in other parts of the code):

min(self.opt.text_maxlength, BERT_MAX_SEQ_LENGTH),

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions