Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 270 additions & 0 deletions src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,276 @@ def GesvjOp : EnzymeXLA_Op<"lapack.gesvj", [Pure]> {
}];
}

// Special Functions - Bessel Functions

def BesselJ : EnzymeXLA_Op<"special.besselj", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Bessel function of the first kind of order nu at z";

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def BesselJ0 : EnzymeXLA_Op<"special.besselj0", [Pure, SameOperandsAndResultType, Elementwise]> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for us to specialize j0 here? If we have the same implementation it lowers into anyways, no need for now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagined those might have custom approximation lowerings but I'm not sure about that. Will see if I can find something

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for one, mpfr seems to have separate implementations for the different bessels https://github.com/JuliaMath/SpecialFunctions.jl/blob/1743a8b7ac1565213e87de418765c594720929b6/src/bessel.jl#L682C1-L709C4
I'm not sure yet what the lowering in stablehlo would look like.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in julia right now it just looks like they call custom c library functinos [which we can't do ourselves here anyways]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id probably just start with the more general one, and add the specialized one if we find a lowering for it later

let summary = "Bessel function of the first kind of order 0 at z";

let arguments = (ins
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def BesselJ1 : EnzymeXLA_Op<"special.besselj1", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Bessel function of the first kind of order 1 at z";

let arguments = (ins
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def BesselJX : EnzymeXLA_Op<"special.besseljx", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Scaled Bessel function of the first kind of order nu at z";

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def SphericalBesselJ : EnzymeXLA_Op<"special.sphericalbesselj", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Spherical Bessel function of the first kind of order nu at z";

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def BesselY : EnzymeXLA_Op<"special.bessely", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Bessel function of the second kind of order nu at z";

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def BesselY0 : EnzymeXLA_Op<"special.bessely0", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Bessel function of the second kind of order 0 at z";

let arguments = (ins
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def BesselY1 : EnzymeXLA_Op<"special.bessely1", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Bessel function of the second kind of order 1 at z";

let arguments = (ins
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def BesselYX : EnzymeXLA_Op<"special.besselyx", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Scaled Bessel function of the second kind of order nu at z";

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def SphericalBesselY : EnzymeXLA_Op<"special.sphericalbessely", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Spherical Bessel function of the second kind of order nu at z";

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def BesselH : EnzymeXLA_Op<"special.besselh", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Bessel function of the third kind (Hankel function) of order nu at z";

let description = [{
Computes the Bessel function of the third kind, also known as the Hankel
function. The parameter k must be either 1 or 2, selecting between Hankel
functions of the first kind (H1) and second kind (H2).
}];

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$k,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def HankelH1 : EnzymeXLA_Op<"special.hankelh1", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Hankel function of the first kind of order nu at z";

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def HankelH1X : EnzymeXLA_Op<"special.hankelh1x", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Scaled Hankel function of the first kind of order nu at z";

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def HankelH2 : EnzymeXLA_Op<"special.hankelh2", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Hankel function of the second kind of order nu at z";

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def HankelH2X : EnzymeXLA_Op<"special.hankelh2x", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Scaled Hankel function of the second kind of order nu at z";

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def BesselI : EnzymeXLA_Op<"special.besseli", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Modified Bessel function of the first kind of order nu at z";

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def BesselIX : EnzymeXLA_Op<"special.besselix", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Scaled modified Bessel function of the first kind of order nu at z";

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def BesselK : EnzymeXLA_Op<"special.besselk", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Modified Bessel function of the second kind of order nu at z";

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def BesselKX : EnzymeXLA_Op<"special.besselkx", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Scaled modified Bessel function of the second kind of order nu at z";

let arguments = (ins
HLO_Tensor:$nu,
HLO_Tensor:$z
);

let results = (outs
HLO_Tensor:$res
);
}

def Jinc : EnzymeXLA_Op<"special.jinc", [Pure, SameOperandsAndResultType, Elementwise]> {
let summary = "Jinc function (sombrero/besinc): scaled Bessel function of the first kind divided by x";

let description = [{
Computes the jinc function, also known as the sombrero or besinc function.
It is defined as J1(pi*x) / (2*x) where J1 is the Bessel function of the
first kind of order 1. At x=0, the function evaluates to pi/4.
}];

let arguments = (ins
HLO_Tensor:$x
);

let results = (outs
HLO_Tensor:$res
);
}

// Machine Learning Ops

def GeluOp: EnzymeXLA_Op<"ml.gelu", [Pure, SameOperandsAndResultType, Elementwise]> {
Expand Down
Loading