Skip to content

Commit e9cdc5b

Browse files
committed
Nicer Python enum initalization
1 parent 09da103 commit e9cdc5b

File tree

3 files changed

+86
-77
lines changed

3 files changed

+86
-77
lines changed

build.rs

+77-75
Original file line numberDiff line numberDiff line change
@@ -290,14 +290,32 @@ impl PythonBindGenerator {
290290
}
291291

292292
fn generate_union_definition(&mut self) {
293+
self.write_str("#[derive(pyo3::FromPyObject)]");
294+
self.write_string(format!("pub enum {}Union {{", self.struct_name));
295+
296+
for variable_info in self.types.iter().skip(1) {
297+
let variable_name = &variable_info[0];
298+
self.file_contents
299+
.push(Cow::Owned(format!(" {variable_name}(super::{variable_name}),")));
300+
}
301+
302+
self.write_str("}");
303+
self.write_str("");
304+
293305
self.write_str("#[pyclass(module = \"rlbot_flatbuffers\")]");
294306
self.write_str("#[derive(Debug, Default, Clone, Copy, GetSize)]");
295307
self.write_string(format!("pub enum {}Type {{", self.struct_name));
296308
self.write_str(" #[default]");
297309

298310
for variable_info in &self.types {
299311
let variable_name = &variable_info[0];
300-
self.file_contents.push(Cow::Owned(format!(" {variable_name},")));
312+
313+
if variable_name == "NONE" {
314+
self.file_contents.push(Cow::Borrowed(" #[pyo3(name = \"NONE\")]"));
315+
self.file_contents.push(Cow::Borrowed(" None,"));
316+
} else {
317+
self.file_contents.push(Cow::Owned(format!(" {variable_name},")));
318+
}
301319
}
302320

303321
self.write_str("}");
@@ -342,7 +360,11 @@ impl PythonBindGenerator {
342360
self.write_str(" match value {");
343361

344362
for (i, variable_info) in self.types.iter().enumerate() {
345-
let variable_name = &variable_info[0];
363+
let mut variable_name = variable_info[0].as_str();
364+
365+
if variable_name == "NONE" {
366+
variable_name = "None";
367+
}
346368

347369
self.file_contents
348370
.push(Cow::Owned(format!(" {i} => Self::{variable_name},")));
@@ -397,7 +419,7 @@ impl PythonBindGenerator {
397419

398420
if variable_name == "NONE" {
399421
self.file_contents.push(Cow::Owned(format!(
400-
" flat::{}::{variable_name} => Self::NONE,",
422+
" flat::{}::{variable_name} => Self::None,",
401423
self.struct_t_name,
402424
)));
403425
} else {
@@ -547,12 +569,12 @@ impl PythonBindGenerator {
547569
if variable_value.is_empty() {
548570
if is_box_type {
549571
self.file_contents.push(Cow::Owned(format!(
550-
" {}Type::NONE => flat::{}::NONE,",
572+
" {}Type::None => flat::{}::NONE,",
551573
self.struct_name, self.struct_t_name
552574
)));
553575
} else {
554576
self.file_contents.push(Cow::Owned(format!(
555-
" {}Type::NONE => Self::NONE,",
577+
" {}Type::None => Self::NONE,",
556578
self.struct_name
557579
)));
558580
}
@@ -698,74 +720,48 @@ impl PythonBindGenerator {
698720
self.write_str(" #[new]");
699721
assert!(u8::try_from(self.types.len()).is_ok());
700722

701-
let mut signature_parts = Vec::new();
702-
703-
for variable_info in &self.types {
704-
let variable_type = &variable_info[1];
705-
706-
if variable_type.is_empty() {
707-
continue;
708-
}
709-
710-
let snake_case_name = &variable_info[2];
711-
712-
signature_parts.push(format!("{snake_case_name}=None"));
713-
}
714-
715-
self.write_string(format!(" #[pyo3(signature = ({}))]", signature_parts.join(", ")));
716-
self.write_str(" pub fn new(");
717-
718-
for variable_info in &self.types {
719-
let variable_type = &variable_info[1];
720-
721-
if variable_type.is_empty() {
722-
continue;
723-
}
724-
725-
let snake_case_name = &variable_info[2];
726-
727-
self.file_contents.push(Cow::Owned(format!(
728-
" {snake_case_name}: Option<super::{variable_type}>,"
729-
)));
730-
}
723+
self.write_str(" #[pyo3(signature = (item = None))]");
724+
self.write_string(format!(" pub fn new(item: Option<{}Union>) -> Self {{", self.struct_name));
725+
self.write_str(" match item {");
731726

732-
self.write_str(" ) -> Self {");
733-
734-
self.write_string(format!(" let mut item_type = {}Type::default();", self.struct_name));
735727
for variable_info in &self.types {
736728
let variable_name = &variable_info[0];
737-
let variable_type = &variable_info[1];
738729

739-
if variable_type.is_empty() {
740-
continue;
741-
}
742-
743-
let snake_case_name = &variable_info[2];
730+
if variable_name == "NONE" {
731+
self.file_contents.push(Cow::Borrowed(" None => Self::default(),"));
732+
} else {
733+
let wanted_snake_case_name = &variable_info[2];
744734

745-
self.file_contents.push(Cow::Borrowed(""));
746-
self.file_contents
747-
.push(Cow::Owned(format!(" if {snake_case_name}.is_some() {{")));
748-
self.file_contents.push(Cow::Owned(format!(
749-
" item_type = {}Type::{variable_name};",
750-
self.struct_name
751-
)));
752-
self.file_contents.push(Cow::Borrowed(" }"));
753-
}
735+
self.file_contents.push(Cow::Owned(format!(
736+
" Some({}Union::{}({wanted_snake_case_name})) => Self {{",
737+
self.struct_name, variable_name
738+
)));
739+
self.file_contents.push(Cow::Owned(format!(
740+
" item_type: {}Type::{variable_name},",
741+
self.struct_name
742+
)));
754743

755-
self.write_str("");
756-
self.write_str(" Self {");
757-
self.write_str(" item_type,");
744+
for variable_info in &self.types {
745+
let variable_type = &variable_info[1];
758746

759-
for variable_info in &self.types {
760-
let variable_type = &variable_info[1];
747+
if variable_type.is_empty() {
748+
continue;
749+
}
761750

762-
if variable_type.is_empty() {
763-
continue;
764-
}
751+
let snake_case_name = &variable_info[2];
765752

766-
let snake_case_name = &variable_info[2];
753+
if wanted_snake_case_name == snake_case_name {
754+
self.file_contents.push(Cow::Owned(format!(
755+
" {snake_case_name}: Some({snake_case_name}),",
756+
)));
757+
} else {
758+
self.file_contents
759+
.push(Cow::Owned(format!(" {snake_case_name}: None,",)));
760+
}
761+
}
767762

768-
self.file_contents.push(Cow::Owned(format!(" {snake_case_name},")));
763+
self.file_contents.push(Cow::Borrowed(" },"));
764+
}
769765
}
770766

771767
self.write_str(" }");
@@ -901,8 +897,8 @@ impl PythonBindGenerator {
901897

902898
if variable_type.is_empty() {
903899
self.file_contents.push(Cow::Owned(format!(
904-
" {}Type::NONE => String::from(\"()\"),",
905-
self.struct_name
900+
" {}Type::None => String::from(\"{}()\"),",
901+
self.struct_name, self.struct_name
906902
)));
907903
} else {
908904
let snake_case_name = &variable_info[2];
@@ -916,10 +912,8 @@ impl PythonBindGenerator {
916912
" {}Type::{variable_name} => format!(",
917913
self.struct_name
918914
)));
919-
self.file_contents.push(Cow::Owned(format!(
920-
" \"{}({snake_case_name}={{}})\",",
921-
self.struct_name
922-
)));
915+
self.file_contents
916+
.push(Cow::Owned(format!(" \"{}({{}})\",", self.struct_name)));
923917

924918
self.file_contents
925919
.push(Cow::Owned(format!(" self.{snake_case_name}")));
@@ -1287,6 +1281,18 @@ fn pyi_generator(type_data: &[(String, String, Vec<Vec<String>>)]) -> io::Result
12871281

12881282
if is_enum {
12891283
file_contents.push(Cow::Borrowed(" def __init__(self, value: int = 0): ..."));
1284+
} else if is_union {
1285+
file_contents.push(Cow::Borrowed(" def __init__("));
1286+
1287+
let types = types
1288+
.iter()
1289+
.map(|variable_info| variable_info[0].as_str())
1290+
.filter(|variable_name| *variable_name != "NONE")
1291+
.collect::<Vec<_>>();
1292+
let union_str = types.join(" | ");
1293+
1294+
file_contents.push(Cow::Owned(format!(" self, item: Optional[{union_str}] = None")));
1295+
file_contents.push(Cow::Borrowed(" ): ..."));
12901296
} else {
12911297
file_contents.push(Cow::Borrowed(" def __init__("));
12921298
file_contents.push(Cow::Borrowed(" self,"));
@@ -1297,13 +1303,9 @@ fn pyi_generator(type_data: &[(String, String, Vec<Vec<String>>)]) -> io::Result
12971303
continue;
12981304
}
12991305

1300-
let variable_name = if is_union { &variable_info[2] } else { &variable_info[0] };
1306+
let variable_name = &variable_info[0];
13011307

1302-
let variable_type = if is_union {
1303-
format!("Option<{}>", variable_info[1])
1304-
} else {
1305-
variable_info[1].clone()
1306-
};
1308+
let variable_type = variable_info[1].clone();
13071309

13081310
let default_value = match variable_type.as_str() {
13091311
"bool" => Cow::Borrowed("False"),

pytest.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
print(dgs)
3131
print()
3232

33-
render_type = RenderType(line_3_d=Line3D(Vector3(0, 0, 0), Vector3(1, 1, 1), Color(255)))
33+
render_type = RenderType(Line3D(Vector3(0, 0, 0), Vector3(1, 1, 1), Color(255)))
3434
render_type.line_3_d.color.a = 150
3535

36+
print(repr(RenderType()))
37+
3638
print(repr(render_type))
3739
print(render_type)
3840
print()

src/lib.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
)]
1010
pub mod generated;
1111

12-
#[allow(clippy::too_many_arguments, clippy::upper_case_acronyms, non_camel_case_types)]
12+
#[allow(
13+
clippy::too_many_arguments,
14+
clippy::upper_case_acronyms,
15+
clippy::enum_variant_names,
16+
non_camel_case_types
17+
)]
1318
mod python;
1419

1520
use pyo3::prelude::*;

0 commit comments

Comments
 (0)