You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, the skorch.dataset.get_len function does not accept dictionary input with elements of different size, raising ValueError("Dataset does not have consistent lengths.").
This behavior is problematic in cases such as GNNs with pytorch geometric, where the forward method expects node features and an edge index of a different size.
Accommodating for that could involve modifying the get_len function in some way to specify a length for a batch if a custom collate_fn is used.
I could implement that change in the library.
The text was updated successfully, but these errors were encountered:
Thanks for the report. We do have an example of using pytorch geometric with a section on data handling, not sure if this can be applied to your use case. Also pinging @githubnemo since he added the notebook.
Been playing with this for a bit too. The main problem here is the ability to use torch_geometric with complex pipelines, which often await tabular data.
proposed solution: metadata routing. If we can successfully route the graphs data (e.g: an array or pandas DataFrame of torch_geometric.data.Data), it would be trivial to pass them as additional fit params, and use a collate_fn to handle stuff from there. This is also critical for complex, non-GNN architectures receiving multiple types of "X", so a nice thing to have anyway.
Metadata routing is not implemented for now in skorch NeuralNet, it might be a good idea to open an issue dedicated to it?
edit: added a dedicated issue: #1095
Currently, the
skorch.dataset.get_len
function does not accept dictionary input with elements of different size, raisingValueError("Dataset does not have consistent lengths.")
.This behavior is problematic in cases such as GNNs with pytorch geometric, where the forward method expects node features and an edge index of a different size.
Accommodating for that could involve modifying the
get_len
function in some way to specify a length for a batch if a customcollate_fn
is used.I could implement that change in the library.
The text was updated successfully, but these errors were encountered: