Skip to content

Commit 773aea7

Browse files
committed
Global alpha-beta take 1.
refs #20.
1 parent 05a19a8 commit 773aea7

File tree

2 files changed

+125
-59
lines changed

2 files changed

+125
-59
lines changed

src/AI/AlphaBeta.hs

+124-59
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,18 @@ scoreMove (ScoreMoveInput {..}) = do
7676
let AlphaBeta params rules eval = smiAi
7777
score <- Monitoring.timed "ai.score.move" $ do
7878
let board' = applyMoveActions (pmResult smiMove) smiBoard
79-
score <- doScore rules eval smiCache params smiGameId (opposite smiSide) smiDepth board' smiAlpha smiBeta
79+
score <- doScore rules eval smiCache params smiGameId (opposite smiSide) smiDepth board' smiGlobalInterval smiAlpha smiBeta
8080
`catchError` (\(e :: Error) -> do
8181
$info "doScore: move {}, depth {}: {}" (show smiMove, dpTarget smiDepth, show e)
8282
throwError e
8383
)
8484
$info "Check: {} (depth {}) => {}" (show smiMove, dpTarget smiDepth, show score)
8585
return score
8686

87+
-- restrictInterval smiGlobalInterval smiSide score
8788
return (smiMove, score)
8889

89-
type DepthIterationInput = (AlphaBetaParams, [PossibleMove], Maybe DepthIterationOutput)
90+
type DepthIterationInput = (AlphaBetaParams, TVar (Score, Score), [PossibleMove], Maybe DepthIterationOutput)
9091
type DepthIterationOutput = [(PossibleMove, Score)]
9192
type AiOutput = ([PossibleMove], Score)
9293

@@ -232,17 +233,18 @@ runAI ai@(AlphaBeta params rules eval) handle gameId side board = do
232233
-- return result
233234

234235
depthDriver :: [PossibleMove] -> Checkers DepthIterationOutput
235-
depthDriver moves =
236+
depthDriver moves = do
237+
globalInterval <- liftIO $ atomically $ newTVar (loose, win)
236238
case abBaseTime params of
237239
Nothing -> do
238-
(result, _) <- go (params, moves, Nothing)
240+
(result, _) <- go (params, globalInterval, moves, Nothing)
239241
return result
240-
Just time -> repeatTimed' "runAI" time goTimed (params, moves, Nothing)
242+
Just time -> repeatTimed' "runAI" time goTimed (params, globalInterval, moves, Nothing)
241243

242244
goTimed :: DepthIterationInput
243245
-> Checkers (DepthIterationOutput, Maybe DepthIterationInput)
244-
goTimed (params, moves, prevResult) = do
245-
ret <- tryC $ go (params, moves, prevResult)
246+
goTimed (params, globalInterval, moves, prevResult) = do
247+
ret <- tryC $ go (params, globalInterval, moves, prevResult)
246248
case ret of
247249
Right result -> return result
248250
Left TimeExhaused ->
@@ -253,7 +255,7 @@ runAI ai@(AlphaBeta params rules eval) handle gameId side board = do
253255

