-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[RunInference] Add content-aware dynamic batching via element_size_fn… #37428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2133,5 +2133,63 @@ def request(self, batch, model, inference_args=None): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_handler.run_inference([1], FakeModel()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class FakeModelHandlerForSizing(base.ModelHandler[int, int, FakeModel]): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """A ModelHandler used to test element sizing behavior.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_batch_size: int = 10, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_batch_weight: Optional[int] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| element_size_fn=None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._max_batch_size = max_batch_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._max_batch_weight = max_batch_weight | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._element_size_fn = element_size_fn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def load_model(self) -> FakeModel: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return FakeModel() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def run_inference(self, batch, model, inference_args=None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [model.predict(x) for x in batch] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def batch_elements_kwargs(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kwargs = {'max_batch_size': self._max_batch_size} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self._max_batch_weight is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kwargs['max_batch_weight'] = self._max_batch_weight | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self._element_size_fn: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kwargs['element_size_fn'] = self._element_size_fn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return kwargs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+2138
to
+2159
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is probably a good suggestion |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class RunInferenceSizeTest(unittest.TestCase): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Tests for ModelHandler.batch_elements_kwargs with element_size_fn.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_kwargs_are_passed_correctly(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Adds element_size_fn without clobbering existing kwargs.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def size_fn(x): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return 10 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sized_handler = FakeModelHandlerForSizing( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_batch_size=20, max_batch_weight=100, element_size_fn=size_fn) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kwargs = sized_handler.batch_elements_kwargs() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.assertEqual(kwargs['max_batch_size'], 20) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.assertEqual(kwargs['max_batch_weight'], 100) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.assertIn('element_size_fn', kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.assertEqual(kwargs['element_size_fn'](1), 10) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_sizing_with_edge_cases(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Allows extreme values from element_size_fn.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| zero_size_fn = lambda x: 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sized_handler = FakeModelHandlerForSizing( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_batch_size=1, element_size_fn=zero_size_fn) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kwargs = sized_handler.batch_elements_kwargs() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.assertEqual(kwargs['element_size_fn'](999), 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| large_size_fn = lambda x: 1000000 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sized_handler = FakeModelHandlerForSizing( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_batch_size=1, element_size_fn=large_size_fn) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kwargs = sized_handler.batch_elements_kwargs() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.assertEqual(kwargs['element_size_fn'](1), 1000000) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if __name__ == '__main__': | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| unittest.main() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.