diff --git a/spec/std/channel_spec.cr b/spec/std/channel_spec.cr index b9fed95dbd42..aed01831ac87 100644 --- a/spec/std/channel_spec.cr +++ b/spec/std/channel_spec.cr @@ -59,6 +59,18 @@ describe Channel do Channel.receive_first(Channel(Int32).new, channel).should eq 1 end + it "raises when receive_first timeout exceeded" do + expect_raises Channel::TimeoutError do + Channel.receive_first(Channel(Int32).new, Channel(Int32).new, timeout: 1.nanosecond) + end + expect_raises Channel::TimeoutError do + Channel.receive_first([Channel(Int32).new, Channel(Int32).new], timeout: 1.nanosecond) + end + expect_raises Channel::TimeoutError do + Channel.receive_first(StaticArray[Channel(Int32).new, Channel(Int32).new], timeout: 1.nanosecond) + end + end + it "does send_first" do ch1 = Channel(Int32).new(1) ch2 = Channel(Int32).new(1) @@ -67,6 +79,18 @@ describe Channel do ch2.receive.should eq 2 end + it "raises when send_first timeout exceeded" do + expect_raises Channel::TimeoutError do + Channel.send_first(1, Channel(Int32).new, Channel(Int32).new, timeout: 1.nanosecond) + end + expect_raises Channel::TimeoutError do + Channel.send_first(1, [Channel(Int32).new, Channel(Int32).new], timeout: 1.nanosecond) + end + expect_raises Channel::TimeoutError do + Channel.send_first(1, StaticArray[Channel(Int32).new, Channel(Int32).new], timeout: 1.nanosecond) + end + end + it "does not raise or change its status when it is closed more than once" do ch = Channel(Int32).new ch.closed?.should be_false diff --git a/src/channel.cr b/src/channel.cr index e065f77c5bcc..ee5fe1ca3027 100644 --- a/src/channel.cr +++ b/src/channel.cr @@ -36,6 +36,9 @@ class Channel(T) end end + class TimeoutError < Exception + end + private module SenderReceiverCloseAction def close self.state = DeliveryState::Closed @@ -294,22 +297,83 @@ class Channel(T) pp.text inspect end - def self.receive_first(*channels) - receive_first channels + # Returns the first available value received from the given *channels*, or + # raises `Channel::TimeoutError` if given a *timeout* that expires before a + # value is received. + # + # ``` + # c1 = Channel(String).new(1) + # c2 = Channel(String).new(1) + # + # c2.send "hello" + # value = Channel.receive_first c1, c2 # => receives "hello" from c2 + # + # begin + # # will timeout after 1 second and raise Channel::TimeoutError because + # # no channels are ready to receive + # value = Channel.receive_first c1, c2, timeout: 1.second + # rescue ex : Channel::TimeoutError + # Log.error(exception: ex) + # end + # ``` + def self.receive_first(*channels, timeout : Time::Span? = nil) + receive_first channels, timeout: timeout end - def self.receive_first(channels : Enumerable(Channel)) - _, value = self.select(channels.map(&.receive_select_action)) - value + # :ditto: + def self.receive_first(channels : Enumerable(Channel), *, timeout : Time::Span? = nil) + actions = channels.map do |channel| + action = channel.receive_select_action + action.as(Union(typeof(action) | TimeoutAction)) + end + self.select_action_first(actions, timeout: timeout) end - def self.send_first(value, *channels) : Nil - send_first value, channels + # Sends the given *value* to the first channel ready to receive in *channels*, + # or raises `Channel::TimeoutError` if given a *timeout* that expires before + # a channel becomes ready to receive. + # + # ``` + # c1 = Channel(String).new(1) + # c2 = Channel(String).new(1) + # + # c1.send "hello" + # value = Channel.send_first "goodbye", c1, c2 # => sends "goodbye" to c2 + # + # begin + # # will timeout after 1 second and raise Channel::TimeoutError because + # # no channels are ready to receive + # value = Channel.send_first "ciao", c1, c2, timeout: 1.second + # rescue ex : Channel::TimeoutError + # Log.error(exception: ex) + # end + # ``` + def self.send_first(value, *channels, timeout : Time::Span? = nil) : Nil + send_first value, channels, timeout: timeout end - def self.send_first(value, channels : Enumerable(Channel)) : Nil - self.select(channels.map(&.send_select_action(value))) - nil + # :ditto: + def self.send_first(value, channels : Enumerable(Channel), *, timeout : Time::Span? = nil) : Nil + actions = channels.map do |channel| + action = channel.send_select_action(value) + action.as(Union(typeof(action) | TimeoutAction)) + end + self.select_action_first(actions, timeout: timeout) + end + + private def self.select_action_first(actions : Enumerable(SelectAction), *, timeout : Time::Span? = nil) + if timeout.nil? + _, value = self.select(actions) + else + timeout_action, timeout_index = TimeoutAction.new(timeout), actions.size + if actions.is_a?(Tuple) + index, value = self.select(*actions, timeout_action) + else + index, value = self.select(actions.to_a << timeout_action) + end + raise TimeoutError.new if index == timeout_index + end + value end # :nodoc: