diff --git a/eval/eval_widgets.txt b/eval/eval_widgets.txt index cccf34c..f6a545e 100644 --- a/eval/eval_widgets.txt +++ b/eval/eval_widgets.txt @@ -34,3 +34,4 @@ display_uniswap fetch_nfts_owned_by_address_or_domain fetch_nfts_owned_by_user fetch_link_suggestion +generate_js_code \ No newline at end of file diff --git a/knowledge_base/widgets.yaml b/knowledge_base/widgets.yaml index acb20f3..4daa64d 100644 --- a/knowledge_base/widgets.yaml +++ b/knowledge_base/widgets.yaml @@ -1,22 +1,22 @@ -- _name_: display_transfer - description: Transfer a token from a user's wallet to another address - parameters: - properties: - address: - description: Transfer recipient address. - type: string - amount: - description: Quantity to transfer. - type: string - token: - description: Symbol of the token being transferred. - type: string - required: - - token - - amount - - address - type: object - return_value_description: '' +# - _name_: display_transfer +# description: Transfer a token from a user's wallet to another address +# parameters: +# properties: +# address: +# description: Transfer recipient address. +# type: string +# amount: +# description: Quantity to transfer. +# type: string +# token: +# description: Symbol of the token being transferred. +# type: string +# required: +# - token +# - amount +# - address +# type: object +# return_value_description: '' - _name_: fetch_nft_buy_asset description: Buy an NFT asset of a collection on the OpenSea marketplace, given its network, address, and token ID. Don't use this if we don't have the collection @@ -854,4 +854,15 @@ - amount - vault type: object - return_value_description: '' \ No newline at end of file + return_value_description: '' +- _name_: generate_js_code + description: generate code for the user based on the query; this widget should always be chosen for anything relating to code within a query + parameters: + properties: + query: + description: a standalone question representing the user's intent to perform an action via code + type: string + required: + - query + type: object + return_value_description: formatted javascript code diff --git a/tools/__init__.py b/tools/__init__.py index 426db12..94b9191 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -6,4 +6,5 @@ from . import index_app_info from . import index_api_tool from . import app_usage_guide -from . import index_link_suggestion \ No newline at end of file +from . import index_link_suggestion +from . import generate_js_code \ No newline at end of file diff --git a/tools/generate_js_code.py b/tools/generate_js_code.py new file mode 100644 index 0000000..a5ce511 --- /dev/null +++ b/tools/generate_js_code.py @@ -0,0 +1,71 @@ +from langchain.prompts import PromptTemplate +from langchain.chains import LLMChain +from pydantic import Extra + +import registry +import streaming +from .base import BaseTool, BASE_TOOL_DESCRIPTION_TEMPLATE + + +TEMPLATE = '''You are an expert Web3 developer well versed in using JS to interact with the ecosystem, you will help the user perform actions based on their request by generating functional JS code + +# INSTRUCTIONS +- Assume user wallet already connected to browser so never ask for a private key, Infura project ID, or any credentials +- Print out transaction hash if applicable +- Always use ethers.js +- Assume there is an ethers.js provider and signer available and can be provided to the function or code +- The code should return a function or a promise that can be called to perform the action +- Your final output should be a formatted JS code function with comments, which can be run on a frontend; don't include anything else for now (ie: messages that precede the code, etc.) + +--- +User: {question} +Assistant:''' + + +@registry.register_class +class GenerateJSCodeTool(BaseTool): + """Tool for generating code to perform a user request.""" + + _chain: LLMChain + + class Config: + """Configuration for this pydantic object.""" + extra = Extra.allow + + def __init__( + self, + *args, + **kwargs + ) -> None: + prompt = PromptTemplate( + input_variables=["question"], + template=TEMPLATE, + ) + new_token_handler = kwargs.get('new_token_handler') + chain = streaming.get_streaming_chain(prompt, new_token_handler) + description = BASE_TOOL_DESCRIPTION_TEMPLATE.format( + tool_description="generate code based on the user query", + input_description="a standalone query where user wants to generate code to perform an action", + output_description="a JSON object with a code field, which contains a formatted JS code function with comments, which can be run on a frontend; don't include anything else for now (ie: messages that precede the code, etc.)" + ) + + super().__init__( + *args, + _chain=chain, + description=description, + **kwargs + ) + + def _run(self, query: str) -> str: + example = { + "question": query, + "stop": "User", + } + result = self._chain.run(example) + print('result in generate_js_code', result) + + return result + + async def _arun(self, query: str) -> str: + raise NotImplementedError( + f"{self.__class__.__name__} does not support async") diff --git a/tools/index_widget.py b/tools/index_widget.py index 6274088..5881cfb 100644 --- a/tools/index_widget.py +++ b/tools/index_widget.py @@ -92,7 +92,8 @@ def injection_handler(token): # we have found the response_prefix, trim everything before that timing.log('first_widget_response_token') response_state = 1 - response_buffer = response_buffer[response_buffer.index(response_prefix) + len(response_prefix):] + response_buffer = response_buffer[response_buffer.index( + response_prefix) + len(response_prefix):] if response_state == 1: # we are going to output the response incrementally, evaluating any fetch commands while WIDGET_START in response_buffer and self._evaluate_widgets: @@ -102,14 +103,17 @@ def injection_handler(token): if isinstance(response_buffer, Callable): # handle delegated streaming def handler(token): nonlocal new_token_handler - timing.log('first_visible_widget_response_token') + timing.log( + 'first_visible_widget_response_token') return new_token_handler(token) response_buffer(handler) response_buffer = "" return - elif isinstance(response_buffer, Generator): # handle stream of widgets + # handle stream of widgets + elif isinstance(response_buffer, Generator): for item in response_buffer: - timing.log('first_visible_widget_response_token') + timing.log( + 'first_visible_widget_response_token') new_token_handler(str(item) + "\n") response_buffer = "" return @@ -118,7 +122,8 @@ def handler(token): # NB: for better frontend parsing of nested widgets, we need an invariant that # there are no two independent widgets on the same line, otherwise we can't # detect the closing tag properly when there is nesting. - response_buffer = response_buffer.replace(WIDGET_END, WIDGET_END + '\n') + response_buffer = response_buffer.replace( + WIDGET_END, WIDGET_END + '\n') break else: # keep waiting @@ -139,7 +144,8 @@ def handler(token): # we have found a line-break in the response, switch to the terminal state to mask subsequent output response_state = 2 - chain = streaming.get_streaming_chain(prompt, injection_handler, model_name=model_name) + chain = streaming.get_streaming_chain( + prompt, injection_handler, model_name=model_name) super().__init__( *args, _chain=chain, @@ -165,7 +171,7 @@ def _run(self, query: str) -> str: def iterative_evaluate(phrase: str) -> Union[str, Generator, Callable]: while True: # before we had streaming, we could use this - #eval_phrase = RE_COMMAND.sub(replace_match, phrase) + # eval_phrase = RE_COMMAND.sub(replace_match, phrase) # now, iterate manually to find any streamable components eval_phrase = "" last_matched_char = 0 @@ -249,7 +255,9 @@ def replace_match(m: re.Match) -> Union[str, Generator, Callable]: # elif command == 'fetch-scraped-sites': # return fetch_scraped_sites(*params) elif command == 'fetch-link-suggestion': - return fetch_link_suggestion(*params) + return fetch_link_suggestion(*params) + elif command == 'generate-js-code': + return generate_js_code(*params) elif command == aave.AaveSupplyContractWorkflow.WORKFLOW_TYPE: return str(exec_aave_operation(*params, operation='supply')) elif command == aave.AaveBorrowContractWorkflow.WORKFLOW_TYPE: @@ -277,6 +285,7 @@ def replace_match(m: re.Match) -> Union[str, Generator, Callable]: # assert 0, 'unrecognized command: %s' % m.group(0) return m.group(0) + @error_wrap def fetch_price(basetoken: str, quotetoken: str = "usd") -> str: # TODO @@ -302,7 +311,8 @@ def fetch_price(basetoken: str, quotetoken: str = "usd") -> str: else: return f"Quote currency {quotetoken} not supported" - coingecko_api_url = coingecko_api_url_prefix + f"?ids={basetoken_id}&vs_currencies={quotetoken_id}" + coingecko_api_url = coingecko_api_url_prefix + \ + f"?ids={basetoken_id}&vs_currencies={quotetoken_id}" response = requests.get(coingecko_api_url) response.raise_for_status() return f"The price of {basetoken_name} is {list(list(response.json().values())[0].values())[0]} {quotetoken}" @@ -311,7 +321,8 @@ def fetch_price(basetoken: str, quotetoken: str = "usd") -> str: @error_wrap def fetch_balance(token: str, wallet_address: str) -> str: if not wallet_address or wallet_address == 'None': - raise FetchError(f"Please specify the wallet address to check the token balance of.") + raise FetchError( + f"Please specify the wallet address to check the token balance of.") web3 = context.get_web3_provider() chain_id = context.get_wallet_chain_id() balance = get_token_balance(web3, chain_id, token, wallet_address) @@ -365,7 +376,7 @@ def fn(token_handler): # _streaming=True, # name="ScrapedSitesIndexAnswer", # content_description="", # not used -# index=scraped_sites_index, +# index=scraped_sites_index, # top_k=3, # source_key="url", # ) @@ -373,6 +384,7 @@ def fn(token_handler): # tool._run(query) # return fn + @error_wrap def fetch_link_suggestion(query: str) -> Callable: def fn(token_handler): @@ -382,7 +394,7 @@ def fn(token_handler): _streaming=True, name="LinkSuggestionIndexAnswer", content_description="", # not used - index=dapps_index, + index=dapps_index, top_k=3, source_key="url", ) @@ -391,6 +403,20 @@ def fn(token_handler): return fn +@error_wrap +def generate_js_code(query: str) -> Callable: + tool = dict( + name="GenerateJSCodeTool", + type="tools.generate_js_code.GenerateJSCodeTool", + _streaming=True, + ) + tool = streaming.get_streaming_tools([tool], None)[0] + print('running tool') + code =tool._run(query) + print('code', code) + return str(CodeContainer(code=code)) + + class ListContainer(ContainerMixin, list): def message_prefix(self) -> str: num = len(self) @@ -461,6 +487,27 @@ def container_params(self) -> Dict: ) +@dataclass +class CodeContainer(ContainerMixin): + code: str + message_prefix_str: str = "" + message_suffix_str: str = "" + + def message_prefix(self) -> str: + return self.message_prefix_str + + def message_suffix(self) -> str: + return self.message_suffix_str + + def container_name(self) -> str: + return 'display-code-container' + + def container_params(self) -> Dict: + return dict( + code=self.code, + ) + + @error_wrap def fetch_nft_search(search_str: str) -> Generator: yield StreamingListContainer(operation="create", prefix="Searching", is_thinking=True) @@ -551,10 +598,12 @@ def fetch_nfts_owned_by_user(network: str = None) -> str: parsed_network = network return str(center.fetch_nfts_owned_by_address_or_domain(parsed_network, wallet_address)) + @error_wrap def fetch_nft_buy(network: str, address: str, token_id: str) -> str: wallet_address = context.get_wallet_address() - nft_fulfillment_container = center.fetch_nft_buy(network, wallet_address, address, token_id) + nft_fulfillment_container = center.fetch_nft_buy( + network, wallet_address, address, token_id) return str(nft_fulfillment_container) @@ -570,10 +619,12 @@ def fetch_yields(token, network, count) -> str: ] if network == '*': - headers = [TableHeader(field_name="network", display_name="Network")] + headers + headers = [TableHeader(field_name="network", + display_name="Network")] + headers if token == '*': - headers = [TableHeader(field_name="token", display_name="Token")] + headers + headers = [TableHeader( + field_name="token", display_name="Token")] + headers table_container = TableContainer(headers=headers, rows=yields) return str(table_container) @@ -633,10 +684,11 @@ def container_name(self) -> str: def container_params(self) -> Dict: return dataclass_to_container_params(self) + @error_wrap @ensure_wallet_connected -def set_ens_text(domain: str, key: str, value: str) ->TxPayloadForSending: - wallet_chain_id = 1 # TODO: get from context +def set_ens_text(domain: str, key: str, value: str) -> TxPayloadForSending: + wallet_chain_id = 1 # TODO: get from context wallet_address = context.get_wallet_address() user_chat_message_id = context.get_user_chat_message_id() @@ -646,13 +698,15 @@ def set_ens_text(domain: str, key: str, value: str) ->TxPayloadForSending: 'value': value, } - result = ens.ENSSetTextWorkflow(wallet_chain_id, wallet_address, user_chat_message_id, params).run() + result = ens.ENSSetTextWorkflow( + wallet_chain_id, wallet_address, user_chat_message_id, params).run() return TxPayloadForSending.from_workflow_result(result) + @error_wrap @ensure_wallet_connected -def set_ens_primary_name(domain: str) ->TxPayloadForSending: - wallet_chain_id = 1 # TODO: get from context +def set_ens_primary_name(domain: str) -> TxPayloadForSending: + wallet_chain_id = 1 # TODO: get from context wallet_address = context.get_wallet_address() user_chat_message_id = context.get_user_chat_message_id() @@ -660,13 +714,15 @@ def set_ens_primary_name(domain: str) ->TxPayloadForSending: 'domain': domain, } - result = ens.ENSSetPrimaryNameWorkflow(wallet_chain_id, wallet_address, user_chat_message_id, params).run() + result = ens.ENSSetPrimaryNameWorkflow( + wallet_chain_id, wallet_address, user_chat_message_id, params).run() return TxPayloadForSending.from_workflow_result(result) + @error_wrap @ensure_wallet_connected -def set_ens_avatar_nft(domain: str, nftContractAddress: str, nftId: str, collectionName: str) ->TxPayloadForSending: - wallet_chain_id = 1 # TODO: get from context +def set_ens_avatar_nft(domain: str, nftContractAddress: str, nftId: str, collectionName: str) -> TxPayloadForSending: + wallet_chain_id = 1 # TODO: get from context wallet_address = context.get_wallet_address() user_chat_message_id = context.get_user_chat_message_id() @@ -677,5 +733,6 @@ def set_ens_avatar_nft(domain: str, nftContractAddress: str, nftId: str, collect 'collectionName': collectionName } - result = ens.ENSSetAvatarNFTWorkflow(wallet_chain_id, wallet_address, user_chat_message_id, params).run() + result = ens.ENSSetAvatarNFTWorkflow( + wallet_chain_id, wallet_address, user_chat_message_id, params).run() return TxPayloadForSending.from_workflow_result(result) diff --git a/utils/constants.py b/utils/constants.py index 3c9fa71..dc37c82 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -53,7 +53,7 @@ WIDGET_INFO_TOKEN_LIMIT = 4000 # Widget Index -WIDGET_INDEX_NAME = "WidgetV25" +WIDGET_INDEX_NAME = "WidgetV26" def get_widget_index_name(): if env.is_local():