Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

Commit 1d81a5a

Browse files
committed
Allow async state transition to be canceled
This adds a context and cancelation facility to the type `AsyncError`. Async state transitions can now be canceled by calling `CancelTransition` on the AsyncError returned by `fsm.Event`. The context on that error can also be handed off as described in looplab#77 (comment).
1 parent b3fb114 commit 1d81a5a

5 files changed

+167
-2
lines changed

errors.go

+7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
package fsm
1616

17+
import (
18+
"context"
19+
)
20+
1721
// InvalidEventError is returned by FSM.Event() when the event cannot be called
1822
// in the current state.
1923
type InvalidEventError struct {
@@ -82,6 +86,9 @@ func (e CanceledError) Error() string {
8286
// asynchronous state transition.
8387
type AsyncError struct {
8488
Err error
89+
90+
Ctx context.Context
91+
CancelTransition func()
8592
}
8693

8794
func (e AsyncError) Error() string {

fsm.go

+12-2
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,16 @@ func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) erro
364364
if err = f.leaveStateCallbacks(ctx, e); err != nil {
365365
if _, ok := err.(CanceledError); ok {
366366
f.transition = nil
367+
} else if asyncError, ok := err.(AsyncError); ok {
368+
// setup a new context in order for async state transitions to work correctly
369+
// this "uncancels" the original context which ignores its cancelation
370+
// but keeps the values of the original context available to callers
371+
ctx, cancel := uncancelContext(ctx)
372+
e.cancelFunc = cancel
373+
asyncError.Ctx = ctx
374+
asyncError.CancelTransition = cancel
375+
f.transition = transitionFunc(ctx, true)
376+
return asyncError
367377
}
368378
return err
369379
}
@@ -434,15 +444,15 @@ func (f *FSM) leaveStateCallbacks(ctx context.Context, e *Event) error {
434444
if e.canceled {
435445
return CanceledError{e.Err}
436446
} else if e.async {
437-
return AsyncError{e.Err}
447+
return AsyncError{Err: e.Err}
438448
}
439449
}
440450
if fn, ok := f.callbacks[cKey{"", callbackLeaveState}]; ok {
441451
fn(ctx, e)
442452
if e.canceled {
443453
return CanceledError{e.Err}
444454
} else if e.async {
445-
return AsyncError{e.Err}
455+
return AsyncError{Err: e.Err}
446456
}
447457
}
448458
return nil

fsm_test.go

+36
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,42 @@ func TestAsyncTransitionNotInProgress(t *testing.T) {
473473
}
474474
}
475475

476+
func TestCancelAsyncTransition(t *testing.T) {
477+
fsm := NewFSM(
478+
"start",
479+
Events{
480+
{Name: "run", Src: []string{"start"}, Dst: "end"},
481+
},
482+
Callbacks{
483+
"leave_start": func(_ context.Context, e *Event) {
484+
e.Async()
485+
},
486+
},
487+
)
488+
err := fsm.Event(context.Background(), "run")
489+
asyncError, ok := err.(AsyncError)
490+
if !ok {
491+
t.Errorf("expected error to be 'AsyncError', got %v", err)
492+
}
493+
var asyncStateTransitionWasCanceled bool
494+
go func() {
495+
<-asyncError.Ctx.Done()
496+
asyncStateTransitionWasCanceled = true
497+
}()
498+
asyncError.CancelTransition()
499+
time.Sleep(20 * time.Millisecond)
500+
501+
if err = fsm.Transition(); err != nil {
502+
t.Errorf("expected no error, got %v", err)
503+
}
504+
if !asyncStateTransitionWasCanceled {
505+
t.Error("expected async state transition cancelation to have propagated")
506+
}
507+
if fsm.Current() != "start" {
508+
t.Error("expected state to be 'start'")
509+
}
510+
}
511+
476512
func TestCallbackNoError(t *testing.T) {
477513
fsm := NewFSM(
478514
"start",

uncancel_context.go

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package fsm
2+
3+
import (
4+
"context"
5+
"time"
6+
)
7+
8+
type uncancel struct {
9+
context.Context
10+
}
11+
12+
func (*uncancel) Deadline() (deadline time.Time, ok bool) { return }
13+
func (*uncancel) Done() <-chan struct{} { return nil }
14+
func (*uncancel) Err() error { return nil }
15+
16+
// uncancelContext returns a context which ignores the cancellation of the parent and only keeps the values.
17+
// Also returns a new cancel function.
18+
// This is useful to keep a background task running while the initial request is finished.
19+
func uncancelContext(ctx context.Context) (context.Context, context.CancelFunc) {
20+
return context.WithCancel(&uncancel{ctx})
21+
}

uncancel_context_test.go

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package fsm
2+
3+
import (
4+
"context"
5+
"testing"
6+
)
7+
8+
func TestUncancel(t *testing.T) {
9+
t.Run("create a new context", func(t *testing.T) {
10+
t.Run("and cancel it", func(t *testing.T) {
11+
ctx := context.Background()
12+
ctx = context.WithValue(ctx, "key1", "value1")
13+
ctx, cancelFunc := context.WithCancel(ctx)
14+
cancelFunc()
15+
16+
if ctx.Err() != context.Canceled {
17+
t.Errorf("expected context error 'context canceled', got %v", ctx.Err())
18+
}
19+
select {
20+
case <-ctx.Done():
21+
default:
22+
t.Error("expected context to be done but it wasn't")
23+
}
24+
25+
t.Run("and uncancel it", func(t *testing.T) {
26+
ctx, newCancelFunc := uncancelContext(ctx)
27+
if ctx.Err() != nil {
28+
t.Errorf("expected context error to be nil, got %v", ctx.Err())
29+
}
30+
select {
31+
case <-ctx.Done():
32+
t.Fail()
33+
default:
34+
}
35+
36+
t.Run("now it should still contain the values", func(t *testing.T) {
37+
if ctx.Value("key1") != "value1" {
38+
t.Errorf("expected context value of key 'key1' to be 'value1', got %v", ctx.Value("key1"))
39+
}
40+
})
41+
t.Run("and cancel the child", func(t *testing.T) {
42+
newCancelFunc()
43+
if ctx.Err() != context.Canceled {
44+
t.Errorf("expected context error 'context canceled', got %v", ctx.Err())
45+
}
46+
select {
47+
case <-ctx.Done():
48+
default:
49+
t.Error("expected context to be done but it wasn't")
50+
}
51+
})
52+
})
53+
})
54+
t.Run("and uncancel it", func(t *testing.T) {
55+
ctx := context.Background()
56+
parent := ctx
57+
ctx, newCancelFunc := uncancelContext(ctx)
58+
if ctx.Err() != nil {
59+
t.Errorf("expected context error to be nil, got %v", ctx.Err())
60+
}
61+
select {
62+
case <-ctx.Done():
63+
t.Fail()
64+
default:
65+
}
66+
67+
t.Run("and cancel the child", func(t *testing.T) {
68+
newCancelFunc()
69+
if ctx.Err() != context.Canceled {
70+
t.Errorf("expected context error 'context canceled', got %v", ctx.Err())
71+
}
72+
select {
73+
case <-ctx.Done():
74+
default:
75+
t.Error("expected context to be done but it wasn't")
76+
}
77+
78+
t.Run("and ensure the parent is not affected", func(t *testing.T) {
79+
if parent.Err() != nil {
80+
t.Errorf("expected parent context error to be nil, got %v", ctx.Err())
81+
}
82+
select {
83+
case <-parent.Done():
84+
t.Fail()
85+
default:
86+
}
87+
})
88+
})
89+
})
90+
})
91+
}

0 commit comments

Comments
 (0)