From 5a389f3c6679f19f905d482085f539880a34889d Mon Sep 17 00:00:00 2001 From: Vaibhav Kumar Dixit Date: Wed, 29 May 2024 22:24:45 -0400 Subject: [PATCH] Move Shapley fix from #167 (#168) * Update shapley_sensitivity.jl * Update shapley_method.jl --- src/shapley_sensitivity.jl | 5 +++++ test/shapley_method.jl | 14 ++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/shapley_sensitivity.jl b/src/shapley_sensitivity.jl index 9f790e69..b9c09879 100644 --- a/src/shapley_sensitivity.jl +++ b/src/shapley_sensitivity.jl @@ -203,6 +203,11 @@ function gsa(f, method::Shapley, input_distribution::SklarDist; batch = false) sample_complement = rand( Copulas.subsetdims(input_distribution, idx_minus), n_outer) + if size(sample_complement, 2) == 1 + sample_complement = reshape( + sample_complement, (1, length(sample_complement))) + end + for l in 1:n_outer curr_sample = @view sample_complement[:, l] # Sampling of the set conditionally to the complementary element diff --git a/test/shapley_method.jl b/test/shapley_method.jl index 06a382bd..4188baa6 100644 --- a/test/shapley_method.jl +++ b/test/shapley_method.jl @@ -25,24 +25,22 @@ n_perms = -1; n_var = 10_000; n_outer = 1000; n_inner = 3; -dim = 3; -margins = (Uniform(-pi, pi), Uniform(-pi, pi), Uniform(-pi, pi)); +dim = 4; +margins = (Uniform(-pi, pi), Uniform(-pi, pi), Uniform(-pi, pi), Uniform(-pi, pi)); dependency_matrix = Matrix(4 * I, dim, dim); C = GaussianCopula(dependency_matrix); input_distribution = SklarDist(C, margins); - method = Shapley(n_perms = n_perms, n_var = n_var, n_outer = n_outer, n_inner = n_inner); - #---> non batch @time result = gsa(ishi, method, input_distribution, batch = false) @test result.shapley_effects[1]≈0.43813841765976547 atol=1e-1 @test result.shapley_effects[2]≈0.44673952698721386 atol=1e-1 -@test result.shapley_effects[3]≈0.23144736934254417 atol=1e-1 -# @test result.shapley_effects[4]≈0.0 atol=1e-1 +@test result.shapley_effects[3]≈0.11855122481995543 atol=1e-1 +@test result.shapley_effects[4]≈0.0 atol=1e-1 #<---- non batch #---> batch @@ -50,8 +48,8 @@ result = gsa(ishi_batch, method, input_distribution, batch = true); @test result.shapley_effects[1]≈0.44080027198796035 atol=1e-1 @test result.shapley_effects[2]≈0.43029987176805085 atol=1e-1 -@test result.shapley_effects[3]≈0.23144736934254417 atol=1e-1 -# @test result.shapley_effects[4]≈0.0 atol=1e-1 +@test result.shapley_effects[3]≈0.11855122481995543 atol=1e-1 +@test result.shapley_effects[4]≈0.0 atol=1e-1 #<--- batch d = 3