diff --git a/stream_reader.go b/stream_reader.go index ecfa2680..dfc0cbbc 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -11,8 +11,12 @@ import ( ) var ( - headerData = []byte("data: ") - errorPrefix = []byte(`data: {"error":`) + dataField = []byte("data") + errorPrefix = []byte(`{"error":`) +) + +const ( + splitParts = 2 ) type streamable interface { @@ -55,13 +59,14 @@ func (stream *streamReader[T]) RecvRaw() ([]byte, error) { //nolint:gocognit func (stream *streamReader[T]) processLines() ([]byte, error) { var ( - emptyMessagesCount uint - hasErrorPrefix bool + emptyMessagesCount uint + dataFieldNotFound bool + valueHasErrorPrefix bool ) for { rawLine, readErr := stream.reader.ReadBytes('\n') - if readErr != nil || hasErrorPrefix { + if readErr != nil || valueHasErrorPrefix { respErr := stream.unmarshalError() if respErr != nil { return nil, fmt.Errorf("error, %w", respErr.Error) @@ -70,12 +75,28 @@ func (stream *streamReader[T]) processLines() ([]byte, error) { } noSpaceLine := bytes.TrimSpace(rawLine) - if bytes.HasPrefix(noSpaceLine, errorPrefix) { - hasErrorPrefix = true + + var value []byte + + split := bytes.SplitN(noSpaceLine, []byte(":"), splitParts) + + if len(split) != splitParts || !bytes.Equal(split[0], dataField) { + dataFieldNotFound = true + } else { + value = split[1] + + if bytes.HasPrefix(value, []byte(" ")) { + value = value[1:] + } + + if bytes.HasPrefix(value, errorPrefix) { + valueHasErrorPrefix = true + } } - if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix { - if hasErrorPrefix { - noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData) + + if dataFieldNotFound || valueHasErrorPrefix { + if valueHasErrorPrefix { + noSpaceLine = value } writeErr := stream.errAccumulator.Write(noSpaceLine) if writeErr != nil { @@ -85,11 +106,12 @@ func (stream *streamReader[T]) processLines() ([]byte, error) { if emptyMessagesCount > stream.emptyMessagesLimit { return nil, ErrTooManyEmptyStreamMessages } + dataFieldNotFound = false continue } - noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) + noPrefixLine := value if string(noPrefixLine) == "[DONE]" { stream.isFinished = true return nil, io.EOF diff --git a/stream_reader_test.go b/stream_reader_test.go index 449a14b4..a0cef5e0 100644 --- a/stream_reader_test.go +++ b/stream_reader_test.go @@ -76,3 +76,16 @@ func TestStreamReaderRecvRaw(t *testing.T) { t.Fatalf("Did not return raw line: %v", string(rawLine)) } } + +func TestStreamReaderRecvRawWithNoSpaceInFieldAndValue(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte("data:{\"key\": \"value\"}\n"))), + } + rawLine, err := stream.RecvRaw() + if err != nil { + t.Fatalf("Did not return raw line: %v", err) + } + if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) { + t.Fatalf("Did not return raw line: %v", string(rawLine)) + } +}