Skip to content

Commit 9dc2a42

Browse files
committed
feat(libsql): support more types in de::from_row
1 parent 19b1c7a commit 9dc2a42

File tree

2 files changed

+128
-11
lines changed

2 files changed

+128
-11
lines changed

libsql/src/de.rs

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Deserialization utilities.
22
33
use crate::{Row, Value};
4-
use serde::de::{value::Error as DeError, Error, IntoDeserializer, MapAccess, Visitor};
4+
use serde::de::{value::Error as DeError, Error, IntoDeserializer, MapAccess, SeqAccess, Visitor};
55
use serde::{Deserialize, Deserializer};
66

77
struct RowDeserializer<'de> {
@@ -15,15 +15,12 @@ impl<'de> Deserializer<'de> for RowDeserializer<'de> {
1515
where
1616
V: Visitor<'de>,
1717
{
18-
Err(DeError::custom("Expects a struct"))
18+
Err(DeError::custom(
19+
"Expects a map, newtype, sequence, struct, or tuple",
20+
))
1921
}
2022

21-
fn deserialize_struct<V>(
22-
self,
23-
_name: &'static str,
24-
_fields: &'static [&'static str],
25-
visitor: V,
26-
) -> Result<V::Value, Self::Error>
23+
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2724
where
2825
V: Visitor<'de>,
2926
{
@@ -73,10 +70,83 @@ impl<'de> Deserializer<'de> for RowDeserializer<'de> {
7370
})
7471
}
7572

73+
fn deserialize_struct<V>(
74+
self,
75+
_name: &'static str,
76+
_fields: &'static [&'static str],
77+
visitor: V,
78+
) -> Result<V::Value, Self::Error>
79+
where
80+
V: Visitor<'de>,
81+
{
82+
self.deserialize_map(visitor)
83+
}
84+
85+
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
86+
where
87+
V: Visitor<'de>,
88+
{
89+
struct RowSeqAccess<'a> {
90+
row: &'a Row,
91+
idx: std::ops::Range<usize>,
92+
}
93+
94+
impl<'de> SeqAccess<'de> for RowSeqAccess<'de> {
95+
type Error = DeError;
96+
97+
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
98+
where
99+
T: serde::de::DeserializeSeed<'de>,
100+
{
101+
match self.idx.next() {
102+
None => Ok(None),
103+
Some(i) => {
104+
let value = self.row.get_value(i as i32).map_err(DeError::custom)?;
105+
seed.deserialize(value.into_deserializer()).map(Some)
106+
}
107+
}
108+
}
109+
}
110+
111+
visitor.visit_seq(RowSeqAccess {
112+
row: self.row,
113+
idx: 0..(self.row.column_count() as usize),
114+
})
115+
}
116+
117+
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
118+
where
119+
V: Visitor<'de>,
120+
{
121+
self.deserialize_seq(visitor)
122+
}
123+
124+
fn deserialize_tuple_struct<V>(
125+
self,
126+
_name: &'static str,
127+
_len: usize,
128+
visitor: V,
129+
) -> Result<V::Value, Self::Error>
130+
where
131+
V: Visitor<'de>,
132+
{
133+
self.deserialize_seq(visitor)
134+
}
135+
136+
fn deserialize_newtype_struct<V>(
137+
self,
138+
_name: &'static str,
139+
visitor: V,
140+
) -> Result<V::Value, Self::Error>
141+
where
142+
V: Visitor<'de>,
143+
{
144+
visitor.visit_newtype_struct(self)
145+
}
146+
76147
serde::forward_to_deserialize_any! {
77148
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
78-
bytes byte_buf option unit unit_struct newtype_struct seq tuple
79-
tuple_struct map enum identifier ignored_any
149+
bytes byte_buf option unit unit_struct enum identifier ignored_any
80150
}
81151
}
82152

libsql/tests/integration_tests.rs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,9 +643,11 @@ async fn deserialize_row() {
643643
.await
644644
.unwrap();
645645

646+
use std::collections::HashMap;
647+
646648
use serde::Deserialize;
647649

648-
#[derive(Deserialize, Debug)]
650+
#[derive(Deserialize, Debug, PartialEq)]
649651
struct Data {
650652
id: i64,
651653
name: String,
@@ -684,6 +686,51 @@ async fn deserialize_row() {
684686
assert_eq!(data.none, None);
685687
assert_eq!(data.status, Status::Draft);
686688
assert_eq!(data.wrapper, Wrapper(Status::Published));
689+
690+
#[derive(Deserialize, Debug)]
691+
struct Newtype(Data);
692+
let newtype: Newtype = libsql::de::from_row(&row).unwrap();
693+
assert_eq!(newtype.0, data);
694+
695+
let tuple: (i64, String, f64, Vec<u8>, Option<i64>, Status, Wrapper) =
696+
libsql::de::from_row(&row).unwrap();
697+
assert_eq!(
698+
tuple,
699+
(
700+
123,
701+
"potato".to_string(),
702+
42.0,
703+
vec![0xde, 0xad, 0xbe, 0xef],
704+
None,
705+
Status::Draft,
706+
Wrapper(Status::Published)
707+
)
708+
);
709+
710+
let row2 = conn
711+
.query("SELECT name, status, wrapper FROM users", ())
712+
.await
713+
.unwrap()
714+
.next()
715+
.await
716+
.unwrap()
717+
.unwrap();
718+
let arr: Vec<String> = libsql::de::from_row(&row2).unwrap();
719+
assert_eq!(arr, vec!["potato", "Draft", "Published"]);
720+
721+
let map: HashMap<String, String> = libsql::de::from_row(&row2).unwrap();
722+
assert_eq!(
723+
map,
724+
HashMap::from_iter(
725+
[
726+
("name", "potato"),
727+
("status", "Draft"),
728+
("wrapper", "Published"),
729+
]
730+
.into_iter()
731+
.map(|(k, v)| (k.to_string(), v.to_string()))
732+
)
733+
);
687734
}
688735

689736
#[tokio::test]

0 commit comments

Comments
 (0)