Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 54 additions & 17 deletions libs/ibm/langchain_ibm/utilities/sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,34 +56,68 @@ def convert_column_data(field_metadata: dict) -> str:
native_type = field_metadata_type.get("native_type")
nullable = field_metadata_type.get("nullable")

return f"{name} {native_type}{'' if nullable else ' NOT NULL'},"
return f"{name} {native_type}{'' if nullable else ' NOT NULL'}"

create_table_template = """
CREATE TABLE {schema}.{table_name} (
\t{column_definitions}{primary_key}
\t{column_definitions}{primary_key}{foreign_key}
\t)"""

primary_key: dict = next(
filter(
lambda el: el.get("name") == "primary_key",
table_info.get("extended_metadata", [{}]),
),
{},
)
key_columns = primary_key.get("value", {}).get("key_columns", [])
extended_metadata = table_info.get("extended_metadata", [{}])

def _retrieve_field_data(field_name: str) -> dict:
return next(
filter(
lambda el: el.get("name") == field_name,
extended_metadata,
),
{},
)

# Primary Key
primary_key: dict = _retrieve_field_data("primary_key")
if primary_key:
key_columns = ", ".join(primary_key.get("value", {}).get("key_columns", []))
primary_key_text = (
f",\n\tCONSTRAINT {primary_key['name']} PRIMARY KEY ({key_columns})"
)
else:
primary_key_text = ""

# Foreign keys
foreign_keys: dict = _retrieve_field_data("foreign_keys")
if foreign_keys:
foreign_keys_text = ""
foreign_key_text_template = (
"CONSTRAINT {fk_name} FOREIGN KEY ({col_name}) "
"REFERENCES {external_table_name}({external_col_name})"
)
for foreign_key in foreign_keys.get("value", []):
foreign_keys_text += ",\n\t"
join_condition = foreign_key["join_condition"].split("=")
foreign_keys_text += foreign_key_text_template.format(
fk_name=foreign_key["name"],
col_name=join_condition[0].strip().split(".")[-1],
external_table_name=join_condition[1].strip().split(".")[1],
external_col_name=join_condition[1].strip().split(".")[2],
)
else:
foreign_keys_text = ""

return create_table_template.format(
schema=schema,
table_name=table_name,
column_definitions="\n\t".join(
[
convert_column_data(field_metadata=field_metadata)
for field_metadata in table_info["fields"]
# Do not add comma for the last column
convert_column_data(field_metadata=field_metadata) + ","
if index < len(table_info["fields"])
else convert_column_data(field_metadata=field_metadata)
for index, field_metadata in enumerate(table_info["fields"], start=1)
]
),
primary_key=f"\n\tPRIMARY KEY ({', '.join(key_columns)})"
if primary_key
else "",
primary_key=primary_key_text,
foreign_key=foreign_keys_text,
)


Expand Down Expand Up @@ -300,7 +334,10 @@ def _check_with_username(

self._meta_all_tables = {
table_name: flight_sql_client.get_table_info(
table_name=table_name, schema=self.schema
table_name=table_name,
schema=self.schema,
extended_metadata=True,
interaction_properties=True,
)
for table_name in self._all_tables
if table_name in (self._include_tables or self._all_tables)
Expand Down Expand Up @@ -385,7 +422,7 @@ def get_table_info(self, table_names: Optional[Iterable[str]] = None) -> str:
schema=self.schema,
table_name=table_name,
n=self._sample_rows_in_table_info,
).to_string()
).to_string(index=False)
for table_name in table_names
]
)
Expand Down
12 changes: 6 additions & 6 deletions libs/ibm/tests/unit_tests/utilities/test_sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_pretty_print_table_info(
\tid INT NOT NULL,
\tname VARCHAR(255),
\tage INT,
\tPRIMARY KEY (id)
\tCONSTRAINT primary_key PRIMARY KEY (id)
\t)"""
assert pretty_print_table_info(schema, table_name, table_info) == expected_output

Expand All @@ -173,7 +173,7 @@ def test_pretty_print_table_info_with_nullable_columns() -> None:
CREATE TABLE another_schema.another_table (
\temail VARCHAR(255),
\tcreated_at TIMESTAMP,
\tPRIMARY KEY (email)
\tCONSTRAINT primary_key PRIMARY KEY (email)
\t)"""
assert pretty_print_table_info(schema, table_name, table_info) == expected_output

Expand All @@ -193,7 +193,7 @@ def test_pretty_print_table_info_without_primary_key() -> None:
expected_output = """
CREATE TABLE no_pk_schema.no_pk_table (
\tvalue1 INT NOT NULL,
\tvalue2 VARCHAR(255),
\tvalue2 VARCHAR(255)
\t)"""
assert pretty_print_table_info(schema, table_name, table_info) == expected_output

Expand Down Expand Up @@ -425,13 +425,13 @@ def test_initialize_watsonx_sql_database_get_table_info(
\tid INT NOT NULL,
\tname VARCHAR(255),
\tage INT,
\tPRIMARY KEY (id)
\tCONSTRAINT primary_key PRIMARY KEY (id)
\t)

First 3 rows of table table1:

id name age
0 1 test 35"""
id name age
1 test 35"""
print(wx_sql_database.get_table_info(["table1"]))
assert wx_sql_database.get_table_info(["table1"]) == expected_output

Expand Down