Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,31 @@ print(generated_data)
}
```

### Example Enum
```python
color = {
"type": "object",
"properties": {
"color": {
"type": "enum",
"values": [
"black",
"red",
"white",
"green",
"blue"
]
}
}
}
```

```python
{
color: "blue"
}
```

## Features

- Bulletproof JSON generation: Jsonformer ensures that the generated JSON is always syntactically correct and conforms to the specified schema.
Expand Down
71 changes: 71 additions & 0 deletions jsonformer/logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,30 @@ def __call__(
return True

return False


class EnumStoppingCriteria(StoppingCriteria):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
prompt_length: int,
enums
):
self.tokenizer = tokenizer
self.prompt_length = prompt_length
self.enums = enums

def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
) -> bool:
decoded = self.tokenizer.decode(
input_ids[0][self.prompt_length :], skip_special_tokens=True
)

return decoded in self.enums


class OutputNumbersTokens(LogitsWarper):
def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str):
Expand All @@ -82,3 +106,50 @@ def __call__(self, _, scores):
scores[~mask] = -float("inf")

return scores


class OutputEnumTokens(LogitsWarper):
def __init__(self, tokenizer: PreTrainedTokenizer, enums):
self.tokenizer = tokenizer
vocab_size = len(tokenizer)
self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool)
self.tree = self.build_tree(enums)
self.is_first_call = True
self.vocab_size = len(tokenizer)

def create_mask(self, allowed_tokens):
allowed_mask = torch.zeros(self.vocab_size, dtype=torch.bool)
for _, token_id in self.tokenizer.get_vocab().items():
if token_id in allowed_tokens:
allowed_mask[token_id] = True
return allowed_mask

def build_tree(self, enums):
tree = {}
for enum in enums:
encoded_enum = self.tokenizer.encode(enum, add_special_tokens=False)
curr_obj = tree
for code in encoded_enum:
if code in curr_obj.keys():
curr_obj = curr_obj[code]
else:
curr_obj[code] = {}
curr_obj = curr_obj[code]
return tree

def __call__(self, input_ids, scores):
if not self.is_first_call:
self.tree = self.tree[int(input_ids[0][-1])]
else:
self.is_first_call = False

allowed_tokens = self.tree.keys()

if not len(allowed_tokens):
raise Exception("Shouldn't happen")

allowed_mask = self.create_mask(allowed_tokens)
mask = allowed_mask.expand_as(scores)
scores[~mask] = -float("inf")
return scores

50 changes: 50 additions & 0 deletions jsonformer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
NumberStoppingCriteria,
OutputNumbersTokens,
StringStoppingCriteria,
EnumStoppingCriteria,
OutputEnumTokens
)
from termcolor import cprint
from transformers import PreTrainedModel, PreTrainedTokenizer
Expand Down Expand Up @@ -138,6 +140,47 @@ def generate_string(self) -> str:
return response

return response.split('"')[0].strip()

def generate_enum(self, values) -> str:
prompt = self.get_prompt()
self.debug("[generate_enum]", prompt, is_prompt=True)
input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.model.device
)
values = [f'"{value}"'if isinstance(value,str) else str(value) for value in values]

response = self.model.generate(
input_tokens,
max_new_tokens=max([len(self.tokenizer.encode(value, add_special_tokens=False)) for value in values]),
num_return_sequences=1,
temperature=self.temperature,
logits_processor=[OutputEnumTokens(self.tokenizer, values)],
stopping_criteria=[
EnumStoppingCriteria(self.tokenizer, len(input_tokens[0]), values)
],
pad_token_id=self.tokenizer.eos_token_id,
)

# Some models output the prompt as part of the response
# This removes the prompt from the response if it is present
if (
len(response[0]) >= len(input_tokens[0])
and (response[0][: len(input_tokens[0])] == input_tokens).all()
):
response = response[0][len(input_tokens[0]) :]
if response.shape[0] == 1:
response = response[0]

response = self.tokenizer.decode(response, skip_special_tokens=True)

self.debug("[generate_enum]", "|" + response + "|")

if response[0] == response[-1] == '"':
return response[1:-1]

if '.' in response:
return float(response)
return int(response)

def generate_object(
self, properties: Dict[str, Any], obj: Dict[str, Any]
Expand All @@ -146,6 +189,7 @@ def generate_object(
self.debug("[generate_object] generating value for", key)
obj[key] = self.generate_value(schema, obj, key)
return obj


def generate_value(
self,
Expand Down Expand Up @@ -183,6 +227,12 @@ def generate_value(
else:
obj.append(new_obj)
return self.generate_object(schema["properties"], new_obj)
elif schema_type == "enum":
if key:
obj[key] = self.generation_marker
else:
obj.append(self.generation_marker)
return self.generate_enum(schema["values"])
else:
raise ValueError(f"Unsupported schema type: {schema_type}")

Expand Down