Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mutual Information Sensitivity #176

Merged
merged 36 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e2d4616
Add KS Rank sensitivity method
max-de-rooij Jul 15, 2024
183cf4a
Delete settings.json
max-de-rooij Jul 15, 2024
9bb24d8
Update ks_rank_sensitivity.jl
max-de-rooij Jul 16, 2024
e92c304
Update tests
max-de-rooij Jul 16, 2024
dfca2bf
Update tests and docs
max-de-rooij Jul 17, 2024
b45b446
Formatting
max-de-rooij Jul 17, 2024
5f7325e
Update docs
max-de-rooij Jul 17, 2024
cf641ae
Update ks_rank_sensitivity.jl
max-de-rooij Jul 18, 2024
4d02cae
Create docs page for KS Rank method
max-de-rooij Jul 19, 2024
4f2b6f1
Merge branch 'SciML:master' into master
max-de-rooij Jul 19, 2024
62cd867
Update ks_rank_sensitivity.jl
max-de-rooij Jul 19, 2024
d42e834
Update ks_rank_sensitivity.jl
max-de-rooij Jul 19, 2024
222808b
Update ks_rank_sensitivity.jl
max-de-rooij Jul 22, 2024
3b4b552
Rename ks_rank to rsa and fix docs
max-de-rooij Jul 22, 2024
555ccf5
Change name of docs page
max-de-rooij Jul 22, 2024
ec82e5e
Bump lower bound of StatsBase to 0.33.7
max-de-rooij Jul 24, 2024
4017c46
Merge pull request #1 from max-de-rooij/compat
max-de-rooij Jul 24, 2024
a5bc6b0
Make tests faster and modify default
max-de-rooij Jul 25, 2024
cb8c3be
format
max-de-rooij Jul 25, 2024
4c54aaf
Merge branch 'master' of https://github.com/max-de-rooij/GlobalSensit…
max-de-rooij Jul 30, 2024
727d6b8
Add mutual information method
max-de-rooij Jul 30, 2024
8a992a4
Update mutual_information_method.jl
max-de-rooij Jul 30, 2024
24c1cfa
Update dependency to ComplexityMeasures
max-de-rooij Aug 1, 2024
d9847e6
Update src/mutual_information_sensitivity.jl
max-de-rooij Aug 1, 2024
5284d5e
Update mutual_information_sensitivity.jl
max-de-rooij Aug 1, 2024
3c276ca
Merge branch 'master' of https://github.com/max-de-rooij/GlobalSensit…
max-de-rooij Aug 1, 2024
e4d6ffe
Merge branch 'SciML:master' into master
max-de-rooij Aug 2, 2024
8769b16
Update mutual_information_sensitivity.jl
max-de-rooij Aug 2, 2024
1b6e633
Update mutual_information_sensitivity.jl
max-de-rooij Aug 2, 2024
25deb61
Bump version to 2.7
max-de-rooij Aug 7, 2024
0d87dc3
Simplify and remove higher orders
max-de-rooij Aug 8, 2024
ed8dd55
Update mutual_information_method.jl
max-de-rooij Aug 8, 2024
7e74fd9
Update mutual_information_method.jl
max-de-rooij Aug 8, 2024
dd9c4fc
Apply formatter
max-de-rooij Aug 8, 2024
9e4e7e1
Update Project.toml
max-de-rooij Aug 9, 2024
bda6801
Update mutual_information_sensitivity.jl
max-de-rooij Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "2.6.2"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
ComplexityMeasures = "ab4b797d-85ee-42ba-b621-05d793b346a2"
Copulas = "ae264745-0b69-425e-9d9d-cf662c5eec93"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Expand All @@ -23,6 +24,7 @@ Trapz = "592b5752-818d-11e9-1e9a-2b8ca4a44cd1"
[compat]
Aqua = "0.8"
Combinatorics = "1"
ComplexityMeasures = "3.6"
Copulas = "0.1.22"
Distributions = "0.25.87"
FFTW = "1.3"
Expand Down
3 changes: 2 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ pages = [
"methods/easi.md",
"methods/fractional.md",
"methods/rbdfast.md",
"methods/rsa.md"]
"methods/rsa.md",
"methods/mutualinformation.md"]
]
5 changes: 5 additions & 0 deletions docs/src/methods/mutualinformation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Mutual Information Method

