@@ -242,7 +242,7 @@ struct SCookTorrance
242242 template<typename C=bool_constant<IsBSDF> >
243243 enable_if_t<C::value && IsBSDF, sample_type> generate (NBL_CONST_REF_ARG (anisotropic_interaction_type) interaction, const vector3_type u, NBL_REF_ARG (anisocache_type) cache)
244244 {
245- const vector3_type localV = hlsl:: normalize ( interaction.getTangentSpaceV () );
245+ const vector3_type localV = interaction.getTangentSpaceV ();
246246 const scalar_type NdotV = localV.z;
247247
248248 fresnel_type _f = impl::getOrientedFresnel<fresnel_type, IsBSDF>::__call (fresnel, NdotV);
@@ -270,90 +270,37 @@ struct SCookTorrance
270270 bool transmitted = math::partitionRandVariable (reflectance, z, rcpChoiceProb);
271271
272272 ray_dir_info_type V = interaction.getV ();
273- const vector3_type H = hlsl::normalize (hlsl::mul (interaction.getFromTangentSpace (), localH));
274-
275- // TODO: UNDER CONSTRUCTION, will uncomment when sure basic stuff passes tests
276- // Refract<scalar_type> r = Refract<scalar_type>::create(V.getDirection(), H);
277- // const scalar_type LdotH = hlsl::mix(VdotH, r.getNdotT(rcpEta.value2[0]), transmitted);
278-
279- // // fail if samples have invalid paths
280- // const scalar_type viewShortenFactor = hlsl::mix(scalar_type(1.0), rcpEta.value[0], transmitted);
281- // const scalar_type NdotL = localH.z * (VdotH * viewShortenFactor + hlsl::abs(LdotH)) - NdotV * viewShortenFactor;
282- // // VNDF sampling guarantees that `VdotH` has same sign as `NdotV`
283- // // and `transmitted` controls the sign of `LdotH` relative to `VdotH` by construction (reflect -> same sign, or refract -> opposite sign)
284- // if (ComputeMicrofacetNormal<scalar_type>::isTransmissionPath(NdotV, NdotL) != transmitted)
285- // return sample_type::createInvalid(); // should check if sample direction is invalid
286-
287- // cache = anisocache_type::createPartial(VdotH, LdotH, localH.z, transmitted, rcpEta);
288- // assert(cache.isValid(_f.getRefractionOrientedEta()));
289-
290- // struct reflect_refract_wrapper // so we don't recalculate LdotH
291- // {
292- // vector3_type operator()(const bool doRefract, const scalar_type rcpOrientedEta) NBL_CONST_MEMBER_FUNC
293- // {
294- // return rr(NdotTorR, rcpOrientedEta);
295- // }
296- // bxdf::ReflectRefract<scalar_type> rr;
297- // scalar_type NdotTorR;
298- // };
299- // bxdf::ReflectRefract<scalar_type> rr; // rr.getNdotTorR() and calls to mix as well as a good part of the computations should CSE with our computation of NdotL above
300- // rr.refract = r;
301- // reflect_refract_wrapper rrw;
302- // rrw.rr = rr;
303- // rrw.NdotTorR = LdotH;
304- // ray_dir_info_type L = V.template reflectRefract<reflect_refract_wrapper>(rrw, transmitted, rcpEta.value[0]);
305-
306- ray_dir_info_type L;
307- if (transmitted)
308- {
309- // scalar_type eta = rcpEta.value[0]; // refraction takes eta as ior_incoming/ior_transmitted due to snell's law
310- // vector3_type orientedH = ieee754::flipSignIfRHSNegative<vector3_type>(H, hlsl::promote<vector3_type>(NdotV));
311- // scalar_type cosThetaI = hlsl::dot(V.getDirection(), orientedH);
312- // scalar_type sin2ThetaI = hlsl::max(scalar_type(0), scalar_type(1) - cosThetaI * cosThetaI);
313- // scalar_type sin2ThetaT = eta * eta * sin2ThetaI;
314-
315- // if (sin2ThetaT >= 1) return sample_type::createInvalid();
316- // scalar_type cosThetaT = hlsl::sqrt(scalar_type(1) - sin2ThetaT);
317- // L.direction = eta * -V.getDirection() + (eta * cosThetaI - cosThetaT) * orientedH;
318-
319-
320- Refract<scalar_type> r = Refract<scalar_type>::create (V.getDirection (), H);
321- bxdf::ReflectRefract<scalar_type> rr;
322- rr.refract = r;
323- L = V.reflectRefract (rr, transmitted, rcpEta.value[0 ]);
324- L.direction = hlsl::normalize (L.direction);
325- }
326- else
327- {
328- bxdf::Reflect<scalar_type> r = bxdf::Reflect<scalar_type>::create (V.getDirection (), H);
329- L = V.reflect (r);
330- L.direction = hlsl::normalize (L.direction);
331- }
332-
333- vector3_type _N = hlsl::normalize (interaction.getN ());
334- scalar_type NdotL = hlsl::dot (_N, L.getDirection ());
273+ const vector3_type H = hlsl::mul (interaction.getFromTangentSpace (), localH);
274+ assert (hlsl::abs (hlsl::length (H) - scalar_type (1.0 )) < scalar_type (1e-4 ));
275+ Refract<scalar_type> r = Refract<scalar_type>::create (V.getDirection (), H);
276+ const scalar_type LdotH = hlsl::mix (VdotH, r.getNdotT (rcpEta.value2[0 ]), transmitted);
277+
278+ // fail if samples have invalid paths
279+ const scalar_type NdotL = hlsl::mix (scalar_type (2.0 ) * VdotH * localH.z - NdotV,
280+ localH.z * (VdotH * rcpEta.value[0 ] + LdotH) - NdotV * rcpEta.value[0 ], transmitted);
281+ // VNDF sampling guarantees that `VdotH` has same sign as `NdotV`
282+ // and `transmitted` controls the sign of `LdotH` relative to `VdotH` by construction (reflect -> same sign, or refract -> opposite sign)
335283 if (ComputeMicrofacetNormal<scalar_type>::isTransmissionPath (NdotV, NdotL) != transmitted)
336284 return sample_type::createInvalid (); // should check if sample direction is invalid
337285
338- cache.iso_cache.VdotH = VdotH;
339- cache.iso_cache.LdotH = hlsl::dot (L.getDirection (), H);
340- // cache.iso_cache.VdotL = hlsl::dot(V.getDirection(), L.getDirection());
341- cache.iso_cache.VdotL = hlsl::mix (scalar_type (2.0 ) * VdotH * VdotH - scalar_type (1.0 ),
342- VdotH * (VdotH * rcpEta.value[0 ] + cache.iso_cache.LdotH) - rcpEta.value[0 ], transmitted);
343- // const scalar_type viewShortenFactor = hlsl::mix(scalar_type(1.0), rcpEta.value[0], transmitted);
344- // scalar_type _VdotL = VdotH * (VdotH * viewShortenFactor + cache.iso_cache.LdotH) - viewShortenFactor;
345- // scalar_type VdotL = hlsl::dot(V.getDirection(), L.getDirection());
346- // assert(hlsl::abs(VdotL - cache.iso_cache.VdotL) < 1e-4);
347- assert (localH.z > scalar_type (0.0 ));
348- cache.iso_cache.absNdotH = hlsl::abs (localH.z);
349- cache.iso_cache.NdotH2 = localH.z * localH.z;
350-
286+ cache = anisocache_type::createPartial (VdotH, LdotH, localH.z, transmitted, rcpEta);
351287 assert (cache.isValid (_f.getRefractionOrientedEta ()));
352288
353- // const scalar_type _viewShortenFactor = hlsl::mix(scalar_type(1.0), rcpEta.value[0], transmitted);
354- // const scalar_type _NdotL = localH.z * (VdotH * _viewShortenFactor + cache.iso_cache.LdotH) - NdotV * _viewShortenFactor;
355- scalar_type _NdotL = hlsl::dot (_N, L.getDirection ());
356- assert (hlsl::abs (_NdotL - NdotL) < 1e-4 );
289+ struct reflect_refract_wrapper // so we don't recalculate LdotH
290+ {
291+ vector3_type operator ()(const bool doRefract, const scalar_type rcpOrientedEta) NBL_CONST_MEMBER_FUNC
292+ {
293+ return rr (NdotTorR, rcpOrientedEta);
294+ }
295+ bxdf::ReflectRefract<scalar_type> rr;
296+ scalar_type NdotTorR;
297+ };
298+ bxdf::ReflectRefract<scalar_type> rr; // rr.getNdotTorR() and calls to mix as well as a good part of the computations should CSE with our computation of NdotL above
299+ rr.refract = r;
300+ reflect_refract_wrapper rrw;
301+ rrw.rr = rr;
302+ rrw.NdotTorR = LdotH;
303+ ray_dir_info_type L = V.template reflectRefract<reflect_refract_wrapper>(rrw, transmitted, rcpEta.value[0 ]);
357304
358305 const vector3_type T = interaction.getT ();
359306 const vector3_type B = interaction.getB ();
0 commit comments