From bb190eecce08c08b989c3c6a75b81e9959e3f868 Mon Sep 17 00:00:00 2001
From: I am goroot <iamgoroot@gmail.com>
Date: Wed, 24 Jul 2024 00:54:43 +0200
Subject: [PATCH] Fix: Closing kafka Writer during WriteMessages causes a
 potential hang Fixes #1307

---
 writer.go      | 13 ++++++++++---
 writer_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 54 insertions(+), 3 deletions(-)

diff --git a/writer.go b/writer.go
index 3817bf53..cceace7e 100644
--- a/writer.go
+++ b/writer.go
@@ -663,7 +663,10 @@ func (w *Writer) WriteMessages(ctx context.Context, msgs ...Message) error {
 		assignments[key] = append(assignments[key], int32(i))
 	}
 
-	batches := w.batchMessages(msgs, assignments)
+	batches, err := w.batchMessages(msgs, assignments)
+	if err != nil {
+		return err
+	}
 	if w.Async {
 		return nil
 	}
@@ -695,7 +698,7 @@ func (w *Writer) WriteMessages(ctx context.Context, msgs ...Message) error {
 	return werr
 }
 
-func (w *Writer) batchMessages(messages []Message, assignments map[topicPartition][]int32) map[*writeBatch][]int32 {
+func (w *Writer) batchMessages(messages []Message, assignments map[topicPartition][]int32) (map[*writeBatch][]int32, error) {
 	var batches map[*writeBatch][]int32
 	if !w.Async {
 		batches = make(map[*writeBatch][]int32, len(assignments))
@@ -704,6 +707,10 @@ func (w *Writer) batchMessages(messages []Message, assignments map[topicPartitio
 	w.mutex.Lock()
 	defer w.mutex.Unlock()
 
+	if w.closed {
+		return nil, io.ErrClosedPipe
+	}
+
 	if w.writers == nil {
 		w.writers = map[topicPartition]*partitionWriter{}
 	}
@@ -721,7 +728,7 @@ func (w *Writer) batchMessages(messages []Message, assignments map[topicPartitio
 		}
 	}
 
-	return batches
+	return batches, nil
 }
 
 func (w *Writer) produce(key topicPartition, batch *writeBatch) (*ProduceResponse, error) {
diff --git a/writer_test.go b/writer_test.go
index 6f894ecd..7028ab84 100644
--- a/writer_test.go
+++ b/writer_test.go
@@ -191,6 +191,10 @@ func TestWriter(t *testing.T) {
 			scenario: "test write message with writer data",
 			function: testWriteMessageWithWriterData,
 		},
+		{
+			scenario: "test no new partition writers after close",
+			function: testWriterNoNewPartitionWritersAfterClose,
+		},
 	}
 
 	for _, test := range tests {
@@ -1030,6 +1034,46 @@ func testWriterOverrideConfigStats(t *testing.T) {
 	}
 }
 
+func testWriterNoNewPartitionWritersAfterClose(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+	defer cancel()
+	topic1 := makeTopic()
+	createTopic(t, topic1, 1)
+	defer deleteTopic(t, topic1)
+
+	w := newTestWriter(WriterConfig{
+		Topic: topic1,
+	})
+	defer w.Close() // try and close anyway after test finished
+
+	// using balancer to close writer right between first mutex is released and second mutex is taken to make map of partition writers
+	w.Balancer = mockBalancerFunc(func(m Message, i ...int) int {
+		go w.Close() // close is blocking so run in goroutine
+		for {        // wait until writer is marked as closed
+			w.mutex.Lock()
+			if w.closed {
+				w.mutex.Unlock()
+				break
+			}
+			w.mutex.Unlock()
+		}
+		return 0
+	})
+
+	msg := Message{Value: []byte("Hello World")} // no topic
+
+	if err := w.WriteMessages(ctx, msg); !errors.Is(err, io.ErrClosedPipe) {
+		t.Errorf("expected error: %v got: %v", io.ErrClosedPipe, err)
+		return
+	}
+}
+
+type mockBalancerFunc func(msg Message, partitions ...int) (partition int)
+
+func (b mockBalancerFunc) Balance(msg Message, partitions ...int) int {
+	return b(msg, partitions...)
+}
+
 type staticBalancer struct {
 	partition int
 }