```@docs
MutualInformation
```
4 changes: 3 additions & 1 deletion src/GlobalSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using QuasiMonteCarlo, ForwardDiff, KernelDensity, Trapz
using Parameters: @unpack
using FFTW, Distributions, StatsBase
using Copulas, Combinatorics, ThreadsX
using ComplexityMeasures: entropy, ValueHistogram, StateSpaceSet

abstract type GSAMethod end

Expand All @@ -19,6 +20,7 @@ include("rbd-fast_sensitivity.jl")
include("fractional_factorial_sensitivity.jl")
include("shapley_sensitivity.jl")
include("rsa_sensitivity.jl")
include("mutual_information_sensitivity.jl")

"""
gsa(f, method::GSAMethod, param_range; samples, batch=false)
Expand Down Expand Up @@ -61,7 +63,7 @@ function gsa(f, method::GSAMethod, param_range; samples, batch = false) end
export gsa

export Sobol, Morris, RegressionGSA, DGSM, eFAST, DeltaMoment, EASI, FractionalFactorial,
RBDFAST, Shapley, RSA
RBDFAST, Shapley, RSA, MutualInformation
# Added for shapley_sensitivity

end # module
272 changes: 272 additions & 0 deletions src/mutual_information_sensitivity.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
@doc raw"""

MutualInformation(; order = [0, 1], nboot = 1, conf_level = 0.95, n_bin_configurations = 800, n_samples_per_configuration = 100)

- `order`: A vector of integers specifying the order of sensitivity indices to be calculated. Default is `[0, 1]`. Possible values
are `[0]` for total order indices, `[1]` for first order indices, and `[2]` for second order indices.
- `nboot`: Number of bootstraps to be used for confidence interval estimation. Default is `1`.
- `conf_level`: Confidence level for the confidence interval estimation. Default is `0.95`.
- `n_bin_configurations`: Number of bin configurations to be used for discretization entropy estimation. Default is `800`.
- `n_samples_per_configuration`: Number of samples per bin configuration, used for estimation of discretization entropy.
Default is `100`.


## Method Details

The sensitivity analysis based on mutual information is an alternative approach to sensitivity analysis based on information theoretic measures. In this method,
the output uncertainty is quantified by the entropy of the output distribution, instead of taking a variance-based approach. The Shannon entropy of the output is
given by:

```math
H(Y) = -\sum_y p(y) \log p(y)
```
Where ``p(y)`` is the probability density function of the output ``Y``. By fixing an input ``X_i``, the conditional entropy of the output ``Y`` is given by:

```math
H(Y|X_i) = -\sum_{x} p(x) H(Y|X_i = x)
```

The mutual information between the input ``X_i`` and the output ``Y`` is then given by:

```math
I(X_i;Y) = H(Y) - H(Y|X_i) = H(X) + H(Y) - H(X,Y)
```

Where ``H(X,Y)`` is the joint entropy of the input and output. The mutual information can be used to calculate the sensitivity indices of the input parameters.

### First Order Sensitivity Indices
The first order sensitivity indices are calculated as the mutual information between the input ``X_i`` and the output ``Y`` divided by the entropy of the output ``Y``:

```math
S_{1,i} = \frac{I(X_i;Y)}{H(Y)}
```

This measure is introduced in Lüdtke et al. (2007)[^1] and also present in Datseris & Parlitz (2022)[^2] in an unnormalized form.

### Second Order Sensitivity Indices
To account for the interaction between input parameters, and their effect on the output, the second order sensitivity indices can be calculated using the
conditional mutual information between two input parameters ``X_i`` and ``X_j`` given the output ``Y``. They have to be corrected for the correlation between the input parameters. The
second order sensitivity indices according to Lüdtke et al. (2007)[^1] are given by:

```math
S_{2,i} = \frac{I(X_i;X_j|Y) - I(X_i;X_j)}{H(Y)}
```

### Total Order Sensitivity Indices
From Lüdtke et al. (2007)[^1], the total order sensitivity indices can be calculated as:

