Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

feat: connect reward function #61

Merged
merged 2 commits into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions packages/miner-cloudflare/package.json

This file was deleted.

3 changes: 2 additions & 1 deletion packages/service-discovery/api/miner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ export async function POST(request: Request) {
}

export async function GET(request: Request) {
console.log("getting answer");
const url = new URL(request.url);
const params = new URLSearchParams(url.search);

Expand All @@ -89,7 +90,7 @@ export async function GET(request: Request) {
}

const apiOnly = params.has("api-only")
? Boolean(params.get("api-only"))
? JSON.parse(params.get("api-only")!)
: false;

const model = params.get("model");
Expand Down
8 changes: 0 additions & 8 deletions packages/validator/package.json

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import json
import time
import argparse
import os
import aiohttp
import bittensor as bt
from flask.cli import load_dotenv
from protocol import StreamPrompting
import requests
from fastapi import FastAPI
from pydantic import BaseModel

from stream_miner import StreamMiner
from flask import Flask, current_app, jsonify, request, make_response

load_dotenv()


class Miner(StreamMiner):
def config(self) -> "bt.Config":

parser = argparse.ArgumentParser(description="Streaming Miner Configs")
self.add_args(parser)
return bt.config(parser)
Expand All @@ -34,23 +35,31 @@ async def prompt(self, messages, model) -> StreamPrompting:
return json_resp['result']['response']


app = Flask(__name__)
app.miner = Miner()
app = FastAPI()
miner = Miner()


class ChatRequest(BaseModel):
messages: list
model: str


@app.route("/", methods=['POST'])
async def chat():
data = request.get_json()
miner = current_app.miner
messages = data['messages']
model = data['model']
@app.get("/")
def index():
return "ok"

Comment on lines +47 to +49
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adds healthcheck capacities


@app.post("/chat")
async def chat(request: ChatRequest):
messages = request.messages
model = request.model

response = await miner.prompt(messages=messages, model=model)
messages.append({"role": "system", "content": response})
return jsonify(messages)

return messages

# The main function parses the configuration and runs the validator.
if __name__ == "__main__":

app.run(host='0.0.0.0', port=9000)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=os.getenv('PORT', 9000))
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ flask
pydantic
python-dotenv
simplejson
requests-async
requests-async
gunicorn
aiohttp
fastapi
File renamed without changes.
40 changes: 24 additions & 16 deletions packages/validator/main.py → subnet/validator/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import os
import bittensor as bt

import torch
from validator import BaseValidatorNeuron
from flask import Flask, current_app, jsonify, request, make_response
from fastapi import FastAPI, Request
import aiohttp
import random

import bittensor as bt
from reward import get_reward
from uids import get_random_uids
from requests_async import get
from reward import calculate_total_message_length, get_reward
from typing import TypedDict, Union, List
from urllib.parse import urljoin, urlencode
from dotenv import load_dotenv


class Miner(TypedDict):
Expand All @@ -25,8 +22,8 @@ class Miner(TypedDict):
class Validator(BaseValidatorNeuron):
def __init__(self, config=None):
super(Validator, self).__init__(config=config)

bt.logging.info("load_state()")
load_dotenv()
self.load_state()

async def get_miner_with_model(self, model_name) -> Union[Miner, dict]:
Expand All @@ -41,7 +38,7 @@ async def get_miner_with_model(self, model_name) -> Union[Miner, dict]:
If the response data is not a list, it returns the data as is.
"""

api_only = self.subtensor_connected
api_only = self.subtensor_connected == False
service_map_url = os.getenv('SERVICE_MESH_URL')
secret = os.getenv('SECRET_KEY')
# for now miners are allow listed manually and given a secret key to identify
Expand All @@ -64,25 +61,36 @@ async def get_miner_with_model(self, model_name) -> Union[Miner, dict]:
return data


app = Flask(__name__)
app.validator = Validator()
app = FastAPI()
validator = Validator()


@app.post("/chat")
async def chat():
data = request.get_json()
validator = current_app.validator
async def chat(request: Request):
data = await request.json()

model = data['model']
miner = await validator.get_miner_with_model(model_name=model)
miner_uid = miner['netuid']
prompt_len = calculate_total_message_length(data)

async with aiohttp.ClientSession() as session:
url = miner['address']
async with session.post(url, json=data) as resp:
response = await resp.json()
return jsonify(response)
completion_len = len(response[-1])

reward = get_reward(
model=model, completion_len=completion_len, prompt_len=prompt_len)
print(f'reward for prompt: {reward}')
if (validator.subtensor_connected):
validator.update_scores(
torch.FloatTensor([reward]), [int(miner_uid)])

return response


# The main function parses the configuration and runs the validator.
if __name__ == "__main__":
app.run(host='0.0.0.0', port=8000)
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ python-dotenv
simplejson
starlette
uvicorn
requests-async
requests-async
19 changes: 7 additions & 12 deletions packages/validator/reward.py → subnet/validator/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
}


def calculate_total_message_length(data):
total_length = 0
for message in data["messages"]:
total_length += len(message["content"])
return total_length


def get_reward(model, completion_len, prompt_len) -> float:
print(model, completion_len, prompt_len)
# Define the maximum and minimum completion lengths in characters
Expand Down Expand Up @@ -48,18 +55,6 @@ def get_reward(model, completion_len, prompt_len) -> float:
return reward


def reward(query: int, response: int) -> float:
"""
Reward the miner response to the dummy request. This method returns a reward
value for the miner, which is used to update the miner's score.

Returns:
- float: The reward value for the miner.
"""

return 1.0 if response == query * 2 else 0


def get_rewards(
self,
query: int,
Expand Down
File renamed without changes.
File renamed without changes.
Loading