Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

feat: connect reward function #61

Merged
merged 2 commits into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions packages/miner-cloudflare/package.json

This file was deleted.

2 changes: 1 addition & 1 deletion packages/service-discovery/api/miner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ export async function GET(request: Request) {
}

const apiOnly = params.has("api-only")
? Boolean(params.get("api-only"))
? JSON.parse(params.get("api-only")!)
: false;

const model = params.get("model");
Expand Down
8 changes: 0 additions & 8 deletions packages/validator/package.json

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import json
import time
import argparse
import os
import aiohttp
import bittensor as bt
from flask.cli import load_dotenv
from protocol import StreamPrompting
import requests
from fastapi import FastAPI
from pydantic import BaseModel

from stream_miner import StreamMiner
from flask import Flask, current_app, jsonify, request, make_response

load_dotenv()


class Miner(StreamMiner):
def config(self) -> "bt.Config":

parser = argparse.ArgumentParser(description="Streaming Miner Configs")
self.add_args(parser)
return bt.config(parser)
Expand All @@ -34,23 +35,31 @@ async def prompt(self, messages, model) -> StreamPrompting:
return json_resp['result']['response']


app = Flask(__name__)
app.miner = Miner()
app = FastAPI()
miner = Miner()


class ChatRequest(BaseModel):
messages: list
model: str


@app.route("/", methods=['POST'])
async def chat():
data = request.get_json()
miner = current_app.miner
messages = data['messages']
model = data['model']
@app.get("/")
def index():
return "ok"

Comment on lines +47 to +49
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adds healthcheck capacities


@app.post("/chat")
async def chat(request: ChatRequest):
messages = request.messages
model = request.model

response = await miner.prompt(messages=messages, model=model)
messages.append({"role": "system", "content": response})
return jsonify(messages)

return messages

# The main function parses the configuration and runs the validator.
if __name__ == "__main__":

app.run(host='0.0.0.0', port=9000)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=os.getenv('PORT', 9000))
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ flask
pydantic
python-dotenv
simplejson
requests-async
requests-async
gunicorn
aiohttp
fastapi
File renamed without changes.
40 changes: 24 additions & 16 deletions packages/validator/main.py → subnet/validator/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import os
import bittensor as bt

import torch
from validator import BaseValidatorNeuron
from flask import Flask, current_app, jsonify, request, make_response
from fastapi import FastAPI, Request
import aiohttp
import random

import bittensor as bt
from reward import get_reward
from uids import get_random_uids
from requests_async import get
from reward import calculate_total_message_length, get_reward
from typing import TypedDict, Union, List
from urllib.parse import urljoin, urlencode
from dotenv import load_dotenv


class Miner(TypedDict):
Expand All @@ -25,8 +22,8 @@ class Miner(TypedDict):
class Validator(BaseValidatorNeuron):
def __init__(self, config=None):
super(Validator, self).__init__(config=config)

bt.logging.info("load_state()")
load_dotenv()
self.load_state()

async def get_miner_with_model(self, model_name) -> Union[Miner, dict]:
Expand All @@ -41,7 +38,7 @@ async def get_miner_with_model(self, model_name) -> Union[Miner, dict]:
If the response data is not a list, it returns the data as is.
"""

api_only = self.subtensor_connected
api_only = self.subtensor_connected == False
service_map_url = os.getenv('SERVICE_MESH_URL')
secret = os.getenv('SECRET_KEY')
# for now miners are allow listed manually and given a secret key to identify
Expand All @@ -64,25 +61,36 @@ async def get_miner_with_model(self, model_name) -> Union[Miner, dict]:
return data


app = Flask(__name__)
app.validator = Validator()
app = FastAPI()
validator = Validator()


@app.post("/chat")
async def chat():
data = request.get_json()
validator = current_app.validator
async def chat(request: Request):
data = await request.json()

model = data['model']
miner = await validator.get_miner_with_model(model_name=model)
miner_uid = miner['netuid']
prompt_len = calculate_total_message_length(data)

async with aiohttp.ClientSession() as session:
url = miner['address']
async with session.post(url, json=data) as resp:
response = await resp.json()
return jsonify(response)
completion_len = len(response[-1])

reward = get_reward(
model=model, completion_len=completion_len, prompt_len=prompt_len)
print(f'reward for prompt: {reward}')
if (validator.subtensor_connected):
validator.update_scores(
torch.FloatTensor([reward]), [int(miner_uid)])

return response


# The main function parses the configuration and runs the validator.
if __name__ == "__main__":
app.run(host='0.0.0.0', port=8000)
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ python-dotenv
simplejson
starlette
uvicorn
requests-async
requests-async
19 changes: 7 additions & 12 deletions packages/validator/reward.py → subnet/validator/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
}


def calculate_total_message_length(data):
total_length = 0
for message in data["messages"]:
total_length += len(message["content"])
return total_length


def get_reward(model, completion_len, prompt_len) -> float:
print(model, completion_len, prompt_len)
# Define the maximum and minimum completion lengths in characters
Expand Down Expand Up @@ -48,18 +55,6 @@ def get_reward(model, completion_len, prompt_len) -> float:
return reward


def reward(query: int, response: int) -> float:
"""
Reward the miner response to the dummy request. This method returns a reward
value for the miner, which is used to update the miner's score.

Returns:
- float: The reward value for the miner.
"""

return 1.0 if response == query * 2 else 0


def get_rewards(
self,
query: int,
Expand Down
File renamed without changes.
File renamed without changes.
Loading