```math
S_{\text{total},i} = \frac{H(Y) - H(Y|\{X_1,...,X_n\} \\ X_i)}{H(Y) - H_{\Delta}}}
```

Where ``H_{\Delta}`` is the discretization entropy of the output, which is introduced to account for the discretization of the input space.

## API

gsa(f, method::MutualInformation, p_range; samples, batch = false)

Returns a `MutualInformationResult` object containing the resulting sensitivity indices for the parameters and the corresponding confidence intervals.
The `MutualInformationResult` object contains the following fields:
- `S1`: First order sensitivity indices.
- `S1_Conf_Int`: Confidence intervals for the first order sensitivity indices.
- `S2`: Second order sensitivity indices.
- `S2_Conf_Int`: Confidence intervals for the second order sensitivity indices.
- `ST`: Total order sensitivity indices.
- `ST_Conf_Int`: Confidence intervals for the total order sensitivity indices.

For fields that are not calculated, the corresponding field in the result will be an array of zeros.

### Example

```julia
using GlobalSensitivity

function ishi_batch(X)
A= 7
B= 0.1
@. sin(X[1,:]) + A*sin(X[2,:])^2+ B*X[3,:]^4 *sin(X[1,:])
end
function ishi(X)
A= 7
B= 0.1
sin(X[1]) + A*sin(X[2])^2+ B*X[3]^4 *sin(X[1])
end

lb = -ones(4)*π
ub = ones(4)*π

res1 = gsa(ishi,MutualInformation(),[[lb[i],ub[i]] for i in 1:4],samples=15000)
res2 = gsa(ishi_batch,MutualInformation(),[[lb[i],ub[i]] for i in 1:4],samples=15000,batch=true)
```

