1212##
1313
1414def prepare_url (prov , url = None , host = None , port = None ):
15- if host is None :
16- host = 'localhost'
17- if port is None :
18- port = 8000
19- if url is None :
20- url = prov ['url' ].format (host = host , port = port )
21- return url
15+ host = prov .get ('host' ) if host is None else host
16+ port = prov .get ('port' ) if port is None else port
17+ url = prov .get ('url' ) if url is None else url
18+ return url .format (host = host , port = port )
2219
2320def prepare_auth (prov , api_key = None ):
2421 if (auth_func := prov .get ('authorize' )) is not None :
25- if api_key is None and (api_key := os .environ .get (key_env := prov ['api_key_env' ])) is None :
26- raise Exception ('Cannot find API key in {key_env }' )
22+ if (api_key := os .environ .get (prov ['api_key_env' ])) is None :
23+ raise Exception ('Cannot find API key in {api_key_env }' )
2724 headers_auth = auth_func (api_key )
2825 else :
2926 headers_auth = {}
@@ -36,7 +33,7 @@ def prepare_model(prov, model=None):
3633
3734def prepare_request (
3835 query , provider = 'local' , system = None , prefill = None , prediction = None , history = None ,
39- url = None , port = None , api_key = None , model = None , max_tokens = DEFAULT_MAX_TOKENS , ** kwargs
36+ url = None , host = None , port = None , api_key = None , model = None , max_tokens = DEFAULT_MAX_TOKENS , ** kwargs
4037):
4138 # external provider
4239 prov = get_provider (provider )
@@ -45,7 +42,7 @@ def prepare_request(
4542 max_tokens_name = prov .get ('max_tokens_name' , 'max_tokens' )
4643
4744 # get full url
48- url = prepare_url (prov , url = url , port = port )
45+ url = prepare_url (prov , url = url , host = host , port = port )
4946
5047 # get authorization headers
5148 headers_auth = prepare_auth (prov , api_key = api_key )
@@ -127,9 +124,9 @@ async def reply_async(query, provider='local', history=None, prefill=None, **kwa
127124## stream requests
128125##
129126
130- def parse_stream_data (chunk ):
131- if chunk .startswith (b'data: ' ):
132- text = chunk [6 :]
127+ def parse_stream_data (line ):
128+ if line .startswith (b'data: ' ):
129+ text = line [6 :]
133130 if text != b'[DONE]' and len (text ) > 0 :
134131 return text
135132
@@ -172,7 +169,9 @@ def stream(query, provider='local', history=None, prefill=None, **kwargs):
172169 for line in response .iter_lines ():
173170 if (data := parse_stream_data (line )) is not None :
174171 parsed = json .loads (data )
175- yield extractor (parsed )
172+ text = extractor (parsed )
173+ if text is not None :
174+ yield text
176175
177176async def stream_async (query , provider = 'local' , history = None , prefill = None , ** kwargs ):
178177 # get provider
@@ -206,7 +205,9 @@ async def stream_async(query, provider='local', history=None, prefill=None, **kw
206205 async for line in lines :
207206 if (data := parse_stream_data (line )) is not None :
208207 parsed = json .loads (data )
209- yield extractor (parsed )
208+ text = extractor (parsed )
209+ if text is not None :
210+ yield text
210211
211212##
212213## embeddings
0 commit comments