feat: refactor compute_input_stats across all backends to eliminate code duplication #4932
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR addresses the significant code duplication in
compute_input_statsmethods across descriptor implementations in all backends (dpmodel, PyTorch, and Paddle). The issue was that nearly identical logic (~40 lines) was repeated in every descriptor class, with only minor backend-specific differences in tensor assignment.Problem
Almost all descriptor classes implemented
compute_input_statsin exactly the same way:This pattern was repeated across ~50+ files, making maintenance difficult and error-prone.
Solution
Created backend-specific mixin classes that extract the common logic:
deepmd.dpmodel.common.ComputeInputStatsMixin- Array API compatible implementationdeepmd.pt.common.ComputeInputStatsMixin- PyTorch-specific implementationdeepmd.pd.common.ComputeInputStatsMixin- Paddle-specific implementationEach mixin provides:
compute_input_stats()method with shared logic_set_stat_mean_and_stddev()method for backend-specific tensor assignmentget_stats()methodUsage
Descriptor classes now inherit from the mixin and implement only the backend-specific part:
Benefits
Files Changed
Updated descriptor classes:
se_e2_a.py,repformers.pyse_a.py,repformers.pyrepformers.pyNew common modules:
deepmd/dpmodel/common.py(extended)deepmd/pt/common.py(new)deepmd/pd/common.py(new)This refactoring follows the DRY principle and makes the codebase significantly more maintainable while preserving all existing functionality.
Fixes #4732.
💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.