diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0608a83..8f1c515 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -19,7 +19,6 @@ jobs: matrix: version: - "1" - - "1.6" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/Project.toml b/Project.toml index 604e3d9..b398e32 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxTestUtils" uuid = "ac9de150-d08f-4546-94fb-7472b5760531" authors = ["Avik Pal "] -version = "0.1.13" +version = "0.1.14" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" @@ -25,7 +25,7 @@ ComponentArrays = "0.13, 0.14, 0.15" FiniteDifferences = "0.12" ForwardDiff = "0.10" Functors = "0.4" -JET = "0.4, 0.5, 0.6, 0.7, 0.8" +JET = "0.8" LuxCore = "0.1" LuxDeviceUtils = "0.1" Optimisers = "0.2, 0.3" @@ -33,7 +33,7 @@ Preferences = "1" ReverseDiff = "1" Tracker = "0.2" Zygote = "0.6" -julia = "1.6" +julia = "1.9" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/LuxTestUtils.jl b/src/LuxTestUtils.jl index d4083e1..9a29c1f 100644 --- a/src/LuxTestUtils.jl +++ b/src/LuxTestUtils.jl @@ -10,11 +10,15 @@ const JET_TARGET_MODULES = @load_preference("target_modules", nothing) try using JET global JET_TESTING_ENABLED = true + + import JET: JETTestFailure, get_reports catch @warn "JET not not precompiling. All JET tests will be skipped!!" maxlog=1 global JET_TESTING_ENABLED = false end +import Test: Error, Broken, Pass, Fail, get_testset + """ @jet f(args...) call_broken=false opt_broken=false @@ -56,7 +60,7 @@ end ``` """ macro jet(expr, args...) - @static if VERSION >= v"1.7" && JET_TESTING_ENABLED + if JET_TESTING_ENABLED all_args, call_extras, opt_extras = [], [], [] target_modules_set = false for kwexpr in args @@ -316,19 +320,11 @@ function __test_gradient_pair_check(__source__, orig_expr, v1, v2, name1, name2; end function __test_pass(test_type, orig_expr, source) - @static if VERSION >= v"1.7" - return Test.Pass(test_type, orig_expr, nothing, nothing, source) - else - return Test.Pass(test_type, orig_expr, nothing, nothing) - end + return Test.Pass(test_type, orig_expr, nothing, nothing, source) end function __test_fail(test_type, orig_expr, source) - @static if VERSION >= v"1.9.0-rc1" - return Test.Fail(test_type, orig_expr, nothing, nothing, nothing, source, false) - else - return Test.Fail(test_type, orig_expr, nothing, nothing, source) - end + return Test.Fail(test_type, orig_expr, nothing, nothing, nothing, source, false) end function __test_error(test_type, orig_expr, source)