diff --git a/src/NestedSamplers.jl b/src/NestedSamplers.jl
index 1c631fa8..829619ca 100644
--- a/src/NestedSamplers.jl
+++ b/src/NestedSamplers.jl
@@ -6,6 +6,7 @@ using Random: AbstractRNG, GLOBAL_RNG
 
 using AbstractMCMC
 using AbstractMCMC: AbstractSampler,
+                    LogDensityModel,
                     AbstractModel,
                     samples,
                     save!!
diff --git a/src/staticsampler.jl b/src/staticsampler.jl
index 1c7be396..f9e2384e 100644
--- a/src/staticsampler.jl
+++ b/src/staticsampler.jl
@@ -1,6 +1,7 @@
 # Sampler and model implementations
 
 struct Nested{B, P <: AbstractProposal} <: AbstractSampler
+    type::Type
     ndims::Int
     nactive::Int
     bounds::B
@@ -53,6 +54,7 @@ function Nested(ndims,
     enlarge = 1.25,
     min_ncall=2nactive,
     min_eff=0.10,
+    type = Float64,
     kwargs...)
 
     nactive < 2ndims && @warn "Using fewer than 2ndim ($(2ndims)) active points is discouraged"
@@ -72,7 +74,9 @@ function Nested(ndims,
 
     update_interval_frac = get(kwargs, :update_interval, default_update_interval(proposal, ndims))
     update_interval = round(Int, update_interval_frac * nactive)
-    return Nested(ndims,
+    return Nested(
+        type,
+        ndims,
         nactive,
         bounds,
         enlarge,
diff --git a/src/step.jl b/src/step.jl
index 8041bb17..e0d4c8d7 100644
--- a/src/step.jl
+++ b/src/step.jl
@@ -1,8 +1,43 @@
+#TODO it would be cool to make these parametric 
+struct NestedState
+    it::Int
+    ncall::Int
+    us
+    vs
+    logl
+    logl_dead
+    logz
+    logzerr
+    h
+    logvol
+    since_update::Int
+    has_bounds::Bool
+    active_bound
+end
+
+struct NestedTransition
+    #θ
+    u
+    v
+    logwt
+    logl
+end
 
-function step(rng, model, sampler::Nested; kwargs...)
+function step(
+    rng::Random.AbstractRNG, 
+    model, 
+    sampler::Nested; 
+    init_params = nothing,
+    kwargs...)
     # Initialize particles
     # us are in unit space, vs are in prior space
-    us, vs, logl = init_particles(rng, model, sampler)
+    if init_params != nothing && size(init_params) == (sampler.ndims,sampler.nactive)
+        us = init_params
+    else
+        @warn "Init params have wrong dimensions, sampling new ones!"
+        us = rand(rng, sampler.type, sampler.ndims, sampler.nactive)
+    end
+    us, vs, logl = init_particles(rng, us, model)
 
     # Find least likely point
     logl_dead, idx_dead = findmin(logl)
@@ -31,15 +66,20 @@ function step(rng, model, sampler::Nested; kwargs...)
     logzerr = sqrt(h / sampler.nactive)
     logvol -= 1 / sampler.nactive
 
-    sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead)
-    state = (it = 1, ncall = ncall, us = us, vs = vs, logl = logl, logl_dead = logl_dead,
-             logz = logz, logzerr = logzerr, h = h, logvol = logvol,
-             since_update = since_update, has_bounds = false, active_bound = nothing)
+    sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead) 
+    #sample = NestedTransition(u_dead, v_dead, logwt, logl_dead)
+    state = NestedState(1, ncall, us, vs, logl, logl_dead,
+        logz, logzerr, h, logvol, since_update, false, nothing)
 
     return sample, state
 end
 
-function step(rng, model, sampler, state; kwargs...)
+function step(
+    rng::Random.AbstractRNG, 
+    model, 
+    sampler::Nested, 
+    state::NestedState; 
+    kwargs...)
     ## Update bounds
     pointvol = exp(state.logvol) / sampler.nactive
     # check if ready for first update
@@ -102,10 +142,10 @@ function step(rng, model, sampler, state; kwargs...)
     logvol = state.logvol - 1 / sampler.nactive
 
     ## prepare returns
-    sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead)
-    state = (it = it, ncall = ncall, us = state.us, vs = state.vs, logl = state.logl, logl_dead = logl_dead,
-             logz = logz, logzerr = logzerr, h = h, logvol = logvol,
-             since_update = since_update, has_bounds = has_bounds, active_bound = active_bound)
+    sample = (u = u_dead, v = v_dead, logwt = logwt, logl = logl_dead) 
+    #sample = NestedTransition(u_dead, v_dead, logwt, logl_dead)
+    state = NestedState(it, ncall, state.us, state.vs, state.logl, logl_dead,
+        logz, logzerr, h, logvol, since_update, has_bounds, active_bound)
 
     return sample, state
 end
@@ -113,7 +153,7 @@ end
 function bundle_samples(samples,
         model::AbstractModel,
         sampler::Nested,
-        state,
+        state::NestedState,
         ::Type{Chains};
         add_live=true,
         param_names=missing,
@@ -143,7 +183,7 @@ end
 function bundle_samples(samples,
         model::AbstractModel,
         sampler::Nested,
-        state,
+        state::NestedState,
         ::Type{Array};
         add_live=true,
         check_wsum=true,
@@ -175,8 +215,7 @@ init_particles(rng, model, sampler) =
 
 # loop and fill arrays, checking validity of points
 # will retry 100 times before erroring
-function init_particles(rng, T, ndims, nactive, model)
-    us = rand(rng, T, ndims, nactive)
+function init_particles(rng::Random.AbstractRNG, us::Matrix, model::AbstractModel)
     vs_and_logl = mapslices(
         Base.Fix1(prior_transform_and_loglikelihood, model), us;
         dims=1
@@ -204,6 +243,25 @@ function init_particles(rng, T, ndims, nactive, model)
     return us, vs, logl
 end
 
+function init_particles(rng::Random.AbstractRNG, us::Matrix, model::LogDensityModel)
+    logdensity = model.logdensity
+    dims, nactive = size(us)
+    logls = [logdensity(us[:, i]) for i in 1:nactive]
+
+    ntries = 1
+    while true
+        any(isfinite, logls) && break
+        rand!(rng, us)
+        logls = [logdensity(us[:, i]) for i in 1:nactive]
+        ntries += 1
+        ntries > 100 && error("After 100 attempts, could not initialize any live points with finite loglikelihood. Please check your prior transform and loglikelihood methods.")
+    end
+
+    # force -Inf to be a finite but small number to keep estimators from breaking
+    @. logls[logls == -Inf] = -1e300
+
+    return us, us, logls
+end
 
 # add remaining live points to `samples`
 function add_live_points(samples, model, sampler, state)