Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion libs/community/langchain_community/chat_models/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -144,6 +145,22 @@ class ChatSnowflakeCortex(BaseChatModel):
"""Automatically inferred from env var `SNOWFLAKE_WAREHOUSE` if not provided."""
snowflake_role: Optional[str] = Field(default=None, alias="role")
"""Automatically inferred from env var `SNOWFLAKE_ROLE` if not provided."""
snowflake_authenticator: Optional[str] = Field(default=None, alias="authenticator")
"""
The authentication method for connecting to Snowflake.
Set to 'OAUTH' (case-insensitive) to use OAuth authentication with a token.
If not provided, password-based authentication is used by default.
This value can be passed as an argument or set via the `SNOWFLAKE_AUTHENTICATOR` environment variable.
If set to 'OAUTH', you must also provide a valid OAuth token via the `token` argument or the `SNOWFLAKE_ACCESS_TOKEN` environment variable.
Refer to Snowflake documentation for other supported authenticators if needed.
"""
snowflake_token: Optional[str] = Field(default=None, alias="token")
"""
The OAuth access token for authentication when `authenticator` is set to 'OAUTH'.
Required if using OAuth authentication (`authenticator='OAUTH'`).
This value can be passed as an argument or set via the `SNOWFLAKE_ACCESS_TOKEN` environment variable.
If not using OAuth, this field is ignored. Ensure your token is valid and has the necessary permissions for Snowflake Cortex access.
"""

def bind_tools(
self,
Expand Down Expand Up @@ -206,18 +223,34 @@ def validate_environment(cls, values: Dict) -> Dict:
values["snowflake_role"] = get_from_dict_or_env(
values, "snowflake_role", "SNOWFLAKE_ROLE"
)
if os.getenv("SNOWFLAKE_AUTHENTICATOR") == "OAUTH" or os.getenv("snowflake_authenticator"):
values["snowflake_authenticator"] = get_from_dict_or_env(
values, "snowflake_authenticator", "SNOWFLAKE_AUTHENTICATOR"
)
values["snowflake_token"] = get_from_dict_or_env(
values, "snowflake_token", "SNOWFLAKE_ACCESS_TOKEN"
)

authenticator = values.get("snowflake_authenticator")
password = values.get("snowflake_password")
token = values.get("snowflake_token")

connection_params = {
"account": values["snowflake_account"],
"user": values["snowflake_username"],
"password": values["snowflake_password"].get_secret_value(),
"database": values["snowflake_database"],
"schema": values["snowflake_schema"],
"warehouse": values["snowflake_warehouse"],
"role": values["snowflake_role"],
"client_session_keep_alive": "True",
}

if authenticator and str(authenticator).lower() == "oauth":
connection_params["authenticator"] = authenticator
connection_params["token"] = token
else:
connection_params["password"] = password.get_secret_value() if isinstance(password, SecretStr) else password

try:
values["session"] = Session.builder.configs(connection_params).create()
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
SNOWFLAKE_SCHEMA="YOUR_SNOWFLAKE_SCHEMA",
SNOWFLAKE_WAREHOUSE="YOUR_SNOWFLAKE_WAREHOUSE"
SNOWFLAKE_ROLE="YOUR_SNOWFLAKE_ROLE",
SNOWFLAKE_AUTHENTICATOR="OAUTH",
SNOWFLAKE_TOKEN="YOUR_SNOWFLAKE_TOKEN",
"""

import pytest
Expand Down