diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 73b8d7a6..99c86b02 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -7,7 +7,7 @@ annotate_untyped_fields_with_any = false conditional_to_if = true format_docstrings = true join_lines_based_on_source = true -long_to_short_function_def = true +long_to_short_function_def = false margin = 100 separate_kwargs_with_semicolon = true short_to_long_function_def = true diff --git a/src/Fit/Fit.jl b/src/Fit/Fit.jl index 7dc752ae..aaa8e319 100644 --- a/src/Fit/Fit.jl +++ b/src/Fit/Fit.jl @@ -7,12 +7,12 @@ Functions for fitting to data. """ module Fit using Compat: @compat -@compat public fit_spectrum, ews_to_abundances, ews_to_stellar_parameters, - ews_to_stellar_parameters_direct - using ..Korg, ForwardDiff include("fit_via_synthesis.jl") include("fit_via_EWs.jl") +include("fit_via_multimethod.jl") +@compat public fit_spectrum, ews_to_abundances, ews_to_stellar_parameters, + ews_to_stellar_parameters_direct, multimethod_abundances end # module diff --git a/src/Fit/fit_via_multimethod.jl b/src/Fit/fit_via_multimethod.jl new file mode 100644 index 00000000..923a6874 --- /dev/null +++ b/src/Fit/fit_via_multimethod.jl @@ -0,0 +1,141 @@ +using ..Korg +using Statistics: mean, std +using Interpolations: linear_interpolation +using Trapz: trapz +using ProgressMeter + +""" +TODO +TODO don't hardcode Fe / [m/H] +TODO make sure solar abundances are flexible. I think they currently are. + +# Keyword arguments: + + - `verbose` (default: `true`: whether to print progress information + +# TODO kwarg everything? +""" +function multimethod_abundances(linelist, element, guess, lines_to_fit, obs_flux, obs_wls, R, + err=ones(length(obs_flux)); verbose=true, + abundance_perturbations=-0.6:0.3:0.6, synth_kwargs...) + # TODO check no conflicting R specification + # check that lines_to_fit are within the range of obs_wls + + # wavelengths, windows and LSF, TODO refactor + windows = [(l - 1, l + 1) for l in lines_to_fit] # TODO kwarg + merged_windows, _ = Korg.merge_bounds(windows, 0.0) + synthesis_wls, obs_wl_mask, LSF = Korg.Fit._setup_wavelengths_and_LSF(obs_wls, + nothing, + nothing, R, + merged_windows, + 0) + + if !issorted(lines_to_fit) + lines_to_fit = sort(lines_to_fit) + end + if lines_to_fit[1] < obs_wls[1] || lines_to_fit[end] > obs_wls[end] + throw(ArgumentError("lines_to_fit includes lines outside of obs_wls")) + end + + obs_flux = obs_flux[obs_wl_mask] + err = err[obs_wl_mask] + obs_wls = obs_wls[obs_wl_mask] + line_indices = map(lines_to_fit) do λ + # this is dumb + searchsortedfirst(obs_wls, λ - 0.2):searchsortedlast(obs_wls, λ + 0.2) + end + + linelist = Korg.filter_linelist(linelist, synthesis_wls, 10.0; warn_about_time=false) + abundances = abundance_perturbations .+ guess + + # would be cool to multithread + model_spectra = @showprogress "synthesizing" ennabled=verbose map(abundances) do X_H + params = Dict([ + :wavelengths => synthesis_wls, + :linelist => linelist, + Symbol(element) => X_H, + synth_kwargs... + ]) + flux = synth(; params...)[2] + LSF * flux + end + model_spectra = hcat(model_spectra...) # shape (n_wls, n_models) + + # TODO better + pseudocontinuum_mask = abs.(model_spectra[:, end] - model_spectra[:, 1]) .< 0.01 # TODO kwarg + model_pseudocontinua = [calculate_pseudocontinuum(obs_wls, m, pseudocontinuum_mask) + for m in eachcol(model_spectra)] + model_pseudocontinua = hcat(model_pseudocontinua...) + obs_pseudocontinuum = calculate_pseudocontinuum(obs_wls, obs_flux, pseudocontinuum_mask) + + EWs = map(eachcol(model_spectra), eachcol(model_pseudocontinua)) do m, p + calculate_rough_EWs(obs_wls, m, p, line_indices) + end + EWs = hcat(EWs...) + obs_EWs = calculate_rough_EWs(obs_wls, obs_flux, obs_pseudocontinuum, line_indices) + EW_itps = interpolate_and_predict(abundances, obs_EWs, EWs) + EW_As = [itp(EW) for (itp, EW) in zip(EW_itps, obs_EWs)] + + zscores = ((obs_flux .- model_spectra) ./ err) .^ 2 + chi2 = Matrix{Float64}(undef, length(line_indices), length(abundances)) + for (i, r) in enumerate(line_indices) + chi2[i, :] .= sum(zscores[r, :]; dims=1)[:] + end + chi2_itps, chi2_coeffs = quadratic_minimizers(abundances, chi2) + chi2_As = -chi2_coeffs[2, :] ./ 2chi2_coeffs[3, :] + + depths = [calculate_line_core_depths(m, pc, line_indices) + for (m, pc) in zip(eachcol(model_spectra), eachcol(model_pseudocontinua))] + depths = hcat(depths...) + obs_depths = calculate_line_core_depths(obs_flux, obs_pseudocontinuum, line_indices) + depth_itps = interpolate_and_predict(abundances, obs_depths, depths) + depth_As = [itp.(d) for (itp, d) in zip(depth_itps, obs_depths)] + + (; EW_As, depth_As, chi2_As, obs_wl_mask, model_spectra, pseudocontinuum_mask, + obs_pseudocontinuum, EWs, obs_EWs, EW_itps, line_indices, model_pseudocontinua, chi2, + chi2_coeffs, chi2_itps, depths, obs_depths, depth_itps) +end + +""" +Fit a quadratic model to chi2(abundance_perturbation), and report the minimizer +""" +function quadratic_minimizers(abundance_perturbations, chi2) + # fit a quadratic model + A = hcat((abundance_perturbations .^ i for i in 0:2)...) + coeffs = A \ chi2' + map(eachcol(coeffs)) do c + x -> sum(c .* x .^ (0:2)) + end, coeffs +end + +function interpolate_and_predict(abundance_perturbations, observed_quantities, model_quantities) + map(eachrow(model_quantities)) do quantities + A = hcat((quantities .^ i for i in 0:3)...) + coeffs = A \ abundance_perturbations + x -> sum(coeffs .* x .^ (0:3)) + end +end + +function calculate_pseudocontinuum(obs_wls, flux, pseudocontinuum_mask) + linear_interpolation(obs_wls[pseudocontinuum_mask], flux[pseudocontinuum_mask]; + extrapolation_bc=1.0)(obs_wls) +end + +function calculate_rough_EWs(obs_wls, spectrum, pseudocontinuum, line_indices) + absorption = pseudocontinuum - spectrum + map(line_indices) do r + trapz(obs_wls[r], absorption[r]) * 1e3 # convert to mÅ + end +end + +function calculate_line_core_depths(spectrum, pseudocontinuum, line_indices) + map(line_indices) do r + # get the middle 5 pixels in the range + # TODO everthing about this is terrible + firstind = r[1] + length(r) ÷ 2 - 2 + lastind = firstind + 4 + firstind = max(1, firstind) + lastind = min(length(spectrum), lastind) + mean(spectrum[firstind:lastind] - pseudocontinuum[firstind:lastind]) + end +end diff --git a/src/synthesize.jl b/src/synthesize.jl index c620c05a..133a1bfc 100644 --- a/src/synthesize.jl +++ b/src/synthesize.jl @@ -314,21 +314,43 @@ end filter_linelist(linelist, wls, line_buffer) Return a new linelist containing only lines within the provided wavelength ranges. + +# Arguments: + + - `linelist`: The input linelist to filter. + - `wls`: The wavelengths to filter the linelist by. Type: [`Korg.Wavelengths`](@ref) + - `line_buffer`: The buffer around each wavelength range to include in the filtered linelist. (Default: 10 Å) + +# Keyword Arguments: + + - `warn_empty`: Whether to warn if the filtered linelist is empty. + - `warn_about_time`: Whether to warn if the filtering process takes a long time. """ -function filter_linelist(linelist, wls, line_buffer; warn_empty=true) +function filter_linelist(linelist, wls, line_buffer::AbstractFloat=10.0; + warn_empty=true, warn_about_time=true) + if line_buffer > 1 # it's in Å + line_buffer *= 1e-8 + end nlines_before = length(linelist) last_line_index = 0 # we need to keep track of this across iterations to avoid double-counting lines. - sub_ranges = map(eachwindow(wls)) do (λstart, λstop) - first_line_index = searchsortedfirst(linelist, (; wl=λstart - line_buffer); by=l -> l.wl) - # ensure we don't double-count lines. - first_line_index = max(first_line_index, last_line_index + 1) + t = @elapsed begin + sub_ranges = map(eachwindow(wls)) do (λstart, λstop) + first_line_index = searchsortedfirst(linelist, (; wl=λstart - line_buffer); + by=l -> l.wl) + # ensure we don't double-count lines. + first_line_index = max(first_line_index, last_line_index + 1) + + last_line_index = searchsortedlast(linelist, (; wl=λstop + line_buffer); by=l -> l.wl) - last_line_index = searchsortedlast(linelist, (; wl=λstop + line_buffer); by=l -> l.wl) + first_line_index:last_line_index + end + linelist = vcat((linelist[r] for r in sub_ranges)...) + end - first_line_index:last_line_index + if t > 0.2 && warn_about_time + @warn "Filtering the linelist to the requested wavelength range took $t seconds. Consider prefiltering the linelist with Korg.filter_linelist to avoid duplicating this work many times." end - linelist = vcat((linelist[r] for r in sub_ranges)...) if nlines_before != 0 && length(linelist) == 0 && warn_empty @warn "The provided linelist was not empty, but none of the lines were within the provided wavelength range." diff --git a/src/utils.jl b/src/utils.jl index b85e9ce1..b4129057 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,4 +1,5 @@ using Statistics: quantile +using Core: Argument using Interpolations: linear_interpolation, Flat using SparseArrays: spzeros @@ -40,7 +41,12 @@ end # Convert R to a value based on its type # used in `_lsf_bounds_and_kernel` -_resolve_R(R::Real, λ0) = R +function _resolve_R(R::Real, λ0) + if !isfinite(R) + throw(ArgumentError("R (resolving power) must be a finite number (is $R)")) + end + R +end _resolve_R(R::Function, λ0) = R(λ0 * 1e8) # R is a function of λ in Å # Core LSF calculation shared by all variants diff --git a/test/fit.jl b/test/fit.jl index ece588a6..527474d7 100644 --- a/test/fit.jl +++ b/test/fit.jl @@ -398,4 +398,29 @@ using Random @test_throws m Korg.Fit.ews_to_stellar_parameters(linelist, erroneous_EWs) end end + + @testset "MultiMethod" begin + # synthesize a spectrum and get the answer back. + # Use synthesize to make this a better integration test, + # because multi_method_abundaces is based on synth + + linelist = Korg.get_VALD_solar_linelist() + A_X = format_A_X(-0.3, -0.2, Dict("N" => -0.1)) + atm = interpolate_marcs(5777, 4.44, A_X) + sol = synthesize(atm, linelist, A_X, 5000, 5500) + + fe_lines = Korg.air_to_vacuum.([5044.211 5054.642 5127.359 5127.679 5198.711 5225.525 5242.491 5247.05 5250.208 5295.312 5322.041 5373.709 5379.574 5386.334 5466.396 5466.987][:]) + + obs_wl = sol.wavelengths + obs_flux = Korg.apply_LSF(sol.flux ./ sol.cntm, sol.wavelengths, 100_000) + + out = Korg.Fit.multimethod_abundances(linelist, "m_H", 0.0, fe_lines, obs_flux, obs_wl, + 100_000; + abundance_perturbations, alpha_H=-0.2, N=-0.1, + Teff=5777, logg=4.44) + + @test all(out.EW_As ≈ -0.3) atol=0.001 + @test all(out.depth_As ≈ -0.3) atol=0.001 + @test all(out.chi2_As ≈ -0.3) atol=0.001 + end end