### References
[^1]: Lüdtke, N., Panzeri, S., Brown, M., Broomhead, D. S., Knowles, J., Montemurro, M. A., & Kell, D. B. (2007). Information-theoretic sensitivity analysis: a general method for credit assignment in complex networks. Journal of The Royal Society Interface, 5(19), 223–235.
[^2]: Datseris, G., & Parlitz, U. (2022). Nonlinear Dynamics, Ch. 7, pg. 105-119.
"""
struct MutualInformation <: GSAMethod
order::Vector{Int}
nboot::Int
conf_level::Real
n_bin_configurations::Int
n_samples_per_configuration::Int
end

MutualInformation(; order = [0, 1], nboot = 1, conf_level = 0.95, n_bin_configurations = 800, n_samples_per_configuration = 100) = MutualInformation(order, nboot, conf_level, n_bin_configurations, n_samples_per_configuration)

struct MutualInformationResult{T1, T2, T3, T4}
S1::T1
S1_Conf_Int::T2
S2::T3
S2_Conf_Int::T4
ST::T1
ST_Conf_Int::T2
end

function _total_order_ci(Xi, Y, entropy_Y, discretization_entropy; n_bootstraps = 100, conf_level = 0.95)
# perform permutations of Y and calculate total order
mi_values = zeros(n_bootstraps)
est = ValueHistogram(Int(round(sqrt(length(Y)))))
for i in 1:n_bootstraps
Y_perm = Y[randperm(length(Y))]

conditional_entropy = entropy(est, StateSpaceSet(Xi, Y_perm)) - entropy(est, Xi)
mi_values[i] = (entropy_Y - conditional_entropy) / (entropy_Y - discretization_entropy)
end

α = 1 - conf_level

return quantile(mi_values, [α/2, 1 - α/2])
end

function _first_order_ci(Xi, Y; n_bootstraps = 100, conf_level = 0.95)

# perform permutations of Y and calculate mutual information
mi_values = zeros(n_bootstraps)
est = ValueHistogram(Int(round(sqrt(length(Y)))))
for i in 1:n_bootstraps
Y_perm = Y[randperm(length(Y))]
max-de-rooij marked this conversation as resolved.
Show resolved Hide resolved
mutual_information = entropy(est, Xi) + entropy(est, Y_perm) - entropy(est, StateSpaceSet(Xi, Y_perm))
max-de-rooij marked this conversation as resolved.
Show resolved Hide resolved
mi_values[i] = mutual_information / entropy(est, Y_perm)
end

α = 1 - conf_level

return quantile(mi_values, [α/2, 1 - α/2])
end

function _second_order_ci(Xi, Xj, Y; n_bootstraps = 100, conf_level = 0.95)

# perform permutations of Y and calculate second order mutual information
mi_values = zeros(n_bootstraps)
est = ValueHistogram(Int(round(sqrt(length(Y)))))
for i in 1:n_bootstraps
Y_perm = Y[randperm(length(Y))]
conditional_mutual_information = entropy(est, StateSpaceSet(Xi, Y_perm)) + entropy(est, StateSpaceSet(Xj, Y_perm)) - entropy(est, StateSpaceSet(Xi, Xj, Y_perm)) - entropy(est, Y_perm)
mutual_information = entropy(est, Xi) + entropy(est, Xj) - entropy(est, StateSpaceSet(Xi, Xj))
mi_values[i] = (conditional_mutual_information - mutual_information) / entropy(est, Y_perm)
end

α = 1 - conf_level


return quantile(mi_values, [α/2, 1 - α/2])
end

function _discretization_entropy(X::AbstractArray, f, batch; n_bin_configurations = 800, n_samples_per_configuration = 100)
n_bins = Int(round(sqrt(size(X, 2))))
n_dims = size(X, 1)

span = (maximum(X, dims = 2) .- minimum(X, dims = 2)) ./ n_bins

entropy_Y = 0.0
for _ in 1:n_bin_configurations
config = rand(1:n_bins, n_dims)
bin_edges = [span .* (config .- 1) span .* config]
if batch
samples_Y = f(hcat([rand.(Uniform.(bin_edges[:,1], bin_edges[:,2])) for _ in 1:n_samples_per_configuration]...))
else
samples_Y = [f(rand.(Uniform.(bin_edges[:,1], bin_edges[:,2]))) for _ in 1:n_samples_per_configuration]
end
entropy_Y += entropy(ValueHistogram(Int(round(sqrt(length(samples_Y))))), samples_Y)
end

return entropy_Y / n_bin_configurations
end

function _compute_mi(X::AbstractArray, f, batch::Bool, method::MutualInformation)

discretization_entropy = _discretization_entropy(X, f, batch; n_bin_configurations = method.n_bin_configurations, n_samples_per_configuration = method.n_samples_per_configuration)
if batch
Y = f(X)
multioutput = Y isa AbstractMatrix
else
Y = [f(X[:, j]) for j in axes(X, 2)]
multioutput = !(eltype(Y) <: Number)
if eltype(Y) <: RecursiveArrayTools.AbstractVectorOfArray
y_size = size(Y[1])
Y = vec.(Y)
desol = true
end
end

# K is the number of variables, samples is the number of simulations
K = size(X, 1)

if method.nboot > size(X, 2)
throw(ArgumentError("Number of bootstraps must be less than or equal to the number of samples"))
end
est = ValueHistogram(Int(round(sqrt(size(Y, 1)))))
entropy_Y = entropy(est, Y)

total_order = zeros(K)
total_order_ci = zeros(K, 2)

first_order = zeros(K)
first_order_ci = zeros(K, 2)

second_order = zeros(K, K)
second_order_ci = zeros(K, K, 2)

# calculate total order
if 0 in method.order
@inbounds for i in 1:K
Xi = @view X[i, :]
conditional_entropy = entropy(est, StateSpaceSet(Xi, Y)) - entropy(est, Xi)
total_order[i] = (entropy_Y - conditional_entropy) / (entropy_Y - discretization_entropy)
total_order_ci[i, :] = _total_order_ci(Xi, Y, entropy_Y, discretization_entropy, n_bootstraps = method.nboot, conf_level = method.conf_level)
end
end

if 1 in method.order
# calculate mutual information
@inbounds for i in 1:K
Xi = @view X[i, :]
mutual_information = entropy(est, Xi) + entropy_Y - entropy(est, StateSpaceSet(Xi, Y))
first_order[i] = mutual_information / entropy_Y
first_order_ci[i, :] .= _first_order_ci(Xi, Y, n_bootstraps = method.nboot, conf_level = method.conf_level)
end
end

if 2 in method.order
for (i, j) in combinations(1:K, 2)
Xi = @view X[i, :]
Xj = @view X[j, :]
conditional_mutual_information = entropy(est, StateSpaceSet(Xi, Y)) + entropy(est, StateSpaceSet(Xj, Y)) - entropy(est, StateSpaceSet(Xi, Xj, Y)) - entropy_Y
mutual_information = entropy(est, Xi) + entropy(est, Xj) - entropy(est, StateSpaceSet(Xi, Xj))
second_order[i,j] = second_order[j,i] = (conditional_mutual_information - mutual_information) / entropy_Y
second_order_ci[i,j,:] = second_order_ci[j,i,:] = _second_order_ci(Xi, Xj, Y, n_bootstraps = method.nboot, conf_level = method.conf_level)
end
end

total_order, total_order_ci, first_order, first_order_ci, second_order, second_order_ci
end

function gsa(f, method::MutualInformation, p_range; samples, batch = false)
lb = [i[1] for i in p_range]
ub = [i[2] for i in p_range]

X = QuasiMonteCarlo.sample(samples, lb, ub, QuasiMonteCarlo.SobolSample())
total_order, total_order_ci, first_order, first_order_ci, second_order, second_order_ci = _compute_mi(X, f, batch, method)

return MutualInformationResult(first_order, first_order_ci, second_order, second_order_ci, total_order, total_order_ci)
end
67 changes: 67 additions & 0 deletions test/mutual_information_method.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using GlobalSensitivity, Test, QuasiMonteCarlo

function ishi_batch(X)
A = 7
B = 0.1
@. sin(X[1, :]) + A * sin(X[2, :])^2 + B * X[3, :]^4 * sin(X[1, :])
end
function ishi(X)
A = 7
B = 0.1
sin(X[1]) + A * sin(X[2])^2 + B * X[3]^4 * sin(X[1])
end

function linear_batch(X)
A = 7
B = 0.1
@. A * X[1, :] + B * X[2, :]
end
function linear(X)
A = 7
B = 0.1
A * X[1] + B * X[2]
end

lb = -ones(4) * π
ub = ones(4) * π

res1 = gsa(
ishi, MutualInformation(order=[0,1,2]), [[lb[i], ub[i]] for i in 1:4], samples = 10_000)

res2 = gsa(
ishi_batch, MutualInformation(order=[0,1,2]), [[lb[i], ub[i]] for i in 1:4],
samples = 10_000, batch = true)

res1.S1_Conf_Int

@test res1.S1 ≈ [0.1416, 0.1929, 0.1204, 0.0925] atol = 1e-3
@test [0.09, 0.09, 0.09, 0.09] <= res1.S1_Conf_Int[:,1] <= [0.1, 0.1, 0.1, 0.1]
@test res2.S1 ≈ [0.1416, 0.1929, 0.1204, 0.0925] atol = 1e-3
@test [0.09, 0.09, 0.09, 0.09] <= res2.S1_Conf_Int[:,1] <= [0.1, 0.1, 0.1, 0.1]

@test sortperm(res1.ST) == [4,3,1,2]
@test sortperm(res2.ST) == [4,3,1,2]

@test res1.S2 ≈ [
0.0 0.576849 0.656412 0.681677
0.576849 0.0 0.609111 0.615966
0.656412 0.609111 0.0 0.661516
0.681677 0.615966 0.661516 0.0
] atol = 1e-2

@test res2.S2 ≈ [
0.0 0.576849 0.656412 0.681677
0.576849 0.0 0.609111 0.615966
0.656412 0.609111 0.0 0.661516
0.681677 0.615966 0.661516 0.0
] atol = 1e-2

res1 = gsa(
linear, MutualInformation(), [[lb[i], ub[i]] for i in 1:4], samples = 10_000)
res2 = gsa(
linear_batch, MutualInformation(), [[lb[i], ub[i]] for i in 1:4], batch = true,
samples = 10_000)

@test res1.S1 ≈ [0.8155, 0.08997, 0.09096, 0.09747] atol = 1e-3
@test res2.S1 ≈ [0.8155, 0.08997, 0.09096, 0.09747] atol = 1e-3

Loading
Loading