-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: checkpoints #15
Conversation
16b357a
to
8fd7f04
Compare
2cace50
to
39c8d61
Compare
src/exchange/exchange.py
Outdated
@@ -38,10 +39,10 @@ class Exchange: | |||
provider: Provider | |||
model: str | |||
system: str | |||
moderator: Moderator = field(default=ContextTruncate()) | |||
moderator: Moderator = field(default=PassiveModerator()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason for changing this to Passive? Truncate is probably a better user experience
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah didn't catch this in my proof-read - meant to change it back. Though thoughts on making it ContextSummarize()
since that is better at maintaining context?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea eventually. I think we still need feedback on it to see how well it's working for people
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, just set it back to ContextTruncate
and we can revisit later.
@@ -51,8 +52,10 @@ def replace(self, **kwargs: Dict[str, Any]) -> "Exchange": | |||
"""Make a copy of the exchange, replacing any passed arguments""" | |||
if kwargs.get("messages") is None: | |||
kwargs["messages"] = deepcopy(self.messages) | |||
if kwargs.get("checkpoints") is None: | |||
kwargs["checkpoints"] = deepcopy(self.checkpoints) | |||
if kwargs.get("checkpoint_data") is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This breaks backwards compatibility, specifically breaking some uses in Goose: https://github.com/search?q=repo%3Asquare%2Fgoose%20checkpoints%3D&type=code. We'll need to update those references and anywhere else they might be (custom plugins repo's).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Will update once this PR gets merged in. We'll a major version bump to indicate the breaking changes, even though they are minor breaking changes.
first_checkpoint_start_index = first_checkpoint.start_index - self.checkpoint_data.message_index_offset | ||
|
||
# check if the first message is part of the first checkpoint | ||
if first_checkpoint_start_index == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it so that if this is not the case, we've goofed somewhere and are not in a good state? Is there a way we could get here (I don't think so), but maybe we add a check for that to catch it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can totally reach both first_checkpoint_start_index == 0
and first_checkpoint_start_index != 0
. Just depicted an example below:
note: [] denotes list of messages, () denotes how checkpoints are grouped
if we start off with:
[(1,2,3),(4),(5)]
pop_first_message()
> first_checkpoint_start_index == 0
> and the resulting array: [2,3,(4),(5)]
pop_first_message() again
> first_checkpoint_start_index == 2
> and the resulting array: [3,(4),(5)]
self.checkpoint_data.message_index_offset = new_index | ||
|
||
@property | ||
def is_allowed_to_call_llm(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is kind of LLM dependent? You can call gpt4o without any messages for example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, maybe this belongs in the Provider class 🤔. I've removed its usages now and marked it with a TODO
so we can revisit it at a later point.
# currently this is the point at which we start to summarize, so | ||
# so once we get to this token size the token count will exceed this | ||
# by a little bit | ||
MAX_TOKENS = 70000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as for Truncate. We should start summarizing at a higher token count, but also having a knob to summarize a fraction (or offset) of the messages
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should ie also be per provider? claude has a much higher context window for example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, definitely per provider. We need a way to set these in the profiles (on a backlog) so users can set these values directly instead of using the defaults.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, we need to add two new things for providers:
max_tokens
(variable)is_allowed_to_call_llm
(method)
I've annotated both of these areas in the code with TODOs
|
||
|
||
def pop_checkpoint( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice that this gets cleaned up!
model=self.model, | ||
messages=messages_to_summarize, | ||
# checkpoint_data=CheckpointData(), | ||
# TODO: figure out why the summarizer exchange has checkpoint data that is not empty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is it set too? Is this because the messages_to_summarize is a subset of the messages? Haven't looked deeply at this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea here was to create a brand new instance of CheckpointData
. This is helpful because when we then call .reply()
to get the summary, we have two checkpoints: (0..n-1) & (n)
- where n is the summary message. Just implemented the TODO
captured in this codeblock in the most recent change to this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense. Was this pushed up? I don't see the changes described.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, changes are pushed up. You should see it's changed from
...
model=self.model,
messages=messages_to_summarize,
# checkpoint_data=CheckpointData(),
# TODO: figure out why the summarizer exchange has checkpoint data that is not empty
to
...
model=self.model,
messages=messages_to_summarize,
checkpoint_data=CheckpointData(),
src/exchange/moderators/truncate.py
Outdated
# currently this is the point at which we start to truncate, so | ||
# so once we get to this token size the token count will exceed this | ||
# by a little bit | ||
MAX_TOKENS = 70000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be larger. 128K is the standard context lengths, and we have a max limit on tool usage at 16000 (
exchange/src/exchange/exchange.py
Line 20 in ed3a6ab
max_output_tokens = 16000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks for explaining that! I'll implement this and drop some comments in the code for posterity.
I am thinking of adding some extra padding so that if someone misconfigures their accelerator model, we can handle that gracefully. Will be bringing down the max tokens from 112000 to 100000.
# calculate the system prompt tokens (includes functions etc...) | ||
_system_token_exchange = exchange.replace( | ||
messages=[], | ||
checkpoints=[], | ||
checkpoint_data=CheckpointData(), | ||
moderator=PassiveModerator(), | ||
model=self.model if self.model else exchange.model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should prob be the same model as the initial exchange so that token content is consistent with what we're expecting. Between models on openai, this is ok, but if we use claude where while the original is gpt-4, there will be discrepancies (don't expect to be major though)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, I just added a check for this on the .replace()
method - any calls to .replace()
that specify model
in the kwargs
should also specify checkpoint_data
.
I'm not convinced that we should override to using the exchange's model here, as it defeats the purpose of specifying an accelerator model in the profile being used. Instead, we expect folks will configure profiles in a meaningful way, and our scripting should add enough token count padding to account for some user errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just meant for this specific case to calculate the system message token count, but I guess this effects everything. +1 on users using self consistent profiles
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, I just added a check for this on the .replace() method - any calls to .replace() that specify model in the kwargs should also specify checkpoint_data.
I actually ended up removing this, as it causes too many new errors in Goose. This is indicative of us already falling into patterns of mutating the Exchange in potentially unsafe ways.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
drive-by: I might add to the integration test an edge case, so that we don't lose the system prompt or other things that could interfere with tool calls. This is in addition to the unit tests you have here, as it helps us know end-to-end works.
hey @codefromthecrypt - great callout! we are actually planning on revamping a bunch of our integration tests this coming week. Going to punt on including them in this PR specifically though! |
* main: fix typos found by PyCharm (#21) added retry when sending httpx request to LLM provider apis (#20) chore: version bump to `0.8.2` (#19) fix: don't always use ollama provider (#18) fix: export `metadata.plugins` export should have a valid value (#17) Create an entry-point for `ai-exchange` (#16) chore: Run tests for python >=3.10 (#14) Update pypi_release.yaml (#13) ollama provider (#7) chore: gitignore generated lockfile (#8)
0b9ce57
to
45949f5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a last minute bikeshed 😎
|
||
|
||
@define | ||
class CheckpointData: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
completely style, but as this is an aggregate, words that imply plural can help vs something like Data. At first, esp as there is a function about "message" in here, I thought there may be data, like messages in here..
class CheckpointData: | |
class Checkpoints: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I totally agree here - the naming is weird. I'm hoping to completely remove the CheckpointData
class and move its variables into the Exchange
class instead.
Avoiding the rename right now as it would lead to self.checkpoints.checkpoints
access patterns on the containing Exchange
.
# pop from the left or right sides. This offset allows us to map the checkpoint indices | ||
# to the correct message index, even if we have popped messages from the left side of | ||
# the exchange in the past. we reset this offset to 0 when we empty the checkpoint data. | ||
message_index_offset: int = field(default=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious why this isn't on the exchange instead. If it were, then I feel this object is more tidy, an aggregation of checkpoints... Also, it would be a lot easier to understand its relationship to messages (as they are defined on the exchange). Food for thought and I'm sure I'm missing something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I totally agree! The Exchange
class should be a clean representation of messages and checkpoints that stay in sync within it's same instance. However, our current implementation of the Exchange
class is a frozen one - this means that variables such as message_index_offset
and total_token_count
cannot be updated without hacky workarounds. Will be having a discussion with the team to hopefully change this.
This PR introduces a large rewrite of checkpoints, ensuring they are always in sync with messages on the exchange.
Why?
As the exchange grows in side, it is increasingly important to trim its beginning (as these messages are less relevant). We also need to ensure the exchange has an accurate representation of the tokens it's used up so far, so we can avoid:
This PR makes managing checkpoints and messages much easier, introducing the following methods on the exchange:
.pop_first_message()
.pop_last_message()
.pop_first_checkpoint()
.pop_last_checkpoint()
I have also included a large number of tests that cover the different ways that messages and checkpoints could fall out of sync. Now that these are covered, we hope to not run into an regressions in this behavior.
Note to reviewers
There are a couple of sections marked with
TODO
comments. These have been intentionally left behind as they are not essential features in order to be backwards compatible. We can instead focus on implementing these over the coming weeks.