Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NN Built-In for Embedding Layers #2237

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

MaximilianSchreff
Copy link
Contributor

This PR adds the embedding layer as a built-in operator in our nn/layers library. The functionality is similar to pytorch.nn.Embedding (https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html)

The layer receives indices as input which refer to indices of an embedding dictionary and returns an embedding matrix where row i refers to embedding vector indices[i] of the embedding dictionary.

This layer is used in every transformer architecture. Here the indices usually come from a tokenizer and the embedding matrix is the input to the actual transformer model.

Testing

  • Testing forward pass and backward pass for correctness
  • Implemented as a component test in NNComponentTest.java
  • Manually calculated test cases for the forward pass
  • For backward pass, comparison against pytorches autograd module

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
Copy link

codecov bot commented Feb 25, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 72.47%. Comparing base (78b23cf) to head (336ef19).

Additional details and impacted files
@@            Coverage Diff            @@
##               main    #2237   +/-   ##
=========================================
  Coverage     72.46%   72.47%           
- Complexity    45453    45465   +12     
=========================================
  Files          1469     1469           
  Lines        170893   170893           
  Branches      33325    33325           
=========================================
+ Hits         123846   123863   +17     
+ Misses        37630    37617   -13     
+ Partials       9417     9413    -4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@phaniarnab
Copy link
Contributor

Thanks @MaximilianSchreff. I will merge it in.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

None yet

2 participants