Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ annotate_untyped_fields_with_any = false
conditional_to_if = true
format_docstrings = true
join_lines_based_on_source = true
long_to_short_function_def = true
long_to_short_function_def = false
margin = 100
separate_kwargs_with_semicolon = true
short_to_long_function_def = true
Expand Down
6 changes: 3 additions & 3 deletions src/Fit/Fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ Functions for fitting to data.
"""
module Fit
using Compat: @compat
@compat public fit_spectrum, ews_to_abundances, ews_to_stellar_parameters,
ews_to_stellar_parameters_direct

using ..Korg, ForwardDiff

include("fit_via_synthesis.jl")
include("fit_via_EWs.jl")
include("fit_via_multimethod.jl")

@compat public fit_spectrum, ews_to_abundances, ews_to_stellar_parameters,
ews_to_stellar_parameters_direct, multimethod_abundances
end # module
141 changes: 141 additions & 0 deletions src/Fit/fit_via_multimethod.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
using ..Korg
using Statistics: mean, std
using Interpolations: linear_interpolation
using Trapz: trapz
using ProgressMeter

"""
TODO
TODO don't hardcode Fe / [m/H]
TODO make sure solar abundances are flexible. I think they currently are.

# Keyword arguments:

- `verbose` (default: `true`: whether to print progress information

# TODO kwarg everything?
"""
function multimethod_abundances(linelist, element, guess, lines_to_fit, obs_flux, obs_wls, R,
err=ones(length(obs_flux)); verbose=true,
abundance_perturbations=-0.6:0.3:0.6, synth_kwargs...)
# TODO check no conflicting R specification
# check that lines_to_fit are within the range of obs_wls

# wavelengths, windows and LSF, TODO refactor
windows = [(l - 1, l + 1) for l in lines_to_fit] # TODO kwarg
merged_windows, _ = Korg.merge_bounds(windows, 0.0)
synthesis_wls, obs_wl_mask, LSF = Korg.Fit._setup_wavelengths_and_LSF(obs_wls,
nothing,
nothing, R,
merged_windows,
0)

if !issorted(lines_to_fit)
lines_to_fit = sort(lines_to_fit)
end
if lines_to_fit[1] < obs_wls[1] || lines_to_fit[end] > obs_wls[end]
throw(ArgumentError("lines_to_fit includes lines outside of obs_wls"))
end

obs_flux = obs_flux[obs_wl_mask]
err = err[obs_wl_mask]
obs_wls = obs_wls[obs_wl_mask]
line_indices = map(lines_to_fit) do λ
# this is dumb
searchsortedfirst(obs_wls, λ - 0.2):searchsortedlast(obs_wls, λ + 0.2)
end

linelist = Korg.filter_linelist(linelist, synthesis_wls, 10.0; warn_about_time=false)
abundances = abundance_perturbations .+ guess

# would be cool to multithread
model_spectra = @showprogress "synthesizing" ennabled=verbose map(abundances) do X_H
params = Dict([
:wavelengths => synthesis_wls,
:linelist => linelist,
Symbol(element) => X_H,
synth_kwargs...
])
flux = synth(; params...)[2]
LSF * flux
end
model_spectra = hcat(model_spectra...) # shape (n_wls, n_models)

# TODO better
pseudocontinuum_mask = abs.(model_spectra[:, end] - model_spectra[:, 1]) .< 0.01 # TODO kwarg
model_pseudocontinua = [calculate_pseudocontinuum(obs_wls, m, pseudocontinuum_mask)
for m in eachcol(model_spectra)]
model_pseudocontinua = hcat(model_pseudocontinua...)
obs_pseudocontinuum = calculate_pseudocontinuum(obs_wls, obs_flux, pseudocontinuum_mask)

EWs = map(eachcol(model_spectra), eachcol(model_pseudocontinua)) do m, p
calculate_rough_EWs(obs_wls, m, p, line_indices)
end
EWs = hcat(EWs...)
obs_EWs = calculate_rough_EWs(obs_wls, obs_flux, obs_pseudocontinuum, line_indices)
EW_itps = interpolate_and_predict(abundances, obs_EWs, EWs)
EW_As = [itp(EW) for (itp, EW) in zip(EW_itps, obs_EWs)]

zscores = ((obs_flux .- model_spectra) ./ err) .^ 2
chi2 = Matrix{Float64}(undef, length(line_indices), length(abundances))
for (i, r) in enumerate(line_indices)
chi2[i, :] .= sum(zscores[r, :]; dims=1)[:]
end
chi2_itps, chi2_coeffs = quadratic_minimizers(abundances, chi2)
chi2_As = -chi2_coeffs[2, :] ./ 2chi2_coeffs[3, :]

depths = [calculate_line_core_depths(m, pc, line_indices)
for (m, pc) in zip(eachcol(model_spectra), eachcol(model_pseudocontinua))]
depths = hcat(depths...)
obs_depths = calculate_line_core_depths(obs_flux, obs_pseudocontinuum, line_indices)
depth_itps = interpolate_and_predict(abundances, obs_depths, depths)
depth_As = [itp.(d) for (itp, d) in zip(depth_itps, obs_depths)]

(; EW_As, depth_As, chi2_As, obs_wl_mask, model_spectra, pseudocontinuum_mask,
obs_pseudocontinuum, EWs, obs_EWs, EW_itps, line_indices, model_pseudocontinua, chi2,
chi2_coeffs, chi2_itps, depths, obs_depths, depth_itps)
end

"""
Fit a quadratic model to chi2(abundance_perturbation), and report the minimizer
"""
function quadratic_minimizers(abundance_perturbations, chi2)
# fit a quadratic model
A = hcat((abundance_perturbations .^ i for i in 0:2)...)
coeffs = A \ chi2'
map(eachcol(coeffs)) do c
x -> sum(c .* x .^ (0:2))
end, coeffs
end

function interpolate_and_predict(abundance_perturbations, observed_quantities, model_quantities)
map(eachrow(model_quantities)) do quantities
A = hcat((quantities .^ i for i in 0:3)...)
coeffs = A \ abundance_perturbations
x -> sum(coeffs .* x .^ (0:3))
end
end

function calculate_pseudocontinuum(obs_wls, flux, pseudocontinuum_mask)
linear_interpolation(obs_wls[pseudocontinuum_mask], flux[pseudocontinuum_mask];
extrapolation_bc=1.0)(obs_wls)
end

function calculate_rough_EWs(obs_wls, spectrum, pseudocontinuum, line_indices)
absorption = pseudocontinuum - spectrum
map(line_indices) do r
trapz(obs_wls[r], absorption[r]) * 1e3 # convert to mÅ
end
end

function calculate_line_core_depths(spectrum, pseudocontinuum, line_indices)
map(line_indices) do r
# get the middle 5 pixels in the range
# TODO everthing about this is terrible
firstind = r[1] + length(r) ÷ 2 - 2
lastind = firstind + 4
firstind = max(1, firstind)
lastind = min(length(spectrum), lastind)
mean(spectrum[firstind:lastind] - pseudocontinuum[firstind:lastind])
end
end
38 changes: 30 additions & 8 deletions src/synthesize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,21 +314,43 @@ end
filter_linelist(linelist, wls, line_buffer)

Return a new linelist containing only lines within the provided wavelength ranges.

# Arguments:

- `linelist`: The input linelist to filter.
- `wls`: The wavelengths to filter the linelist by. Type: [`Korg.Wavelengths`](@ref)
- `line_buffer`: The buffer around each wavelength range to include in the filtered linelist. (Default: 10 Å)

# Keyword Arguments:

- `warn_empty`: Whether to warn if the filtered linelist is empty.
- `warn_about_time`: Whether to warn if the filtering process takes a long time.
"""
function filter_linelist(linelist, wls, line_buffer; warn_empty=true)
function filter_linelist(linelist, wls, line_buffer::AbstractFloat=10.0;
warn_empty=true, warn_about_time=true)
if line_buffer > 1 # it's in Å
line_buffer *= 1e-8
end
nlines_before = length(linelist)

last_line_index = 0 # we need to keep track of this across iterations to avoid double-counting lines.
sub_ranges = map(eachwindow(wls)) do (λstart, λstop)
first_line_index = searchsortedfirst(linelist, (; wl=λstart - line_buffer); by=l -> l.wl)
# ensure we don't double-count lines.
first_line_index = max(first_line_index, last_line_index + 1)
t = @elapsed begin
sub_ranges = map(eachwindow(wls)) do (λstart, λstop)
first_line_index = searchsortedfirst(linelist, (; wl=λstart - line_buffer);
by=l -> l.wl)
# ensure we don't double-count lines.
first_line_index = max(first_line_index, last_line_index + 1)

last_line_index = searchsortedlast(linelist, (; wl=λstop + line_buffer); by=l -> l.wl)

last_line_index = searchsortedlast(linelist, (; wl=λstop + line_buffer); by=l -> l.wl)
first_line_index:last_line_index
end
linelist = vcat((linelist[r] for r in sub_ranges)...)
end

first_line_index:last_line_index
if t > 0.2 && warn_about_time
@warn "Filtering the linelist to the requested wavelength range took $t seconds. Consider prefiltering the linelist with Korg.filter_linelist to avoid duplicating this work many times."
end
linelist = vcat((linelist[r] for r in sub_ranges)...)

if nlines_before != 0 && length(linelist) == 0 && warn_empty
@warn "The provided linelist was not empty, but none of the lines were within the provided wavelength range."
Expand Down
8 changes: 7 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Statistics: quantile
using Core: Argument
using Interpolations: linear_interpolation, Flat
using SparseArrays: spzeros

Expand Down Expand Up @@ -40,7 +41,12 @@ end

# Convert R to a value based on its type
# used in `_lsf_bounds_and_kernel`
_resolve_R(R::Real, λ0) = R
function _resolve_R(R::Real, λ0)
if !isfinite(R)
throw(ArgumentError("R (resolving power) must be a finite number (is $R)"))
end
R
end
_resolve_R(R::Function, λ0) = R(λ0 * 1e8) # R is a function of λ in Å

# Core LSF calculation shared by all variants
Expand Down
25 changes: 25 additions & 0 deletions test/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -398,4 +398,29 @@ using Random
@test_throws m Korg.Fit.ews_to_stellar_parameters(linelist, erroneous_EWs)
end
end

@testset "MultiMethod" begin
# synthesize a spectrum and get the answer back.
# Use synthesize to make this a better integration test,
# because multi_method_abundaces is based on synth

linelist = Korg.get_VALD_solar_linelist()
A_X = format_A_X(-0.3, -0.2, Dict("N" => -0.1))
atm = interpolate_marcs(5777, 4.44, A_X)
sol = synthesize(atm, linelist, A_X, 5000, 5500)

fe_lines = Korg.air_to_vacuum.([5044.211 5054.642 5127.359 5127.679 5198.711 5225.525 5242.491 5247.05 5250.208 5295.312 5322.041 5373.709 5379.574 5386.334 5466.396 5466.987][:])

obs_wl = sol.wavelengths
obs_flux = Korg.apply_LSF(sol.flux ./ sol.cntm, sol.wavelengths, 100_000)

out = Korg.Fit.multimethod_abundances(linelist, "m_H", 0.0, fe_lines, obs_flux, obs_wl,
100_000;
abundance_perturbations, alpha_H=-0.2, N=-0.1,
Teff=5777, logg=4.44)

@test all(out.EW_As ≈ -0.3) atol=0.001
@test all(out.depth_As ≈ -0.3) atol=0.001
@test all(out.chi2_As ≈ -0.3) atol=0.001
end
end
Loading