Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions l2p/domain_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,32 +289,28 @@ def extract_pddl_action(
llm_response, syntax_validator.unsupported_keywords
)
)
elif e == "invalid_param_types":
elif e == "invalid_param_types" and types:
validation_info = syntax_validator.validate_params(
action["params"], types
)
elif e == "invalid_predicate_name":
elif e == "invalid_predicate_name" and types:
validation_info = (
syntax_validator.validate_types_predicates(
new_predicates, types
)
)
elif e == "invalid_predicate_format":
elif e == "invalid_predicate_format" and types:
validation_info = (
syntax_validator.validate_format_predicates(
predicates, types
)
)
elif e == "invalid_predicate_usage":
elif e == "invalid_predicate_usage" and types:
validation_info = (
syntax_validator.validate_usage_predicates(
llm_response, predicates, types
)
)
else:
raise NotImplementedError(
f"Validation type '{e}' is not implemented."
)

if not validation_info[0]:
return action, new_predicates, llm_response, validation_info
Expand Down Expand Up @@ -759,7 +755,7 @@ def get_predicates(self):
def generate_domain(
self,
domain: str,
types: str,
types: str | None,
predicates: str,
actions: list[Action],
requirements: list[str],
Expand All @@ -783,7 +779,8 @@ def generate_domain(
indent(string=f"(:requirements\n {' '.join(requirements)})", level=1)
+ "\n\n"
)
desc += f" (:types \n{indent(string=types, level=2)}\n )\n\n"
if types: # Only add types if not None or empty string
desc += f" (:types \n{indent(string=types, level=2)}\n )\n\n"
desc += f" (:predicates \n{indent(string=predicates, level=2)}\n )"
desc += self.action_descs(actions)
desc += "\n)"
Expand Down
2 changes: 2 additions & 0 deletions l2p/llm_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def valid_models(self) -> list[str]:
"gpt-4-32k",
"gpt-4o",
"gpt-4o-mini",
"o1",
"o3-mini"
]


Expand Down
30 changes: 26 additions & 4 deletions l2p/utils/pddl_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,27 @@ def __init__(self, error_types=None, unsupported_keywords=None):
# PARAMETER CHECKS

def validate_params(
self, parameters: OrderedDict, types: dict[str, str]
self, parameters: OrderedDict, types: dict[str, str] | None
) -> tuple[bool, str]:
"""Checks whether a PDDL action parameter contains types found in object types."""

types = types or {}

# If no types are defined, inform the user and check for parameter types
if not types:
for param_name, param_type in parameters.items():
if param_type is not None and param_type != "":
feedback_msg = (
f'The parameter `{param_name}` has an object type `{param_type}` '
'while no types are defined. Please remove the object type from this parameter.'
)
return False, feedback_msg

# if all parameter names do not contain a type
return True, "PASS: All parameters are valid."

for param_name in parameters:
param_type = parameters[param_name]
# Otherwise, check that parameter types are valid in the given types
for param_name, param_type in parameters.items():

if not any(param_type in t for t in types.keys()):
feedback_msg = f'There is an invalid object type `{param_type}` for the parameter {param_name} not found in the types {types.keys()}. If you need to use a new type, you can emulate it with an "is_{{type}} ?o - object" precondition. Please revise the PDDL model to fix this error.'
Expand All @@ -54,9 +69,16 @@ def validate_params(
# PREDICATE CHECKS

def validate_types_predicates(
self, predicates: list[dict], types: dict[str, str]
self, predicates: list[dict], types: dict[str, str] | None
) -> tuple[bool, str]:
"""Check if predicate name is found within any type definitions"""

# Handle the case where types is None or empty
types = types or {}

if not types:
feedback_msg = "PASS: All predicate names are unique to object type names"
return True, feedback_msg

invalid_predicates = list()
for pred in predicates:
Expand Down