Skip to content

Commit

Permalink
Unions can no longer be None
Browse files Browse the repository at this point in the history
  • Loading branch information
VirxEC committed Feb 3, 2025
1 parent ae278a5 commit 9c14d33
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 43 deletions.
12 changes: 7 additions & 5 deletions codegen/pyi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@ pub fn generator(type_data: &[PythonBindType]) -> io::Result<()> {
.map(|variable_info| variable_info.name.as_str())
.filter(|variable_name| *variable_name != "NONE")
.collect::<Vec<_>>();
let default_value = types.first().unwrap();
let union_str = types.join(" | ");

write_fmt!(file, " item: Optional[{union_str}]");
write_fmt!(file, " item: {union_str}");
write_str!(file, "");
write_str!(file, " def __new__(");
write_fmt!(file, " cls, item: Optional[{union_str}] = None");
write_fmt!(file, " cls, item: {union_str} = {default_value}()");
write_str!(file, " ): ...");
write_str!(file, " def __init__(");
write_fmt!(file, " self, item: Optional[{union_str}] = None");
write_fmt!(file, " self, item: {union_str} = {default_value}()");
write_str!(file, " ): ...\n");
}
PythonBindType::Enum(gen) => {
Expand Down Expand Up @@ -228,8 +229,7 @@ pub fn generator(type_data: &[PythonBindType]) -> io::Result<()> {
_ => None,
})
.unwrap();
let types = union_types.join(" | ");
python_types.push(format!("Optional[{types}]"));
python_types.push(union_types.join(" | "));
}
RustType::Custom(type_name)
| RustType::Other(type_name)
Expand Down Expand Up @@ -284,6 +284,8 @@ pub fn generator(type_data: &[PythonBindType]) -> io::Result<()> {
|| t.starts_with("Option<")
{
Cow::Borrowed("None")
} else if let Some(pos) = python_type.find('|') {
Cow::Owned(format!("{}()", &python_type[..pos - 1]))
} else if t.starts_with("Vec<") {
Cow::Borrowed("[]")
} else if t.starts_with("Box<") {
Expand Down
5 changes: 2 additions & 3 deletions codegen/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,12 +404,11 @@ impl StructBindGenerator {
RustType::Union(inner_type) => {
write_fmt!(
self,
" {variable_name}: Py::new(py, super::{}::new({variable_name})).unwrap(),",
inner_type
" {variable_name}: {variable_name}.map(|u| Py::new(py, super::{inner_type}::new(u)).unwrap()).unwrap_or_else(|| super::{inner_type}::py_default(py)),"
);
}
RustType::Box(inner_type) | RustType::Custom(inner_type) => {
write_fmt!(self, " {variable_name}: {variable_name}.unwrap_or_else(|| super::{inner_type}::py_default(py)),",);
write_fmt!(self, " {variable_name}: {variable_name}.unwrap_or_else(|| super::{inner_type}::py_default(py)),");
}
RustType::Vec(InnerVecType::U8) => {
write_fmt!(
Expand Down
65 changes: 30 additions & 35 deletions codegen/unions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@ impl UnionBindGenerator {
assert!(u8::try_from(self.types.len()).is_ok());

write_str!(self, " #[new]");
write_str!(self, " #[pyo3(signature = (item = None))]");
write_fmt!(
self,
" pub fn new(item: Option<{}Union>) -> Self {{",
" pub fn new(item: {}Union) -> Self {{",
self.struct_name
);
write_str!(self, " Self { item }");
Expand All @@ -66,17 +65,15 @@ impl UnionBindGenerator {
self,
" pub fn get(&self, py: Python) -> Option<Py<PyAny>> {"
);
write_str!(self, " match self.item.as_ref() {");
write_str!(self, " match &self.item {");

for variable_info in &self.types {
let variable_name = variable_info.name.as_str();

if variable_name == "NONE" {
write_str!(self, " None => None,");
} else {
if variable_name != "NONE" {
write_fmt!(
self,
" Some({}Union::{variable_name}(item)) => Some(item.clone_ref(py).into_any()),",
" {}Union::{variable_name}(item) => Some(item.clone_ref(py).into_any()),",
self.struct_name
);
}
Expand All @@ -94,19 +91,17 @@ impl UnionBindGenerator {

fn generate_inner_repr_method(&mut self) {
write_str!(self, " pub fn inner_repr(&self, py: Python) -> String {");
write_str!(self, " match self.item.as_ref() {");
write_str!(self, " match &self.item {");

for variable_info in &self.types {
let variable_name = variable_info.name.as_str();

if variable_info.value.is_some() {
write_fmt!(
self,
" Some({}Union::{variable_name}(item)) => item.borrow(py).__repr__(py),",
" {}Union::{variable_name}(item) => item.borrow(py).__repr__(py),",
self.struct_name
);
} else {
write_str!(self, " None => crate::none_str(),");
}
}

Expand All @@ -116,24 +111,18 @@ impl UnionBindGenerator {

fn generate_repr_method(&mut self) {
write_str!(self, " pub fn __repr__(&self, py: Python) -> String {");
write_str!(self, " match self.item.as_ref() {");
write_str!(self, " match &self.item {");

for variable_info in &self.types {
let variable_name = variable_info.name.as_str();

if variable_info.value.is_some() {
write_fmt!(
self,
" Some({}Union::{variable_name}(item)) => format!(\"{}({{}})\", item.borrow(py).__repr__(py)),",
" {}Union::{variable_name}(item) => format!(\"{}({{}})\", item.borrow(py).__repr__(py)),",
self.struct_name,
self.struct_name
);
} else {
write_fmt!(
self,
" None => String::from(\"{}()\"),",
self.struct_name
);
}
}

Expand Down Expand Up @@ -177,7 +166,7 @@ impl Generator for UnionBindGenerator {
}

fn generate_definition(&mut self) {
write_fmt!(self, "#[derive(Debug, pyo3::FromPyObject)]");
write_fmt!(self, "#[derive(pyo3::FromPyObject)]");
write_fmt!(self, "pub enum {}Union {{", self.struct_name);

for variable_info in self.types.iter().skip(1) {
Expand All @@ -192,17 +181,26 @@ impl Generator for UnionBindGenerator {
if self.is_frozen {
write_str!(self, "#[pyclass(module = \"rlbot_flatbuffers\", frozen)]");
} else {
write_str!(self, "#[pyclass(module = \"rlbot_flatbuffers\")]");
write_str!(self, "#[pyclass(module = \"rlbot_flatbuffers\", set_all)]");
}

write_fmt!(self, "#[derive(Debug, Default)]");
write_fmt!(self, "pub struct {} {{", self.struct_name);
write_fmt!(self, " item: {}Union,", self.struct_name);
write_str!(self, "}");
write_str!(self, "");

if !self.is_frozen {
write_str!(self, " #[pyo3(set)]");
}

write_fmt!(self, " pub item: Option<{}Union>,", self.struct_name);
write_fmt!(self, "impl crate::PyDefault for {} {{", self.struct_name);
write_str!(self, " fn py_default(py: Python) -> Py<Self> {");
write_str!(self, " Py::new(py, Self {");
write_fmt!(
self,
" item: {}Union::{}(super::{}::py_default(py)),",
self.struct_name,
self.types[1].name,
self.types[1].name
);
write_str!(self, " }).unwrap()");
write_str!(self, " }");
write_str!(self, "}");
write_str!(self, "");
}
Expand All @@ -228,9 +226,8 @@ impl Generator for UnionBindGenerator {
if variable_name == "NONE" {
write_fmt!(
self,
" flat::{}::NONE => {}::default(),",
" flat::{}::NONE => unreachable!(),",
self.struct_t_name,
self.struct_name
);
} else {
write_fmt!(
Expand All @@ -242,7 +239,7 @@ impl Generator for UnionBindGenerator {

write_fmt!(
self,
" item: Some({}Union::{variable_name}(",
" item: {}Union::{variable_name}(",
self.struct_name
);

Expand All @@ -251,7 +248,7 @@ impl Generator for UnionBindGenerator {
" Py::new(py, super::{variable_name}::from_gil(py, *item)).unwrap(),"
);

write_fmt!(self, " )),");
write_fmt!(self, " ),");
write_fmt!(self, " }},");
}
}
Expand All @@ -275,15 +272,15 @@ impl Generator for UnionBindGenerator {
self.struct_name
);

write_str!(self, " match py_type.item.as_ref() {");
write_str!(self, " match &py_type.item {");

for variable_info in &self.types {
let variable_name = variable_info.name.as_str();

if let Some(ref value) = variable_info.value {
write_fmt!(
self,
" Some({}Union::{value}(item)) => {{",
" {}Union::{value}(item) => {{",
self.struct_name,
);

Expand All @@ -294,8 +291,6 @@ impl Generator for UnionBindGenerator {
);

write_str!(self, " },");
} else {
write_str!(self, " None => Self::NONE,");
}
}

Expand Down

0 comments on commit 9c14d33

Please sign in to comment.