Skip to content

Commit 511c99d

Browse files
author
Tianyu Gao
committedOct 10, 2022
Merge branch 'main' of github.com:princeton-nlp/SimCSE
2 parents 8f3ef2f + d868602 commit 511c99d

File tree

5 files changed

+62
-7
lines changed

5 files changed

+62
-7
lines changed
 

‎.github/workflows/stale.yml

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
2+
#
3+
# You can adjust the behavior by modifying this file.
4+
# For more information, see:
5+
# https://github.com/actions/stale
6+
name: Mark stale issues and pull requests
7+
8+
on:
9+
schedule:
10+
- cron: '18 9 * * *'
11+
12+
jobs:
13+
stale:
14+
15+
runs-on: ubuntu-latest
16+
permissions:
17+
issues: write
18+
pull-requests: write
19+
20+
steps:
21+
- uses: actions/stale@v5
22+
with:
23+
repo-token: ${{ secrets.GITHUB_TOKEN }}
24+
stale-issue-message: 'Stale issue message'
25+
stale-pr-message: 'Stale pull request message'
26+
stale-issue-label: 'no-issue-activity'
27+
stale-pr-label: 'no-pr-activity'
28+
days-before-stale: 30
29+
days-before-close: 5

‎requirements.txt

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
transformers==4.2.1
2-
scipy==1.5.4
3-
datasets==1.2.1
4-
pandas==1.1.5
5-
scikit-learn==0.24.0
6-
prettytable==2.1.0
2+
scipy
3+
datasets
4+
pandas
5+
scikit-learn
6+
prettytable
77
gradio
88
torch
9-
setuptools==49.3.0
9+
setuptools

‎simcse/tool.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,31 @@ def build_index(self, sentences_or_file_path: Union[str, List[str]],
176176
self.is_faiss_index = False
177177
self.index["index"] = index
178178
logger.info("Finished")
179+
180+
def add_to_index(self, sentences_or_file_path: Union[str, List[str]],
181+
device: str = None,
182+
batch_size: int = 64):
183+
184+
# if the input sentence is a string, we assume it's the path of file that stores various sentences
185+
if isinstance(sentences_or_file_path, str):
186+
sentences = []
187+
with open(sentences_or_file_path, "r") as f:
188+
logging.info("Loading sentences from %s ..." % (sentences_or_file_path))
189+
for line in tqdm(f):
190+
sentences.append(line.rstrip())
191+
sentences_or_file_path = sentences
192+
193+
logger.info("Encoding embeddings for sentences...")
194+
embeddings = self.encode(sentences_or_file_path, device=device, batch_size=batch_size, normalize_to_unit=True, return_numpy=True)
195+
196+
if self.is_faiss_index:
197+
self.index["index"].add(embeddings.astype(np.float32))
198+
else:
199+
self.index["index"] = np.concatenate((self.index["index"], embeddings))
200+
self.index["sentences"] += sentences_or_file_path
201+
logger.info("Finished")
202+
203+
179204

180205
def search(self, queries: Union[str, List[str]],
181206
device: str = None,
@@ -186,7 +211,7 @@ def search(self, queries: Union[str, List[str]],
186211
if isinstance(queries, list):
187212
combined_results = []
188213
for query in queries:
189-
results = self.search(query, device)
214+
results = self.search(query, device, threshold, top_k)
190215
combined_results.append(results)
191216
return combined_results
192217

‎slides/emnlp2021_slides.pdf

13.8 MB
Binary file not shown.

‎train.py

+1
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ def mask_tokens(
500500
"""
501501
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
502502
"""
503+
inputs = inputs.clone()
503504
labels = inputs.clone()
504505
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
505506
probability_matrix = torch.full(labels.shape, self.mlm_probability)

0 commit comments

Comments
 (0)
Please sign in to comment.