254256
go :: DepthIterationInput
255257
-> Checkers (DepthIterationOutput, Maybe DepthIterationInput)
256-
go (params, moves, prevResult) = do
258+
go (params, globalInterval, moves, prevResult) = do
257259
let depth = abDepth params
258260
if length moves <= 1 -- Just one move possible
259261
then do
@@ -276,14 +278,14 @@ runAI ai@(AlphaBeta params rules eval) handle gameId side board = do
276278
dpTarget = min (dpMax dp) (dpTarget dp + 1)
277279
}
278280
| otherwise = dp
279-
result <- widthController True True prevResult moves dp' =<< initInterval
281+
result <- widthController True True prevResult moves dp' globalInterval =<< initInterval
280282
-- In some corner cases, there might be 1 or 2 possible moves,
281283
-- so the timeout would allow us to calculate with very big depth;
282284
-- too big depth does not decide anything in such situations.
283285
if depth < 50
284286
then do
285287
let params' = params {abDepth = depth + 1, abStartDepth = Nothing}
286-
return (result, Just (params', moves, Just result))
288+
return (result, Just (params', globalInterval, moves, Just result))
287289
else return (result, Nothing)
288290

289291
score0 = evalBoard eval First board
@@ -317,22 +319,22 @@ runAI ai@(AlphaBeta params rules eval) handle gameId side board = do
317319
| s > 100 = 5
318320
| otherwise = 2
319321

320-
nextInterval :: (Score, Score) -> (Score, Score)
321-
nextInterval (alpha, beta) =
322+
nextInterval :: Score -> (Score, Score) -> (Score, Score)
323+
nextInterval good (alpha, beta) =
322324
let width = (beta - alpha)
323325
width' = selectScale width `scaleScore` width
324-
alpha' = prevScore alpha
325-
beta' = nextScore beta
326+
alpha' = min good (prevScore alpha)
327+
beta' = max good (nextScore beta)
326328
in if maximize
327329
then (beta', max beta' (beta' + width'))
328330
else (min alpha' (alpha' - width'), alpha')
329331

330-
prevInterval :: (Score, Score) -> (Score, Score)
331-
prevInterval (alpha, beta) =
332+
prevInterval :: Score -> (Score, Score) -> (Score, Score)
333+
prevInterval bad (alpha, beta) =
332334
let width = (beta - alpha)
333335
width' = selectScale width `scaleScore` width
334-
alpha' = prevScore alpha
335-
beta' = nextScore beta
336+
alpha' = min bad (prevScore alpha)
337+
beta' = max bad (nextScore beta)
336338
in if minimize
337339
then (beta', max beta' (beta' + width'))
338340
else (min alpha' (alpha' - width'), alpha')
@@ -342,24 +344,26 @@ runAI ai@(AlphaBeta params rules eval) handle gameId side board = do
342344
-> Maybe DepthIterationOutput -- ^ Results of previous depth iteration
343345
-> [PossibleMove]
344346
-> DepthParams
347+
-> TVar (Score, Score) -- ^ Global (alpha, beta)
345348
-> (Score, Score) -- ^ (Alpha, Beta)
346349
-> Checkers DepthIterationOutput
347-
widthController allowNext allowPrev prevResult moves dp interval@(alpha,beta) =
350+
widthController allowNext allowPrev prevResult moves dp globalInterval localInterval = do
351+
interval@(alpha, beta) <- getRestrictedInterval globalInterval localInterval
348352
if alpha == beta
349353
then do
350354
$info "Empty scores interval: [{}]. We have to think that all moves have this score." (Single alpha)
351355
return [(move, alpha) | move <- moves]
352356
else do
353-
results <- widthIteration prevResult moves dp interval
354-
let (good, badScore, badMoves) = selectBestEdge interval moves results
357+
results <- widthIteration prevResult moves dp globalInterval interval
358+
let (goodScore, good, badScore, badMoves) = selectBestEdge interval moves results
355359
(bestMoves, bestResults) = unzip good
356360
if length badMoves == length moves
357361
then
358362
if allowPrev
359363
then do
360-
let interval' = prevInterval interval
364+
let interval' = prevInterval badScore interval
361365
$info "All moves are `too bad'; consider worse scores interval: [{} - {}]" interval'
362-
widthController False True prevResult badMoves dp interval'
366+
widthController False True prevResult badMoves dp globalInterval interval'
363367
else do
364368
$info "All moves are `too bad' ({}), but we have already checked worse interval; so this is the real score." (Single badScore)
365369
return [(move, badScore) | move <- moves]
@@ -372,15 +376,19 @@ runAI ai@(AlphaBeta params rules eval) handle gameId side board = do
372376
_ ->
373377
if allowNext
374378
then do
375-
let interval'@(alpha',beta') = nextInterval interval
379+
let interval'@(alpha',beta') = nextInterval goodScore interval
376380
$info "Some moves ({} of them) are `too good'; consider better scores interval: [{} - {}]" (length bestMoves, alpha', beta')
377-
widthController True False prevResult bestMoves dp interval'
381+
widthController True False prevResult bestMoves dp globalInterval interval'
378382
else do
379383
$info "Some moves ({} of them) are `too good'; but we have already checked better interval; so this is the real score" (Single $ length bestMoves)
380384
return bestResults
381385

382-
scoreMoves :: [PossibleMove] -> DepthParams -> (Score, Score) -> Checkers [Either Error (PossibleMove, Score)]
383-
scoreMoves moves dp (alpha, beta) = do
386+
scoreMoves :: [PossibleMove]
387+
-> DepthParams
388+
-> TVar (Score, Score) -- ^ Global interval
389+
-> (Score, Score) -- ^ Local interval
390+
-> Checkers [Either Error (PossibleMove, Score)]
391+
scoreMoves moves dp globalInterval (localAlpha, localBeta) = do
384392
let var = aichData handle
385393
let processor = aichProcessor handle
386394
let inputs = [
@@ -392,22 +400,33 @@ runAI ai@(AlphaBeta params rules eval) handle gameId side board = do
392400
smiDepth = dp,
393401
smiBoard = board,
394402
smiMove = move,
395-
smiAlpha = alpha,
396-
smiBeta = beta
403+
smiGlobalInterval = globalInterval,
404+
smiAlpha = localAlpha,
405+
smiBeta = localBeta
397406
} | move <- moves ]
398407
process' processor inputs
399408

400-
scoreMoves' :: [PossibleMove] -> DepthParams -> (Score, Score) -> Checkers DepthIterationOutput
401-
scoreMoves' moves dp (alpha, beta) = do
402-
results <- scoreMoves moves dp (alpha, beta)
409+
scoreMoves' :: [PossibleMove]
410+
-> DepthParams
411+
-> TVar (Score, Score)
412+
-> (Score, Score)
413+
-> Checkers DepthIterationOutput
414+
scoreMoves' moves dp globalInterval localInterval = do
415+
results <- scoreMoves moves dp globalInterval localInterval
403416
case sequence results of
404417
Right result -> return result
405418
Left err -> throwError err
406419

407-
widthIteration :: Maybe DepthIterationOutput -> [PossibleMove] -> DepthParams -> (Score, Score) -> Checkers DepthIterationOutput
408-
widthIteration prevResult moves dp (alpha, beta) = do
420+
widthIteration :: Maybe DepthIterationOutput
421+
-> [PossibleMove]
422+
-> DepthParams
423+
-> TVar (Score, Score)
424+
-> (Score, Score)
425+
-> Checkers DepthIterationOutput
426+
widthIteration prevResult moves dp globalInterval localInterval = do
427+
(alpha, beta) <- getRestrictedInterval globalInterval localInterval
409428
$info "`- Considering scores interval: [{} - {}], depth = {}" (alpha, beta, dpTarget dp)
410-
results <- scoreMoves moves dp (alpha, beta)
429+
results <- scoreMoves moves dp globalInterval (alpha, beta)
411430
joinResults prevResult results
412431

413432
joinResults :: Maybe DepthIterationOutput -> [Either Error (PossibleMove, Score)] -> Checkers DepthIterationOutput
@@ -426,9 +445,16 @@ runAI ai@(AlphaBeta params rules eval) handle gameId side board = do
426445

427446
selectBestEdge (alpha, beta) moves results =
428447
let (good, bad) = if maximize then (beta, alpha) else (alpha, beta)
429-
goodResults = [(move, (goodMoves, score)) | (move, (goodMoves, score)) <- zip moves results, score == good]
430-
badResults = [move | (move, (_, score)) <- zip moves results, score == bad]
431-
in (goodResults, bad, badResults)
448+
goodResults = [(move, (goodMoves, score)) | (move, (goodMoves, score)) <- zip moves results, score >= good]
449+
badResults = [move | (move, (_, score)) <- zip moves results, score <= bad]
450+
scores = map snd results
451+
badScore = if maximize
452+
then minimum scores
453+
else maximum scores
454+
goodScore = if maximize
455+
then maximum scores
456+
else minimum scores
457+
in (goodScore, goodResults, bad, badResults)
432458

433459
select :: DepthIterationOutput -> Checkers AiOutput
434460
select pairs = do
@@ -447,10 +473,11 @@ doScore :: (GameRules rules, Evaluator eval)
447473
-> Side
448474
-> DepthParams
449475
-> Board
476+
-> TVar (Score, Score)
450477
-> Score -- ^ Alpha
451478
-> Score -- ^ Beta
452479
-> Checkers Score
453-
doScore rules eval var params gameId side dp board alpha beta = do
480+
doScore rules eval var params gameId side dp board globalInterval alpha beta = do
454481
initState <- mkInitState
455482
out <- evalStateT (cachedScoreAB var params input) initState
456483
return $ soScore out
@@ -461,13 +488,14 @@ doScore rules eval var params gameId side dp board alpha beta = do
461488
let timeout = case abBaseTime params of
462489
Nothing -> Nothing
463490
Just sec -> Just $ TimeSpec (fromIntegral sec) 0
464-
return $ ScoreState rules eval gameId [loose] now timeout
491+
return $ ScoreState rules eval gameId globalInterval [loose] now timeout
465492

466493
-- | State of ScoreM monad.
467494
data ScoreState rules eval = ScoreState {
468495
ssRules :: rules
469496
, ssEvaluator :: eval
470497
, ssGameId :: GameId
498+
, ssGlobalInterval :: TVar (Score, Score)
471499
, ssBestScores :: [Score] -- ^ At each level of depth-first search, there is own "best score"
472500
, ssStartTime :: TimeSpec -- ^ Start time of calculation
473501
, ssTimeout :: Maybe TimeSpec -- ^ Nothing for "no timeout"
@@ -518,6 +546,28 @@ clamp alpha beta score
518546
| score > beta = beta
519547
| otherwise = score
520548

549+
restrictInterval :: MonadIO m => TVar (Score, Score) -> Side -> Score -> m ()
550+
restrictInterval var side score = liftIO $ atomically $ do
551+
(globalAlpha, globalBeta) <- readTVar var
552+
when (globalAlpha < score && score < globalBeta) $
553+
if side == First -- maximize
554+
then writeTVar var (score, globalBeta)
555+
else writeTVar var (globalAlpha, score)
556+
557+
getRestrictedInterval :: (MonadIO m, HasLogger m, HasLogContext m) => TVar (Score, Score) -> (Score, Score) -> m (Score, Score)
558+
getRestrictedInterval global (localAlpha, localBeta) = do
559+
(globalAlpha, globalBeta) <- liftIO $ atomically $ readTVar global
560+
let alpha1 = max globalAlpha localAlpha
561+
beta1 = min globalBeta localBeta
562+
if alpha1 <= beta1
563+
then do
564+
$trace "Restrict: Global [{}, {}] x Local [{}, {}] => [{}, {}]"
565+
(globalAlpha, globalBeta, localAlpha, localBeta, alpha1, beta1)
566+
return (alpha1, beta1)
567+
else do
568+
let mid = (alpha1 + beta1) `divideScore` 2
569+
return (mid, mid)
570+
521571
-- | Calculate score of the board.
522572
-- This uses the cache. It is called in the recursive call also.
523573
cachedScoreAB :: forall rules eval. (GameRules rules, Evaluator eval)
@@ -541,7 +591,7 @@ cachedScoreAB var params input = do
541591
-- AB-section: alpha <= result <= beta. So here we clamp the value
542592
-- that we got from cache.
543593
case itemBound item of
544-
Exact -> return $ Just $ ScoreOutput (clamp alpha beta score) False
594+
Exact -> return $ Just $ ScoreOutput score False
545595
Alpha -> if score <= alpha
546596
then return $ Just $ ScoreOutput alpha False
547597
else return Nothing
@@ -632,36 +682,49 @@ scoreAB var params input
632682
-- target depth is achieved, calculate score of current board directly
633683
evaluator <- gets ssEvaluator
634684
let score0 = evalBoard' evaluator board
685+
(alpha, beta) <- getRestrictedInterval'
635686
$trace " X Side: {}, A = {}, B = {}, score0 = {}" (show side, show alpha, show beta, show score0)
636687
quiescene <- checkQuiescene
637688
return $ ScoreOutput score0 quiescene
638689
| otherwise = do
639690
-- first, let "best" be the worse possible value
640-
let best = if maximize then alpha else beta -- we assume alpha <= beta
641-
push best
642-
$trace "{}V Side: {}, A = {}, B = {}" (indent, show side, show alpha, show beta)
643-
rules <- gets ssRules
644-
moves <- lift $ getPossibleMoves var side board
645-
646-
-- this actually means that corresponding side lost.
647-
when (null moves) $
648-
$trace "{}`—No moves left." (Single indent)
649-
650-
dp' <- updateDepth params moves dp
651-
let prevMove = siPrevMove input
652-
moves' <- sortMoves prevMove moves
653-
out <- iterateMoves (zip [1..] moves') dp'
654-
pop
655-
return out
691+
let best = if maximize then loose else win -- we assume alpha <= beta
692+
(alpha, beta) <- getRestrictedInterval'
693+
if alpha == beta
694+
then do
695+
quiescene <- checkQuiescene
696+
return $ ScoreOutput best quiescene
697+
else do
698+
push best
699+
$trace "{}V Side: {}, A = {}, B = {}" (indent, show side, show alpha, show beta)
700+
rules <- gets ssRules
701+
moves <- lift $ getPossibleMoves var side board
702+
703+
-- this actually means that corresponding side lost.
704+
when (null moves) $
705+
$trace "{}`—No moves left." (Single indent)
706+
707+
dp' <- updateDepth params moves dp
708+
let prevMove = siPrevMove input
709+
moves' <- sortMoves prevMove moves
710+
out <- iterateMoves (zip [1..] moves') dp'
711+
pop
712+
return out
656713

657714
where
658715

659716
side = siSide input
660717
dp = siDepth input
661-
alpha = siAlpha input
662-
beta = siBeta input
718+
localAlpha = siAlpha input
719+
localBeta = siBeta input
663720
board = siBoard input
664721

722+
getRestrictedInterval' = do
723+
globalInterval <- gets ssGlobalInterval
724+
result@(alpha, beta) <- getRestrictedInterval globalInterval (localAlpha, localBeta)
725+
return result
726+
727+
665728
evalBoard' :: eval -> Board -> Score
666729
evalBoard' evaluator board = result
667730
where
@@ -776,6 +839,7 @@ scoreAB var params input
776839
go (input : inputs) = do
777840
out <- cachedScoreAB var params input
778841
let score = soScore out
842+
(alpha, beta) <- getRestrictedInterval'
779843
if maximize && score >= beta || minimize && score <= alpha
780844
then go inputs
781845
else return out
@@ -796,6 +860,7 @@ scoreAB var params input
796860
evaluator <- gets ssEvaluator
797861
rules <- gets ssRules
798862
best <- getBest
863+
let (alpha, beta) = (localAlpha, localBeta)
799864
let input' = input {
800865
siSide = opposite side
801866
, siAlpha = if maximize

src/AI/AlphaBeta/Types.hs

+1
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ data ScoreMoveInput rules eval = ScoreMoveInput {
149149
smiAi :: AlphaBeta rules eval
150150
, smiCache :: AICacheHandle rules eval
151151
, smiGameId :: GameId
152+
, smiGlobalInterval :: TVar (Score, Score)
152153
, smiSide :: Side
153154
, smiDepth :: DepthParams
154155
, smiBoard :: Board

0 commit comments

Comments
 (0)