Skip to content

Commit

Permalink
Update ml_dtypes version and path.
Browse files Browse the repository at this point in the history
The include paths for headers within the ml_dtypes package have changed.
We therefore need to adjust the TF/XLA build rules and paths to account
for this.  Also updated the pip ml_dtypes version to match.

The main ml_dtypes repo name needed to be changed to avoid
conflicts with the ml_dtypes subfolder.  The subfolder contains the main
python package that needs to be added to the PYTHON_PATH.

PiperOrigin-RevId: 723654395
  • Loading branch information
cantonios authored and copybara-github committed Feb 5, 2025
1 parent 62f5dee commit fb28268
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 240 deletions.
1 change: 0 additions & 1 deletion opensource_only.files
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ third_party/py/BUILD.tpl:
third_party/py/BUILD:
third_party/py/manylinux_compliance_test.py:
third_party/py/ml_dtypes/BUILD:
third_party/py/ml_dtypes/LICENSE:
third_party/py/numpy/BUILD:
third_party/py/py_import.bzl:
third_party/py/py_manylinux_compliance_test.bzl:
Expand Down
202 changes: 0 additions & 202 deletions third_party/py/ml_dtypes/LICENSE

This file was deleted.

28 changes: 3 additions & 25 deletions third_party/py/ml_dtypes/ml_dtypes.BUILD
Original file line number Diff line number Diff line change
@@ -1,48 +1,26 @@
""" Main ml_dtypes library. """

load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")

package(
default_visibility = ["//visibility:public"],
licenses = ["notice"],
)

exports_files(["LICENSE"])

cc_library(
name = "float8",
hdrs = ["include/float8.h"],
include_prefix = "ml_dtypes",
# Internal headers are all relative to . but other packages
# include these headers with the prefix.
includes = [
".",
"ml_dtypes",
],
deps = ["@eigen_archive//:eigen3"],
)

cc_library(
name = "intn",
hdrs = ["include/intn.h"],
include_prefix = "ml_dtypes",
# Internal headers are all relative to . but other packages
# include these headers with the prefix.
includes = [
".",
"ml_dtypes",
],
)

cc_library(
name = "mxfloat",
hdrs = ["include/mxfloat.h"],
include_prefix = "ml_dtypes",
# Internal headers are all relative to . but other packages
# include these headers with the prefix.
includes = [
".",
"float8",
"ml_dtypes",
],
deps = [
":float8",
"@eigen_archive//:eigen3",
Expand All @@ -60,7 +38,6 @@ pybind_extension(
"_src/numpy.h",
"_src/ufuncs.h",
],
includes = ["ml_dtypes"],
visibility = [":__subpackages__"],
deps = [
":float8",
Expand All @@ -78,5 +55,6 @@ py_library(
"_finfo.py",
"_iinfo.py",
],
imports = ["."], # Import relative to _this_ directory, not the root.
deps = [":_ml_dtypes_ext"],
)
8 changes: 8 additions & 0 deletions third_party/py/ml_dtypes/ml_dtypes_py.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
""" Root ml_dtypes_py package. """

package(
default_visibility = ["//visibility:public"],
licenses = ["notice"],
)

exports_files(["LICENSE"])
12 changes: 6 additions & 6 deletions third_party/py/ml_dtypes/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ float8 varieties, and int4.
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
ML_DTYPES_COMMIT = "0fa5313b65efe848c5968a15dd37dd220cc29567"
ML_DTYPES_SHA256 = "69c562bb961a21d92357c7709430553c226caac75a751c0aa52955ca14ce8641"
ML_DTYPES_COMMIT = "00d98cd92ade342fef589c0470379abb27baebe9"
ML_DTYPES_SHA256 = "f6e5880666661351e6cd084ac4178ddc4dabcde7e9a73722981c0d1500cf5937"
tf_http_archive(
name = "ml_dtypes",
build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD",
name = "ml_dtypes_py",
build_file = "//third_party/py/ml_dtypes:ml_dtypes_py.BUILD",
link_files = {
"//third_party/py/ml_dtypes:LICENSE": "LICENSE",
"//third_party/py/ml_dtypes:ml_dtypes.BUILD": "ml_dtypes/BUILD.bazel",
},
sha256 = ML_DTYPES_SHA256,
strip_prefix = "ml_dtypes-{commit}/ml_dtypes".format(commit = ML_DTYPES_COMMIT),
strip_prefix = "ml_dtypes-{commit}".format(commit = ML_DTYPES_COMMIT),
urls = tf_mirror_urls("https://github.com/jax-ml/ml_dtypes/archive/{commit}/ml_dtypes-{commit}.tar.gz".format(commit = ML_DTYPES_COMMIT)),
)
6 changes: 3 additions & 3 deletions tsl/platform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -884,9 +884,9 @@ cc_library(
hdrs = ["ml_dtypes.h"],
compatible_with = get_compatible_with_portable(),
deps = [
"@ml_dtypes//:float8",
"@ml_dtypes//:intn",
"@ml_dtypes//:mxfloat",
"@ml_dtypes_py//ml_dtypes:float8",
"@ml_dtypes_py//ml_dtypes:intn",
"@ml_dtypes_py//ml_dtypes:mxfloat",
],
)

Expand Down
6 changes: 3 additions & 3 deletions tsl/platform/ml_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_
#define TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_

#include "ml_dtypes/include/float8.h" // from @ml_dtypes
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes
#include "ml_dtypes/include/float8.h" // from @ml_dtypes_py
#include "ml_dtypes/include/intn.h" // from @ml_dtypes_py
#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes_py

namespace tsl {
using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn;
Expand Down

0 comments on commit fb28268

Please sign in to comment.