diff --git a/epochStart/bootstrap/disabled/disabledAccountsAdapter.go b/epochStart/bootstrap/disabled/disabledAccountsAdapter.go index ce928b21fca..1606b50af69 100644 --- a/epochStart/bootstrap/disabled/disabledAccountsAdapter.go +++ b/epochStart/bootstrap/disabled/disabledAccountsAdapter.go @@ -70,6 +70,11 @@ func (a *accountsAdapter) Commit() ([]byte, error) { return nil, nil } +// CommitInMemory - +func (a *accountsAdapter) CommitInMemory() ([]byte, error) { + return nil, nil +} + // JournalLen - func (a *accountsAdapter) JournalLen() int { return 0 diff --git a/process/block/baseProcess.go b/process/block/baseProcess.go index d336ae476dc..f373701727a 100644 --- a/process/block/baseProcess.go +++ b/process/block/baseProcess.go @@ -49,6 +49,15 @@ import ( const ( cleanupHeadersDelta = 5 + + // defaultSyncCommitInterval defines how many blocks to process before committing to disk during sync. + // Setting to 0 disables the optimization (commits every block). + // Higher values improve sync speed but increase memory usage and data loss risk on crash. + defaultSyncCommitInterval = uint64(10) + + // syncThresholdNonces defines how many nonces behind the network the node must be + // to be considered "syncing" and use the commit interval optimization. + syncThresholdNonces = uint64(50) ) var log = logger.GetOrCreate("process/block") @@ -146,6 +155,11 @@ type baseProcessor struct { executionManager process.ExecutionManager txExecutionOrderHandler common.TxExecutionOrderHandler aotSelector process.AOTTransactionSelector + + // Sync commit optimization fields + syncCommitInterval uint64 + blocksSinceLastCommit uint64 + mutSyncCommit sync.Mutex } type bootStorerDataArgs struct { @@ -236,6 +250,7 @@ func NewBaseProcessor(arguments ArgBaseProcessor) (*baseProcessor, error) { executionManager: arguments.ExecutionManager, txExecutionOrderHandler: arguments.TxExecutionOrderHandler, aotSelector: arguments.AOTSelector, + syncCommitInterval: defaultSyncCommitInterval, } err = base.OnExecutedBlock(genesisHdr, genesisHdr.GetRootHash()) @@ -2117,21 +2132,90 @@ func (bp *baseProcessor) RevertAccountsDBToSnapshot(accountsSnapshot map[state.A func (bp *baseProcessor) commitState(headerHandler data.HeaderHandler) error { startTime := time.Now() + inMemory := false defer func() { elapsedTime := time.Since(startTime) log.Debug("elapsed time to commit accounts state", "time [s]", elapsedTime, "header nonce", headerHandler.GetNonce(), + "in memory", inMemory, ) }() if headerHandler.IsStartOfEpochBlock() { + bp.resetSyncCommitCounter() return bp.commitInLastEpoch(headerHandler.GetEpoch()) } + // Check if we should use sync commit optimization + if bp.shouldUseSyncCommitOptimization(headerHandler) { + inMemory = true + return bp.commitInMemory() + } + return bp.commit() } +// shouldUseSyncCommitOptimization checks if the node is syncing and should use +// the in-memory commit optimization to improve sync speed. +func (bp *baseProcessor) shouldUseSyncCommitOptimization(headerHandler data.HeaderHandler) bool { + bp.mutSyncCommit.Lock() + defer bp.mutSyncCommit.Unlock() + + // Disabled if syncCommitInterval is 0 + if bp.syncCommitInterval == 0 { + return false + } + + // Check if node is syncing (far behind the network) + probableHighestNonce := bp.forkDetector.ProbableHighestNonce() + currentNonce := headerHandler.GetNonce() + noncesBehind := uint64(0) + if probableHighestNonce > currentNonce { + noncesBehind = probableHighestNonce - currentNonce + } + + // Not syncing - commit every block + if noncesBehind < syncThresholdNonces { + bp.blocksSinceLastCommit = 0 + return false + } + + // Syncing - use commit interval + bp.blocksSinceLastCommit++ + + // Time for a full commit + if bp.blocksSinceLastCommit >= bp.syncCommitInterval { + bp.blocksSinceLastCommit = 0 + log.Debug("sync commit optimization: performing full commit", + "nonces_behind", noncesBehind, + "interval", bp.syncCommitInterval) + return false + } + + log.Debug("sync commit optimization: using in-memory commit", + "nonces_behind", noncesBehind, + "blocks_since_commit", bp.blocksSinceLastCommit) + return true +} + +func (bp *baseProcessor) resetSyncCommitCounter() { + bp.mutSyncCommit.Lock() + bp.blocksSinceLastCommit = 0 + bp.mutSyncCommit.Unlock() +} + +func (bp *baseProcessor) commitInMemory() error { + for key := range bp.accountsDB { + _, err := bp.accountsDB[key].CommitInMemory() + if err != nil { + return err + } + } + + return nil +} + func (bp *baseProcessor) commitInLastEpoch(currentEpoch uint32) error { lastEpoch := uint32(0) if currentEpoch > 0 { diff --git a/process/block/baseProcess_test.go b/process/block/baseProcess_test.go index 7f3f8c59026..b157be19be5 100644 --- a/process/block/baseProcess_test.go +++ b/process/block/baseProcess_test.go @@ -5604,3 +5604,266 @@ func TestBaseProcessor_Close(t *testing.T) { require.NoError(t, bp.Close()) } + +// ------- Sync Commit Optimization Tests + +func TestBaseProcessor_ShouldUseSyncCommitOptimization_DisabledWhenIntervalZero(t *testing.T) { + t.Parallel() + + coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() + arguments := createArgBaseProcessor(coreComponents, dataComponents, bootstrapComponents, statusComponents) + arguments.ForkDetector = &mock.ForkDetectorMock{ + ProbableHighestNonceCalled: func() uint64 { + return 1000 // Far behind + }, + } + + sp, err := blproc.NewShardProcessor(blproc.ArgShardProcessor{ArgBaseProcessor: arguments}) + require.NoError(t, err) + + // Disable the optimization + sp.SetSyncCommitIntervalForTest(0) + + header := &block.Header{ + Nonce: 100, // Far behind network + } + + // Should return false because interval is 0 + result := sp.ShouldUseSyncCommitOptimization(header) + assert.False(t, result, "should return false when sync commit interval is 0") +} + +func TestBaseProcessor_ShouldUseSyncCommitOptimization_DisabledWhenNotSyncing(t *testing.T) { + t.Parallel() + + coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() + arguments := createArgBaseProcessor(coreComponents, dataComponents, bootstrapComponents, statusComponents) + arguments.ForkDetector = &mock.ForkDetectorMock{ + ProbableHighestNonceCalled: func() uint64 { + return 105 // Only 5 ahead + }, + } + + sp, err := blproc.NewShardProcessor(blproc.ArgShardProcessor{ArgBaseProcessor: arguments}) + require.NoError(t, err) + + sp.SetSyncCommitIntervalForTest(10) + + header := &block.Header{ + Nonce: 100, // Not far behind network (less than syncThresholdNonces) + } + + // Should return false because node is not syncing + result := sp.ShouldUseSyncCommitOptimization(header) + assert.False(t, result, "should return false when not syncing (nonces behind < threshold)") +} + +func TestBaseProcessor_ShouldUseSyncCommitOptimization_UsesInMemoryWhenSyncing(t *testing.T) { + t.Parallel() + + coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() + arguments := createArgBaseProcessor(coreComponents, dataComponents, bootstrapComponents, statusComponents) + arguments.ForkDetector = &mock.ForkDetectorMock{ + ProbableHighestNonceCalled: func() uint64 { + return 200 // Far ahead (100 nonces behind) + }, + } + + sp, err := blproc.NewShardProcessor(blproc.ArgShardProcessor{ArgBaseProcessor: arguments}) + require.NoError(t, err) + + sp.SetSyncCommitIntervalForTest(10) + + header := &block.Header{ + Nonce: 100, + } + + // First call should return true (in-memory commit) + result := sp.ShouldUseSyncCommitOptimization(header) + assert.True(t, result, "should return true for first block when syncing") + assert.Equal(t, uint64(1), sp.GetBlocksSinceLastCommit()) +} + +func TestBaseProcessor_ShouldUseSyncCommitOptimization_FullCommitAtInterval(t *testing.T) { + t.Parallel() + + coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() + arguments := createArgBaseProcessor(coreComponents, dataComponents, bootstrapComponents, statusComponents) + arguments.ForkDetector = &mock.ForkDetectorMock{ + ProbableHighestNonceCalled: func() uint64 { + return 200 // Far ahead + }, + } + + sp, err := blproc.NewShardProcessor(blproc.ArgShardProcessor{ArgBaseProcessor: arguments}) + require.NoError(t, err) + + sp.SetSyncCommitIntervalForTest(5) + + header := &block.Header{ + Nonce: 100, + } + + // Call 5 times - first 4 should return true (in-memory), 5th should return false (full commit) + for i := 0; i < 4; i++ { + result := sp.ShouldUseSyncCommitOptimization(header) + assert.True(t, result, "should return true for block %d", i+1) + } + + // 5th call should trigger full commit + result := sp.ShouldUseSyncCommitOptimization(header) + assert.False(t, result, "should return false at interval (full commit)") + assert.Equal(t, uint64(0), sp.GetBlocksSinceLastCommit(), "counter should be reset after full commit") +} + +func TestBaseProcessor_SetSyncCommitInterval(t *testing.T) { + t.Parallel() + + coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() + arguments := createArgBaseProcessor(coreComponents, dataComponents, bootstrapComponents, statusComponents) + + sp, err := blproc.NewShardProcessor(blproc.ArgShardProcessor{ArgBaseProcessor: arguments}) + require.NoError(t, err) + + // Check default value + assert.Equal(t, blproc.DefaultSyncCommitInterval, sp.GetSyncCommitInterval()) + + // Set new value + sp.SetSyncCommitIntervalForTest(20) + assert.Equal(t, uint64(20), sp.GetSyncCommitInterval()) + + // Set to 0 to disable + sp.SetSyncCommitIntervalForTest(0) + assert.Equal(t, uint64(0), sp.GetSyncCommitInterval()) +} + +func TestBaseProcessor_ResetSyncCommitCounter(t *testing.T) { + t.Parallel() + + coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() + arguments := createArgBaseProcessor(coreComponents, dataComponents, bootstrapComponents, statusComponents) + arguments.ForkDetector = &mock.ForkDetectorMock{ + ProbableHighestNonceCalled: func() uint64 { + return 200 // Far ahead + }, + } + + sp, err := blproc.NewShardProcessor(blproc.ArgShardProcessor{ArgBaseProcessor: arguments}) + require.NoError(t, err) + + sp.SetSyncCommitIntervalForTest(10) + + header := &block.Header{ + Nonce: 100, + } + + // Make a few calls to increment counter + for i := 0; i < 3; i++ { + _ = sp.ShouldUseSyncCommitOptimization(header) + } + assert.Equal(t, uint64(3), sp.GetBlocksSinceLastCommit()) + + // Reset counter + sp.ResetSyncCommitCounter() + assert.Equal(t, uint64(0), sp.GetBlocksSinceLastCommit()) +} + +func TestBaseProcessor_CommitInMemory(t *testing.T) { + t.Parallel() + + commitInMemoryCalled := false + commitCalled := false + + coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() + arguments := createArgBaseProcessor(coreComponents, dataComponents, bootstrapComponents, statusComponents) + + accountsDb := make(map[state.AccountsDbIdentifier]state.AccountsAdapter) + accountsDb[state.UserAccountsState] = &stateMock.AccountsStub{ + CommitInMemoryCalled: func() ([]byte, error) { + commitInMemoryCalled = true + return []byte("rootHash"), nil + }, + CommitCalled: func() ([]byte, error) { + commitCalled = true + return []byte("rootHash"), nil + }, + RecreateTrieIfNeededCalled: func(options common.RootHashHolder) error { + return nil + }, + } + arguments.AccountsDB = accountsDb + + sp, err := blproc.NewShardProcessor(blproc.ArgShardProcessor{ArgBaseProcessor: arguments}) + require.NoError(t, err) + + err = sp.CommitInMemoryForTest() + assert.NoError(t, err) + assert.True(t, commitInMemoryCalled, "CommitInMemory should be called") + assert.False(t, commitCalled, "Commit should not be called") +} + +func TestBaseProcessor_SyncCommitOptimization_Constants(t *testing.T) { + t.Parallel() + + // Verify constants are set to expected values + assert.Equal(t, uint64(50), blproc.SyncThresholdNonces) + assert.Equal(t, uint64(10), blproc.DefaultSyncCommitInterval) +} + +func TestBaseProcessor_SyncCommitOptimization_ConcurrentAccess(t *testing.T) { + t.Parallel() + + // This test verifies that concurrent access to syncCommitInterval is thread-safe. + // The fix moved mutex locking to the beginning of shouldUseSyncCommitOptimization + // to prevent data races when reading syncCommitInterval. + + coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() + arguments := createArgBaseProcessor(coreComponents, dataComponents, bootstrapComponents, statusComponents) + arguments.ForkDetector = &mock.ForkDetectorMock{ + ProbableHighestNonceCalled: func() uint64 { + return 200 // Far behind - will trigger sync optimization + }, + } + + sp, err := blproc.NewShardProcessor(blproc.ArgShardProcessor{ArgBaseProcessor: arguments}) + require.NoError(t, err) + + sp.SetSyncCommitIntervalForTest(10) + + header := &block.Header{ + Nonce: 100, + } + + // Run concurrent operations + var wg sync.WaitGroup + numGoroutines := 10 + numIterations := 100 + + // Concurrent readers calling shouldUseSyncCommitOptimization + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < numIterations; j++ { + _ = sp.ShouldUseSyncCommitOptimization(header) + } + }() + } + + // Concurrent writers calling SetSyncCommitIntervalForTest + for i := 0; i < numGoroutines/2; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < numIterations; j++ { + sp.SetSyncCommitIntervalForTest(uint64(10 + id)) + } + }(i) + } + + // Wait for all goroutines to complete + wg.Wait() + + // If we reach here without a race condition, the test passes + // Run with -race flag to verify: go test -race -run TestBaseProcessor_SyncCommitOptimization_ConcurrentAccess +} diff --git a/process/block/export_test.go b/process/block/export_test.go index 5cc656c9a40..a2f17387e7a 100644 --- a/process/block/export_test.go +++ b/process/block/export_test.go @@ -53,6 +53,12 @@ import ( // UsedShardHeadersInfo - type UsedShardHeadersInfo = usedShardHeadersInfo +// SyncThresholdNonces is exported for testing +const SyncThresholdNonces = syncThresholdNonces + +// DefaultSyncCommitInterval is exported for testing +const DefaultSyncCommitInterval = defaultSyncCommitInterval + // EpochStartDataWrapper - type EpochStartDataWrapper = epochStartDataWrapper @@ -1162,3 +1168,40 @@ func (bp *baseProcessor) ExcludeRevertedExecutionResultsForHeader( ) []data.BaseExecutionResultHandler { return bp.excludeRevertedExecutionResultsForHeader(header, pendingExecutionResults) } + +// ShouldUseSyncCommitOptimization - +func (bp *baseProcessor) ShouldUseSyncCommitOptimization(headerHandler data.HeaderHandler) bool { + return bp.shouldUseSyncCommitOptimization(headerHandler) +} + +// SetSyncCommitIntervalForTest sets the commit interval for sync optimization (test only). +// Set to 0 to disable the optimization (commit every block). +func (bp *baseProcessor) SetSyncCommitIntervalForTest(interval uint64) { + bp.mutSyncCommit.Lock() + bp.syncCommitInterval = interval + bp.mutSyncCommit.Unlock() +} + +// GetSyncCommitInterval - +func (bp *baseProcessor) GetSyncCommitInterval() uint64 { + bp.mutSyncCommit.Lock() + defer bp.mutSyncCommit.Unlock() + return bp.syncCommitInterval +} + +// GetBlocksSinceLastCommit - +func (bp *baseProcessor) GetBlocksSinceLastCommit() uint64 { + bp.mutSyncCommit.Lock() + defer bp.mutSyncCommit.Unlock() + return bp.blocksSinceLastCommit +} + +// ResetSyncCommitCounter - +func (bp *baseProcessor) ResetSyncCommitCounter() { + bp.resetSyncCommitCounter() +} + +// CommitInMemoryForTest - +func (bp *baseProcessor) CommitInMemoryForTest() error { + return bp.commitInMemory() +} diff --git a/process/transactionEvaluator/simulationAccountsDB.go b/process/transactionEvaluator/simulationAccountsDB.go index edad278c798..65eb23fd4c5 100644 --- a/process/transactionEvaluator/simulationAccountsDB.go +++ b/process/transactionEvaluator/simulationAccountsDB.go @@ -112,6 +112,11 @@ func (r *simulationAccountsDB) Commit() ([]byte, error) { return nil, nil } +// CommitInMemory won't do anything as write operations are disabled on this component +func (r *simulationAccountsDB) CommitInMemory() ([]byte, error) { + return nil, nil +} + // JournalLen will call the original accounts' function with the same name func (r *simulationAccountsDB) JournalLen() int { return r.originalAccounts.JournalLen() diff --git a/state/accountsDB.go b/state/accountsDB.go index 7d21e27669d..b82d3d008a9 100644 --- a/state/accountsDB.go +++ b/state/accountsDB.go @@ -915,6 +915,51 @@ func (adb *AccountsDB) Commit() ([]byte, error) { return adb.commit() } +// CommitInMemory computes root hashes and clears journal without persisting to disk. +// This is useful during sync to avoid expensive disk writes on every block. +// A full Commit() should be called periodically to persist the accumulated changes. +func (adb *AccountsDB) CommitInMemory() ([]byte, error) { + adb.mutOp.Lock() + defer func() { + adb.mutOp.Unlock() + adb.loadCodeMeasurements.resetAndPrint() + }() + + return adb.commitInMemory() +} + +func (adb *AccountsDB) commitInMemory() ([]byte, error) { + log.Trace("accountsDB.CommitInMemory started") + adb.entries = make([]JournalEntry, 0) + + // Compute root hashes for all data tries without persisting + // Note: dirty nodes stay in memory until full Commit() is called + dataTries := adb.dataTries.GetAll() + for i := 0; i < len(dataTries); i++ { + _, err := dataTries[i].RootHash() + if err != nil { + return nil, err + } + } + + // Compute root hash for main trie without persisting + newRoot, err := adb.mainTrie.RootHash() + if err != nil { + log.Trace("accountsDB.CommitInMemory ended", "error", err.Error()) + return nil, err + } + + adb.lastRootHash = newRoot + // Note: We intentionally do NOT clear obsoleteDataTrieHashes here. + // They will accumulate and be processed during the next full Commit(). + // We also skip markForEviction and stateAccessesCollector.CommitCollectedAccesses + // as these will be handled during the periodic full commit. + + log.Trace("accountsDB.CommitInMemory ended", "root hash", newRoot) + + return newRoot, nil +} + func (adb *AccountsDB) commit() ([]byte, error) { log.Trace("accountsDB.Commit started") adb.entries = make([]JournalEntry, 0) diff --git a/state/accountsDBApi.go b/state/accountsDBApi.go index 76ee0d506d8..fd49b9f6e2e 100644 --- a/state/accountsDBApi.go +++ b/state/accountsDBApi.go @@ -133,6 +133,11 @@ func (accountsDB *accountsDBApi) Commit() ([]byte, error) { return nil, ErrOperationNotPermitted } +// CommitInMemory is not a permitted operation in this implementation and thus, will return an error +func (accountsDB *accountsDBApi) CommitInMemory() ([]byte, error) { + return nil, ErrOperationNotPermitted +} + // JournalLen will always return 0 func (accountsDB *accountsDBApi) JournalLen() int { return 0 diff --git a/state/accountsDBApiWithHistory.go b/state/accountsDBApiWithHistory.go index e39a24ea7c7..6f1d00b86e7 100644 --- a/state/accountsDBApiWithHistory.go +++ b/state/accountsDBApiWithHistory.go @@ -74,6 +74,11 @@ func (accountsDB *accountsDBApiWithHistory) Commit() ([]byte, error) { return nil, ErrOperationNotPermitted } +// CommitInMemory is not a permitted operation in this implementation and thus, will return an error +func (accountsDB *accountsDBApiWithHistory) CommitInMemory() ([]byte, error) { + return nil, ErrOperationNotPermitted +} + // JournalLen will always return 0 func (accountsDB *accountsDBApiWithHistory) JournalLen() int { return 0 diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index 2397356b958..05f7f198958 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -3464,3 +3464,263 @@ func testAccountLoadInParallel( wg.Wait() } + +// ------- CommitInMemory + +func TestAccountsDB_CommitInMemoryShouldComputeRootHashWithoutPersisting(t *testing.T) { + t.Parallel() + + commitCalled := false + rootHashCalled := 0 + expectedRootHash := []byte("expectedRootHash") + + marshaller := &marshallerMock.MarshalizerMock{} + serializedAccount, _ := marshaller.Marshal(stateMock.AccountWrapMock{}) + trieStub := trieMock.TrieStub{ + CommitCalled: func() error { + commitCalled = true + return nil + }, + RootCalled: func() ([]byte, error) { + rootHashCalled++ + return expectedRootHash, nil + }, + GetCalled: func(_ []byte) ([]byte, uint32, error) { + return serializedAccount, 0, nil + }, + GetStorageManagerCalled: func() common.StorageManager { + return &storageManager.StorageManagerStub{} + }, + } + + adb := generateAccountDBFromTrie(&trieStub) + + accnt, _ := adb.LoadAccount(make([]byte, 32)) + _ = adb.SaveAccount(accnt) + + rootHash, err := adb.CommitInMemory() + assert.Nil(t, err) + assert.Equal(t, expectedRootHash, rootHash) + assert.False(t, commitCalled, "Commit should not be called in CommitInMemory") + assert.Equal(t, 1, rootHashCalled) +} + +func TestAccountsDB_CommitInMemoryShouldClearJournal(t *testing.T) { + t.Parallel() + + marshaller := &marshallerMock.MarshalizerMock{} + serializedAccount, _ := marshaller.Marshal(stateMock.AccountWrapMock{}) + trieStub := trieMock.TrieStub{ + RootCalled: func() ([]byte, error) { + return []byte("rootHash"), nil + }, + GetCalled: func(_ []byte) ([]byte, uint32, error) { + return serializedAccount, 0, nil + }, + UpdateCalled: func(key, value []byte) error { + return nil + }, + GetStorageManagerCalled: func() common.StorageManager { + return &storageManager.StorageManagerStub{} + }, + } + + adb := generateAccountDBFromTrie(&trieStub) + + accnt, err := adb.LoadAccount(make([]byte, 32)) + require.NoError(t, err) + err = adb.SaveAccount(accnt) + require.NoError(t, err) + + // Journal should have entries after SaveAccount + assert.True(t, adb.JournalLen() > 0) + + _, err = adb.CommitInMemory() + assert.Nil(t, err) + + // Journal should be cleared after CommitInMemory + assert.Equal(t, 0, adb.JournalLen()) +} + +func TestAccountsDB_CommitInMemoryErrorOnRootHashShouldFail(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("root hash error") + trieStub := trieMock.TrieStub{ + RootCalled: func() ([]byte, error) { + return nil, expectedErr + }, + GetStorageManagerCalled: func() common.StorageManager { + return &storageManager.StorageManagerStub{} + }, + } + + adb := generateAccountDBFromTrie(&trieStub) + + _, err := adb.CommitInMemory() + assert.Equal(t, expectedErr, err) +} + +func TestAccountsDB_CommitInMemoryFollowedByCommitShouldPersist(t *testing.T) { + t.Parallel() + + commitCalled := false + expectedRootHash := []byte("expectedRootHash") + + marshaller := &marshallerMock.MarshalizerMock{} + serializedAccount, _ := marshaller.Marshal(stateMock.AccountWrapMock{}) + trieStub := trieMock.TrieStub{ + CommitCalled: func() error { + commitCalled = true + return nil + }, + RootCalled: func() ([]byte, error) { + return expectedRootHash, nil + }, + GetCalled: func(_ []byte) ([]byte, uint32, error) { + return serializedAccount, 0, nil + }, + UpdateCalled: func(key, value []byte) error { + return nil + }, + GetStorageManagerCalled: func() common.StorageManager { + return &storageManager.StorageManagerStub{} + }, + } + + adb := generateAccountDBFromTrie(&trieStub) + + accnt, err := adb.LoadAccount(make([]byte, 32)) + require.NoError(t, err) + err = adb.SaveAccount(accnt) + require.NoError(t, err) + + // First, do in-memory commit + rootHash1, err := adb.CommitInMemory() + assert.Nil(t, err) + assert.Equal(t, expectedRootHash, rootHash1) + assert.False(t, commitCalled) + + // Then do a full commit - this should persist + rootHash2, err := adb.Commit() + assert.Nil(t, err) + assert.Equal(t, expectedRootHash, rootHash2) + assert.True(t, commitCalled, "Commit should be called on full Commit()") +} + +func TestAccountsDB_CommitInMemoryShouldNotResetDataTries(t *testing.T) { + t.Parallel() + + // This test verifies that CommitInMemory does NOT reset dataTries. + // Resetting data tries in CommitInMemory would cause data loss because + // when an account is loaded after in-memory commit, the trie cannot be + // recreated from the database (since it was never persisted). + + dataTrieRootHash := []byte("dataTrieRootHash") + dataTrieRootHashCalled := 0 + dataTrieResetCalled := false + + marshaller := &marshallerMock.MarshalizerMock{} + serializedAccount, _ := marshaller.Marshal(stateMock.AccountWrapMock{}) + + mainTrieStub := trieMock.TrieStub{ + RootCalled: func() ([]byte, error) { + return []byte("mainRootHash"), nil + }, + GetCalled: func(_ []byte) ([]byte, uint32, error) { + return serializedAccount, 0, nil + }, + UpdateCalled: func(key, value []byte) error { + return nil + }, + GetStorageManagerCalled: func() common.StorageManager { + return &storageManager.StorageManagerStub{} + }, + } + + dataTrieStub := &trieMock.TrieStub{ + RootCalled: func() ([]byte, error) { + dataTrieRootHashCalled++ + return dataTrieRootHash, nil + }, + } + + adb := generateAccountDBFromTrie(&mainTrieStub) + + // Add a data trie to simulate an account with storage + adb.SetDataTries(state.NewDataTriesHolder()) + adb.DataTries().Put([]byte("address"), dataTrieStub) + + // Create a custom holder to track if Reset was called + originalHolder := adb.DataTries() + customHolder := &dataTriesHolderStub{ + getAllCalled: func() []common.Trie { + return originalHolder.GetAll() + }, + putCalled: func(key []byte, tr common.Trie) { + originalHolder.Put(key, tr) + }, + getCalled: func(key []byte) common.Trie { + return originalHolder.Get(key) + }, + resetCalled: func() { + dataTrieResetCalled = true + }, + } + adb.SetDataTries(customHolder) + + // Do in-memory commit + _, err := adb.CommitInMemory() + require.NoError(t, err) + + // Verify dataTries.Reset() was NOT called + assert.False(t, dataTrieResetCalled, "dataTries.Reset() should NOT be called in CommitInMemory") + + // Verify RootHash was called on data tries (they are processed but not reset) + assert.Equal(t, 1, dataTrieRootHashCalled, "data trie RootHash should be called") +} + +// dataTriesHolderStub is a test stub for tracking calls to the data tries holder +type dataTriesHolderStub struct { + getAllCalled func() []common.Trie + putCalled func(key []byte, tr common.Trie) + replaceCalled func(key []byte, tr common.Trie) + getCalled func(key []byte) common.Trie + resetCalled func() +} + +func (d *dataTriesHolderStub) GetAll() []common.Trie { + if d.getAllCalled != nil { + return d.getAllCalled() + } + return nil +} + +func (d *dataTriesHolderStub) Put(key []byte, tr common.Trie) { + if d.putCalled != nil { + d.putCalled(key, tr) + } +} + +func (d *dataTriesHolderStub) Replace(key []byte, tr common.Trie) { + if d.replaceCalled != nil { + d.replaceCalled(key, tr) + } +} + +func (d *dataTriesHolderStub) Get(key []byte) common.Trie { + if d.getCalled != nil { + return d.getCalled(key) + } + return nil +} + +func (d *dataTriesHolderStub) Reset() { + if d.resetCalled != nil { + d.resetCalled() + } +} + +func (d *dataTriesHolderStub) IsInterfaceNil() bool { + return d == nil +} diff --git a/state/export_test.go b/state/export_test.go index bbc209312e4..cca1a3a61c4 100644 --- a/state/export_test.go +++ b/state/export_test.go @@ -106,3 +106,13 @@ type AccountHandlerWithDataTrieMigrationStatus interface { vmcommon.AccountHandler IsDataTrieMigrated() (bool, error) } + +// DataTries returns the data tries holder for testing +func (adb *AccountsDB) DataTries() common.TriesHolder { + return adb.dataTries +} + +// SetDataTries sets the data tries holder for testing +func (adb *AccountsDB) SetDataTries(holder common.TriesHolder) { + adb.dataTries = holder +} diff --git a/state/interface.go b/state/interface.go index 2e9893f1cc6..57126b8e8c5 100644 --- a/state/interface.go +++ b/state/interface.go @@ -76,6 +76,7 @@ type AccountsAdapter interface { RemoveAccount(address []byte) error CommitInEpoch(currentEpoch uint32, epochToCommit uint32) ([]byte, error) Commit() ([]byte, error) + CommitInMemory() ([]byte, error) JournalLen() int RevertToSnapshot(snapshot int) error GetCode(codeHash []byte) []byte diff --git a/testscommon/state/accountsAdapterStub.go b/testscommon/state/accountsAdapterStub.go index aa3d41e1355..9bc0f6ee33d 100644 --- a/testscommon/state/accountsAdapterStub.go +++ b/testscommon/state/accountsAdapterStub.go @@ -21,6 +21,7 @@ type AccountsStub struct { SaveAccountCalled func(account vmcommon.AccountHandler) error RemoveAccountCalled func(addressContainer []byte) error CommitCalled func() ([]byte, error) + CommitInMemoryCalled func() ([]byte, error) CommitInEpochCalled func(uint32, uint32) ([]byte, error) JournalLenCalled func() int RevertToSnapshotCalled func(snapshot int) error @@ -124,6 +125,15 @@ func (as *AccountsStub) Commit() ([]byte, error) { return nil, errNotImplemented } +// CommitInMemory - +func (as *AccountsStub) CommitInMemory() ([]byte, error) { + if as.CommitInMemoryCalled != nil { + return as.CommitInMemoryCalled() + } + + return nil, errNotImplemented +} + // GetExistingAccount - func (as *AccountsStub) GetExistingAccount(addressContainer []byte) (vmcommon.AccountHandler, error) { if as.GetExistingAccountCalled != nil {