diff --git a/R/chat-structured.R b/R/chat-structured.R index 4ef25a1f..0c240a84 100644 --- a/R/chat-structured.R +++ b/R/chat-structured.R @@ -1,13 +1,63 @@ -extract_data <- function(turn, type, convert = TRUE, needs_wrapper = FALSE) { +extract_data <- function( + turn, + type, + convert = TRUE, + needs_wrapper = FALSE, + prompt_index = NULL +) { is_json <- map_lgl(turn@contents, S7_inherits, ContentJson) n <- sum(is_json) - if (n != 1) { - cli::cli_abort("Data extraction failed: {n} data results recieved.") + if (n == 0) { + cli::cli_abort("Data extraction failed: 0 data results received.") + } else if (n == 1) { + # Normal case - exactly 1 JSON object + json <- turn@contents[[which(is_json)]] + out <- json@value + } else if (n == 2) { + # Check if the two JSON objects are identical (duplicate case) + json_indices <- which(is_json) + json1 <- turn@contents[[json_indices[1]]] + json2 <- turn@contents[[json_indices[2]]] + val1 <- json1@value + val2 <- json2@value + if (identical(val1, val2)) { + # Duplicate case - use the first one + index_msg <- if (!is.null(prompt_index)) { + paste0(" (prompt ", prompt_index, ")") + } else { + "" + } + warning( + "Found duplicate JSON responses, using the first one", + index_msg, + ".", + call. = FALSE, + immediate. = TRUE + ) + out <- val1 + } else { + # Different JSON objects - use the last one (likely the final response) + index_msg <- if (!is.null(prompt_index)) { + paste0(" (prompt ", prompt_index, ")") + } else { + "" + } + warning( + "Found multiple different JSON responses, using the last one", + index_msg, + ".", + call. = FALSE, + immediate. = TRUE + ) + out <- val2 + } + } else { + # More than 2 JSON objects - this is unexpected + cli::cli_abort( + "Data extraction failed: {n} data results received. Expected 1 or 2." + ) } - json <- turn@contents[[which(is_json)]] - out <- json@value - if (needs_wrapper) { out <- out$wrapper type <- type@properties[[1]] diff --git a/R/parallel-chat.R b/R/parallel-chat.R index d77b77e0..35ba2164 100644 --- a/R/parallel-chat.R +++ b/R/parallel-chat.R @@ -170,12 +170,13 @@ multi_convert <- function( ) { needs_wrapper <- type_needs_wrapper(type, provider) - rows <- map(turns, \(turn) { + rows <- imap(turns, \(turn, idx) { extract_data( turn = turn, type = wrap_type_if_needed(type, needs_wrapper), convert = FALSE, - needs_wrapper = needs_wrapper + needs_wrapper = needs_wrapper, + prompt_index = idx ) }) diff --git a/tests/testthat/_snaps/chat-structured.md b/tests/testthat/_snaps/chat-structured.md deleted file mode 100644 index 6812f46f..00000000 --- a/tests/testthat/_snaps/chat-structured.md +++ /dev/null @@ -1,8 +0,0 @@ -# useful error if no ContentJson - - Code - extract_data(turn) - Condition - Error in `extract_data()`: - ! Data extraction failed: 0 data results recieved. - diff --git a/tests/testthat/test-chat-structured.R b/tests/testthat/test-chat-structured.R index 95889ea5..5fa16ca6 100644 --- a/tests/testthat/test-chat-structured.R +++ b/tests/testthat/test-chat-structured.R @@ -2,7 +2,11 @@ test_that("useful error if no ContentJson", { turn <- Turn("assistant", list(ContentText("Hello"))) - expect_snapshot(extract_data(turn), error = TRUE) + expect_error( + extract_data(turn), + "Data extraction failed: 0 data results received.", + fixed = TRUE + ) }) test_that("can extract data from ContentJson", { @@ -24,6 +28,92 @@ test_that("can extract data when wrapper is used", { expect_equal(extract_data(turn, type, needs_wrapper = TRUE), list(x = 1)) }) +test_that("handles duplicate identical JSON responses", { + # This test covers the Bedrock duplicate JSON issue + json_data <- list(name = "John", age = 25) + turn <- Turn( + "assistant", + list( + ContentJson(json_data), + ContentJson(json_data) # Identical duplicate + ) + ) + type <- type_object(name = type_string(), age = type_integer()) + # Should warn about duplicates and use the first one + expect_warning( + result <- extract_data(turn, type), + "Found duplicate JSON responses, using the first one" + ) + expect_equal(result, list(name = "John", age = 25)) +}) + +test_that("handles duplicate identical JSON responses with prompt index", { + # Test that prompt index is included in warning message + json_data <- list(score = 42) + turn <- Turn( + "assistant", + list( + ContentJson(json_data), + ContentJson(json_data) + ) + ) + type <- type_object(score = type_integer()) + expect_warning( + extract_data(turn, type, prompt_index = 3), + "Found duplicate JSON responses, using the first one \\(prompt 3\\)" + ) +}) + +test_that("handles different JSON responses", { + # This test covers the case where two different JSON objects are returned + turn <- Turn( + "assistant", + list( + ContentJson(list(name = "John", age = 25)), + ContentJson(list(name = "Jane", age = 30)) # Different data + ) + ) + type <- type_object(name = type_string(), age = type_integer()) + # Should warn about multiple responses and use the last one + expect_warning( + result <- extract_data(turn, type), + "Found multiple different JSON responses, using the last one" + ) + expect_equal(result, list(name = "Jane", age = 30)) +}) + +test_that("handles different JSON responses with prompt index", { + turn <- Turn( + "assistant", + list( + ContentJson(list(value = 1)), + ContentJson(list(value = 2)) + ) + ) + type <- type_object(value = type_integer()) + expect_warning( + extract_data(turn, type, prompt_index = 5), + "Found multiple different JSON responses, using the last one \\(prompt 5\\)" + ) +}) + +test_that("errors on more than 2 JSON responses", { + # Should error if there are more than 2 JSON objects + turn <- Turn( + "assistant", + list( + ContentJson(list(x = 1)), + ContentJson(list(x = 2)), + ContentJson(list(x = 3)) + ) + ) + type <- type_object(x = type_integer()) + expect_error( + extract_data(turn, type), + "Data extraction failed: 3 data results received. Expected 1 or 2." + ) +}) + # Type coercion --------------------------------------------------------------- test_that("optional base types (scalars) stay as NULL", {