Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions jsonformer/logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,54 @@ def __call__(self, _, scores):
scores[~mask] = -float("inf")

return scores

class IntegerStoppingCriteria(StoppingCriteria):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
prompt_length: int,
max_digits: int = 15,
):
self.tokenizer = tokenizer
self.prompt_length = prompt_length
self.max_digits = max_digits

def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
) -> bool:
decoded = self.tokenizer.decode(
input_ids[0][self.prompt_length :], skip_special_tokens=True
)

if len(decoded.strip()) > self.max_digits:
return True

if (
len(decoded) > 1
and any(c.isdigit() for c in decoded)
and decoded[-1] in [" ", "\n"]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to check all chars to account for the case where the last sampled token was something like 1<space>h

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be necessary since OutputIntegersTokens only allows tokens consisting of digits with optional leading and trailing whitespace.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was just doing some more testing with this and am now running into some bugs when these features are used with llama-7b and dolly-v2-12b.

):
return True

return False

class OutputIntegersTokens(LogitsWarper):
def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str):
self.tokenizer = tokenizer
self.tokenized_prompt = tokenizer(prompt, return_tensors="pt")
vocab_size = len(tokenizer)
self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool)

for _, token_id in tokenizer.get_vocab().items():
token_str = tokenizer.decode(token_id).strip()

if token_str == "" or all(c.isdigit() for c in token_str):
self.allowed_mask[token_id] = True

def __call__(self, _, scores):
mask = self.allowed_mask.expand_as(scores)
scores[~mask] = -float("inf")

return scores
81 changes: 75 additions & 6 deletions jsonformer/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import List, Union, Dict, Any
from typing import List, Set, Union, Dict, Any

from jsonformer.logits_processors import (
NumberStoppingCriteria,
OutputNumbersTokens,
IntegerStoppingCriteria,
OutputIntegersTokens,
StringStoppingCriteria,
)
from termcolor import cprint
from transformers import PreTrainedModel, PreTrainedTokenizer
import json
import torch

GENERATION_MARKER = "|GENERATION|"

Expand All @@ -34,6 +37,7 @@ def __init__(
self.prompt = prompt

self.number_logit_processor = OutputNumbersTokens(self.tokenizer, self.prompt)
self.integer_logit_processor = OutputIntegersTokens(self.tokenizer, self.prompt)

self.generation_marker = "|GENERATION|"
self.debug_on = debug
Expand Down Expand Up @@ -82,6 +86,36 @@ def generate_number(self, temperature: Union[float, None] = None, iterations=0):

return self.generate_number(temperature=self.temperature * 1.3)

def generate_integer(self, temperature: Union[float, None] = None, iterations=0):
prompt = self.get_prompt()
self.debug("[generate_number]", prompt, is_prompt=True)
input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.model.device
)
response = self.model.generate(
input_tokens,
max_new_tokens=self.max_number_tokens,
num_return_sequences=1,
logits_processor=[self.integer_logit_processor],
stopping_criteria=[
IntegerStoppingCriteria(self.tokenizer, len(input_tokens[0]))
],
temperature=temperature or self.temperature,
pad_token_id=self.tokenizer.eos_token_id,
)
response = self.tokenizer.decode(response[0], skip_special_tokens=True)

response = response[len(prompt) :]
response = response.strip()
self.debug("[generate_integer]", response)
try:
return int(response)
except ValueError:
if iterations > 3:
raise ValueError("Failed to generate a valid integer")

return self.generate_integer(temperature=self.temperature * 1.3)

def generate_boolean(self) -> bool:
prompt = self.get_prompt()
self.debug("[generate_boolean]", prompt, is_prompt=True)
Expand All @@ -90,11 +124,8 @@ def generate_boolean(self) -> bool:
output = self.model.forward(input_tensor.to(self.model.device))
logits = output.logits[0, -1]

# todo: this assumes that "true" and "false" are both tokenized to a single token
# this is probably not true for all tokenizers
# this can be fixed by looking at only the first token of both "true" and "false"
true_token_id = self.tokenizer.convert_tokens_to_ids("true")
false_token_id = self.tokenizer.convert_tokens_to_ids("false")
true_token_id = self.tokenizer.encode("true", return_tensors="pt")[0, 0]
false_token_id = self.tokenizer.encode("false", return_tensors="pt")[0, 0]

result = logits[true_token_id] > logits[false_token_id]

Expand Down Expand Up @@ -139,6 +170,32 @@ def generate_string(self) -> str:

return response.split('"')[0].strip()

def generate_enum(self, enum_values: Set[str]) -> str:
prompt = self.get_prompt()
self.debug("[generate_enum]", prompt, is_prompt=True)
prompt_tokens = self.tokenizer.encode(prompt, return_tensors="pt")

highest_probability = 0.0
best_option = None
for option in enum_values:
option_tokens = self.tokenizer.encode(f'"{option}"', return_tensors="pt")
n_option_tokens = option_tokens.shape[1]
prompt_option_tokens = torch.concat([prompt_tokens, option_tokens], dim=1)

with torch.no_grad():
logits = self.model.forward(prompt_option_tokens.to(self.model.device)).logits[0, -n_option_tokens-1:-1]
probabilities = torch.softmax(logits, dim=1)
option_token_probabilities = probabilities[torch.arange(probabilities.shape[0]), option_tokens]
option_probability = torch.prod(option_token_probabilities).item()

if option_probability > highest_probability:
best_option = option
highest_probability = option_probability

self.debug("[generate_enum]", best_option)

return best_option

def generate_object(
self, properties: Dict[str, Any], obj: Dict[str, Any]
) -> Dict[str, Any]:
Expand All @@ -160,6 +217,12 @@ def generate_value(
else:
obj.append(self.generation_marker)
return self.generate_number()
elif schema_type == "integer":
if key:
obj[key] = self.generation_marker
else:
obj.append(self.generation_marker)
return self.generate_integer()
elif schema_type == "boolean":
if key:
obj[key] = self.generation_marker
Expand All @@ -172,6 +235,12 @@ def generate_value(
else:
obj.append(self.generation_marker)
return self.generate_string()
elif schema_type == "enum":
if key:
obj[key] = self.generation_marker
else:
obj.append(self.generation_marker)
return self.generate_enum(set(schema["values"]))
elif schema_type == "array":
new_array = []
obj[key] = new_array
Expand Down