@@ -700,6 +700,153 @@ def tanh_projection(x: np.ndarray, beta: float, eta: float) -> np.ndarray:
700700 )
701701
702702
703+ def smoothed_projection (
704+ x_smoothed : ArrayLikeType ,
705+ beta : float ,
706+ eta : float ,
707+ resolution : float ,
708+ ):
709+ """Project using subpixel smoothing, which allows for β→∞.
710+
711+ This technique integrates out the discontinuity within the projection
712+ function, allowing the user to smoothly increase β from 0 to ∞ without
713+ losing the gradient. Effectively, a level set is created, and from this
714+ level set, first-order subpixel smoothing is applied to the interfaces (if
715+ any are present).
716+
717+ In order for this to work, the input array must already be smooth (e.g. by
718+ filtering).
719+
720+ While the original approach involves numerical quadrature, this approach
721+ performs a "trick" by assuming that the user is always infinitely projecting
722+ (β=∞). In this case, the expensive quadrature simplifies to an analytic
723+ fill-factor expression. When to use this fill factor requires some careful
724+ logic.
725+
726+ For one, we want to make sure that the user can indeed project at any level
727+ (not just infinity). So in these cases, we simply check if in interface is
728+ within the pixel. If not, we revert to the standard filter plus project
729+ technique.
730+
731+ If there is an interface, we want to make sure the derivative remains
732+ continuous both as the interface leaves the cell, *and* as it crosses the
733+ center. To ensure this, we need to account for the different possibilities.
734+
735+ Args:
736+ x: The (2D) input design parameters.
737+ beta: The thresholding parameter in the range [0, inf]. Determines the
738+ degree of binarization of the output.
739+ eta: The threshold point in the range [0, 1].
740+ resolution: resolution of the design grid (not the Meep grid
741+ resolution).
742+ Returns:
743+ The projected and smoothed output.
744+
745+ Example:
746+ >>> Lx = 2; Ly = 2
747+ >>> resolution = 50
748+ >>> eta_i = 0.5; eta_e = 0.75
749+ >>> lengthscale = 0.1
750+ >>> filter_radius = get_conic_radius_from_eta_e(lengthscale, eta_e)
751+ >>> Nx = onp.round(Lx * resolution) + 1
752+ >>> Ny = onp.round(Ly * resolution) + 1
753+ >>> A = onp.random.rand(Nx, Ny)
754+ >>> beta = npa.inf
755+ >>> A_smoothed = conic_filter(A, filter_radius, Lx, Ly, resolution)
756+ >>> A_projected = smoothed_projection(A_smoothed, beta, eta_i, resolution)
757+ """
758+ # Note that currently, the underlying assumption is that the smoothing
759+ # kernel is a circle, which means dx = dy.
760+ dx = dy = 1 / resolution
761+ pixel_radius = dx / 2
762+
763+ x_projected = tanh_projection (x_smoothed , beta = beta , eta = eta )
764+
765+ # Compute the spatial gradient (using finite differences) of the *filtered*
766+ # field, which will always be smooth and is the key to our approach. This
767+ # gradient essentially represents the normal direction pointing the the
768+ # nearest inteface.
769+ x_grad = npa .gradient (x_smoothed )
770+ x_grad_helper = (x_grad [0 ] / dx ) ** 2 + (x_grad [1 ] / dy ) ** 2
771+
772+ # Note that a uniform field (norm=0) is problematic, because it creates
773+ # divide by zero issues and makes backpropagation difficult, so we sanitize
774+ # and determine where smoothing is actually needed. The value where we don't
775+ # need smoothings doesn't actually matter, since all our computations our
776+ # purely element-wise (no spatial locality) and those pixels will instead
777+ # rely on the standard projection. So just use 1, since it's well behaved.
778+ nonzero_norm = npa .abs (x_grad_helper ) > 0
779+
780+ x_grad_norm = npa .sqrt (npa .where (nonzero_norm , x_grad_helper , 1 ))
781+ x_grad_norm_eff = npa .where (nonzero_norm , x_grad_norm , 1 )
782+
783+ # The distance for the center of the pixel to the nearest interface
784+ d = (eta - x_smoothed ) / x_grad_norm_eff
785+
786+ # Only need smoothing if an interface lies within the voxel. Since d is
787+ # actually an "effective" d by this point, we need to ignore values that may
788+ # have been sanitized earlier on.
789+ needs_smoothing = nonzero_norm & (npa .abs (d ) <= pixel_radius )
790+
791+ # The fill factor is used to perform simple, first-order subpixel smoothing.
792+ # We use the (2D) analytic expression that comes when assuming the smoothing
793+ # kernel is a circle. Note that because the kernel contains some
794+ # expressions that are sensitive to NaNs, we have to use the "double where"
795+ # trick to avoid the Nans in the backward trace. This is a common problem
796+ # with array-based AD tracers, apparently. See here:
797+ # https://github.com/google/jax/issues/1052#issuecomment-5140833520
798+
799+ arccos_term = pixel_radius ** 2 * npa .arccos (
800+ npa .where (
801+ needs_smoothing ,
802+ d / pixel_radius ,
803+ 0.0 ,
804+ )
805+ )
806+
807+ sqrt_term = d * npa .sqrt (
808+ npa .where (
809+ needs_smoothing ,
810+ pixel_radius ** 2 - d ** 2 ,
811+ 1 ,
812+ )
813+ )
814+
815+ fill_factor = npa .where (
816+ needs_smoothing ,
817+ (1 / (npa .pi * pixel_radius ** 2 )) * (arccos_term - sqrt_term ),
818+ 1 ,
819+ )
820+
821+ # Determine the upper and lower bounds of materials in the current pixel.
822+ x_minus = x_smoothed - x_grad_norm * pixel_radius
823+ x_plus = x_smoothed + x_grad_norm * pixel_radius
824+
825+ # Create an "effective" set of materials that will ensure everything is
826+ # piecewise differentiable, even if an interface moves out of a pixel, or
827+ # through the pixel center.
828+ x_minus_eff_pert = (x_smoothed * d + x_minus * (pixel_radius - d )) / pixel_radius
829+ x_minus_eff = npa .where (
830+ (d > 0 ),
831+ x_minus_eff_pert ,
832+ x_minus ,
833+ )
834+ x_plus_eff_pert = (- x_smoothed * d + x_plus * (pixel_radius + d )) / pixel_radius
835+ x_plus_eff = npa .where (
836+ (d > 0 ),
837+ x_plus ,
838+ x_plus_eff_pert ,
839+ )
840+
841+ # Only apply smoothing to interfaces
842+ x_projected_smoothed = (1 - fill_factor ) * x_minus_eff + (fill_factor ) * x_plus_eff
843+ return npa .where (
844+ needs_smoothing ,
845+ x_projected_smoothed ,
846+ x_projected ,
847+ )
848+
849+
703850def heaviside_projection (x , beta , eta ):
704851 """Projection filter that thresholds the input parameters between 0 and 1.
705852
0 commit comments