Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mallm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__='v1.0.4'
__version__ = 'v1.0.4'
3 changes: 2 additions & 1 deletion mallm/models/Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def _call( # type: ignore
log_prob_sum = 0.0
for message in chat_completion:
message_str = message.choices[0].delta.content
log_prob_sum += message.choices[0].logprobs.content[0].logprob
if message.choices[0].logprobs:
log_prob_sum += message.choices[0].logprobs.content[0].logprob
if message_str and message_str not in self.stop_tokens:
collected_messages.append(message_str)
log_prob_sum = log_prob_sum / len(collected_messages)
Expand Down
11 changes: 7 additions & 4 deletions mallm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,18 @@ def __init__(self, config: Config) -> None:
self.llm = Chat(
client=OpenAI(
base_url=self.config.endpoint_url, api_key=self.config.api_key
)
),
model=self.config.model_name
)

self.judge_llm = None
if self.config.judge_endpoint_url:
self.judge_llm = Chat(
client=OpenAI(
base_url=self.config.judge_endpoint_url, api_key=self.config.judge_api_key
)
client=OpenAI(
base_url=self.config.judge_endpoint_url,
api_key=self.config.judge_api_key,
),
model=self.config.judge_model_name,
)

if config.response_generator not in RESPONSE_GENERATORS:
Expand Down
10 changes: 1 addition & 9 deletions mallm/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Config:
judge_intervention: Optional[str] = None
judge_metric: Optional[str] = None
judge_endpoint_url: Optional[str] = None
judge_model_name: Optional[str] = None
judge_api_key: str = "-"
judge_always_intervene: bool = False

Expand Down Expand Up @@ -117,15 +118,6 @@ def check_config(self) -> None:
if self.endpoint_url.endswith("/"):
logger.warning("Removing trailing / from the endpoint url.")
self.endpoint_url = self.endpoint_url[:-1]
try:
logger.info("Testing availability of the endpoint...")
page = requests.head(self.endpoint_url.replace("/v1", ""))
logger.info("Status: " + str(page.status_code))
assert page.status_code == 200
except Exception as e:
logger.error("HTTP Error: Could not connect to the provided endpoint url.")
logger.error(e)
sys.exit(1)
if self.concurrent_api_requests > 250:
logger.warning(
"concurrent_api_requests is very large. Please make sure the API endpoint you are using can handle that many simultaneous requests."
Expand Down
4 changes: 3 additions & 1 deletion mallm/utils/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from mallm.discourse_policy.report import DiscourseReport
from mallm.models.discussion.CriticalResponseGenerator import CriticalResponseGenerator
from mallm.models.discussion.FreeTextResponseGenerator import FreeTextResponseGenerator
from mallm.models.discussion.ReasoningResponseGenerator import ReasoningResponseGenerator
from mallm.models.discussion.ReasoningResponseGenerator import (
ReasoningResponseGenerator,
)
from mallm.models.discussion.ResponseGenerator import ResponseGenerator
from mallm.models.discussion.SimpleResponseGenerator import SimpleResponseGenerator
from mallm.models.discussion.SplitFreeTextResponseGenerator import (
Expand Down