Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

preprocessor directives followup #89

Merged
merged 8 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 53 additions & 10 deletions yt_idv/scene_components/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from yt_idv.scene_data.base_data import SceneData
from yt_idv.shader_objects import (
PreprocessorDefinitionState,
ShaderProgram,
ShaderTrait,
component_shaders,
Expand Down Expand Up @@ -58,6 +59,8 @@ class SceneComponent(traitlets.HasTraits):
colormap = traitlets.Instance(ColormapTexture)
_program1 = traitlets.Instance(ShaderProgram, allow_none=True)
_program2 = traitlets.Instance(ShaderProgram, allow_none=True)
_program1_pp_defs = traitlets.Instance(PreprocessorDefinitionState, allow_none=True)
_program2_pp_defs = traitlets.Instance(PreprocessorDefinitionState, allow_none=True)
_program1_invalid = True
_program2_invalid = True
_cmap_bounds_invalid = True
Expand Down Expand Up @@ -144,15 +147,38 @@ def _default_display_name(self):
def _default_render_method(self):
return default_shader_combos[self.name]

@traitlets.default("_program1_pp_defs")
def _default_program1_pp_defs(self):
return PreprocessorDefinitionState()

@traitlets.default("_program2_pp_defs")
def _default_program2_pp_defs(self):
return PreprocessorDefinitionState()

@traitlets.observe("render_method")
def _change_render_method(self, change):
new_combo = component_shaders[self.name][change["new"]]
with self.hold_trait_notifications():
self.vertex_shader = new_combo["first_vertex"]
self.fragment_shader = new_combo["first_fragment"]
self.geometry_shader = new_combo.get("first_geometry", None)
self.colormap_vertex = new_combo["second_vertex"]
self.colormap_fragment = new_combo["second_fragment"]
self.vertex_shader = (
new_combo["first_vertex"],
self._program1_pp_defs["vertex"],
)
self.fragment_shader = (
new_combo["first_fragment"],
self._program1_pp_defs["fragment"],
)
self.geometry_shader = (
new_combo.get("first_geometry", None),
self._program1_pp_defs["geometry"],
)
self.colormap_vertex = (
new_combo["second_vertex"],
self._program2_pp_defs["vertex"],
)
self.colormap_fragment = (
new_combo["second_fragment"],
self._program2_pp_defs["fragment"],
)

@traitlets.observe("render_method")
def _add_initial_isolayer(self, change):
Expand Down Expand Up @@ -191,10 +217,23 @@ def _change_colormap_fragment(self, change):
self._program2_invalid = True

@traitlets.observe("use_db")
def _initialize_db(self, changed):
# invaldiate the colormap when the depth buffer selection changes
def _toggle_depth_buffer(self, changed):
# invalidate the colormap when the depth buffer selection changes
self._cmap_bounds_invalid = True

# update the preprocessor state: USE_DB only present in the second
# program, only update that one.
if changed["new"]:
self._program2_pp_defs.add_definition("fragment", ("USE_DB", ""))
else:
self._program2_pp_defs.clear_definition("fragment", ("USE_DB", ""))

# update the colormap fragment with current render method
current_combo = component_shaders[self.name][self.render_method]
pp_defs = self._program2_pp_defs["fragment"]
self.colormap_fragment = current_combo["second_fragment"], pp_defs
self._recompile_shader()

@traitlets.default("colormap")
def _default_colormap(self):
cm = ColormapTexture()
Expand Down Expand Up @@ -241,7 +280,9 @@ def program1(self):
self._program1.delete_program()
self._fragment_shader_default()
self._program1 = ShaderProgram(
self.vertex_shader, self.fragment_shader, self.geometry_shader
self.vertex_shader,
self.fragment_shader,
self.geometry_shader,
)
self._program1_invalid = False
return self._program1
Expand All @@ -254,7 +295,10 @@ def program2(self):
# The vertex shader will always be the same.
# The fragment shader will change based on whether we are
# colormapping or not.
self._program2 = ShaderProgram(self.colormap_vertex, self.colormap_fragment)
self._program2 = ShaderProgram(
self.colormap_vertex,
self.colormap_fragment,
)
self._program2_invalid = False
return self._program2

Expand Down Expand Up @@ -296,7 +340,6 @@ def run_program(self, scene):
p2._set_uniform("cmap", 0)
p2._set_uniform("fb_tex", 1)
p2._set_uniform("db_tex", 2)
p2._set_uniform("use_db", self.use_db)
# Note that we use cmap_min/cmap_max, not
# self.cmap_min/self.cmap_max.
p2._set_uniform("cmap_min", self.cmap_min)
Expand Down
116 changes: 105 additions & 11 deletions yt_idv/shader_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ctypes
import os
from collections import OrderedDict
from typing import List, Optional, Tuple

import traitlets
import yaml
Expand Down Expand Up @@ -79,28 +80,62 @@ class ShaderProgram:
geometry_shader : string
or :class:`yt_idv.shader_objects.GeometryShader`
The geometry shader used in the pipeline; optional.
preprocessor_defs : PreprocessorDefinitionState
a PreprocessorDefinitionState instance defining any preprocessor
definitions if used; optional.
"""

def __init__(self, vertex_shader=None, fragment_shader=None, geometry_shader=None):
def __init__(
self,
vertex_shader=None,
fragment_shader=None,
geometry_shader=None,
preprocessor_defs=None,
):
# Don't allow just one. Either neither or both.
if vertex_shader is None and fragment_shader is None:
pass
elif None not in (vertex_shader, fragment_shader):
# Geometry is optional
self.link(vertex_shader, fragment_shader, geometry_shader)
self.link(
vertex_shader,
fragment_shader,
geometry_shader,
preprocessor_defs=preprocessor_defs,
)
else:
raise RuntimeError
self._uniform_funcs = OrderedDict()

def link(self, vertex_shader, fragment_shader, geometry_shader=None):
def link(
self,
vertex_shader,
fragment_shader,
geometry_shader=None,
preprocessor_defs=None,
):
if preprocessor_defs is None:
preprocessor_defs = PreprocessorDefinitionState()

# We allow an optional geometry shader, but not tesselation (yet?)
self.program = GL.glCreateProgram()
if not isinstance(vertex_shader, Shader):
vertex_shader = Shader(source=vertex_shader)
vertex_shader = Shader(
source=vertex_shader,
preprocessor_defs=preprocessor_defs.get_shader_defs("vertex"),
)
if not isinstance(fragment_shader, Shader):
fragment_shader = Shader(source=fragment_shader)
fragment_shader = Shader(
source=fragment_shader,
preprocessor_defs=preprocessor_defs.get_shader_defs("fragment"),
)
if geometry_shader is not None and not isinstance(geometry_shader, Shader):
geometry_shader = Shader(source=geometry_shader)
geometry_shader = Shader(
source=geometry_shader,
preprocessor_defs=preprocessor_defs.get_shader_defs("geometry"),
)
self.vertex_shader = vertex_shader
self.fragment_shader = fragment_shader
self.geometry_shader = geometry_shader
Expand Down Expand Up @@ -277,6 +312,7 @@ def _get_source(self, source):
if ";" in source:
# This is probably safe, right? Enh, probably.
return source

# What this does is concatenate multiple (if available) source files.
# This gets around GLSL's composition issues, which means we can have
# functions that get called at each step in a ray tracing process, for
Expand Down Expand Up @@ -371,31 +407,89 @@ def __del__(self):
self.delete_shader()


def _validate_shader(shader_type, value, allow_null=True):
def _validate_shader(shader_type, value, allow_null=True, preprocessor_defs=None):
shader_info = known_shaders[shader_type][value]
shader_info.setdefault("shader_type", shader_type)
shader_info["use_separate_blend"] = bool("blend_func_separate" in shader_info)
shader_info.setdefault("shader_name", value)
shader = Shader(allow_null=allow_null, **shader_info)
return shader
if preprocessor_defs is not None:
shader_info["preprocessor_defs"] = preprocessor_defs
return Shader(allow_null=allow_null, **shader_info)


class ShaderTrait(traitlets.TraitType):
default_value = None
info_text = "A shader (vertex, fragment or geometry)"

def validate(self, obj, value):
if isinstance(value, str):
if isinstance(value, str) or isinstance(value, tuple):
try:
shader_type = self.metadata.get("shader_type", "vertex")
return _validate_shader(shader_type, value)
if isinstance(value, tuple):
preprocessor_defs = value[1]
value = value[0]
else:
preprocessor_defs = None
return _validate_shader(
shader_type, value, preprocessor_defs=preprocessor_defs
)
except KeyError:
self.error(obj, value)
elif isinstance(value, Shader):
return value
self.error(obj, value)


class PreprocessorDefinitionState:

_valid_shader_types = ("vertex", "geometry", "fragment")

def __init__(self):
self.vertex = {}
self.geometry = {}
self.fragment = {}

def _get_dict(self, shader_type: str) -> dict:
"""return the dict of definitions for specifed shader_type"""
return getattr(self, shader_type)

def add_definition(self, shader_type: str, value: Tuple[str, str]):
"""add a definition for specified shader_type, will overwrite
existing definitions.
"""
self._validate_shader_type(shader_type)
self._get_dict(shader_type)[value[0]] = value[1]

def clear_definition(self, shader_type: str, value: Tuple[str, str]):
"""remove the definition of value for specified shader_type"""
self._validate_shader_type(shader_type)
self._get_dict(shader_type).pop(value[0])

def get_shader_defs(self, shader_type: str) -> List[Tuple[str, str]]:
"""return the preprocessor definition list for specified shader_type"""
self._validate_shader_type(shader_type)
return list(self._get_dict(shader_type).items())

def _validate_shader_type(self, shader_type: str):
if shader_type not in self._valid_shader_types:
raise ValueError(
f"shader_type must be one of {self._valid_shader_types}, "
f"but found {shader_type}"
)

def __getitem__(self, item: str) -> List[Tuple[str, str]]:
return self.get_shader_defs(item)

def reset(self, shader_type: Optional[str] = None):
if shader_type is None:
self.vertex = {}
self.geometry = {}
self.fragment = {}
else:
self._validate_shader_type(shader_type)
setattr(self, shader_type, {})


known_shaders = {}
component_shaders = {}
default_shader_combos = {}
Expand Down
10 changes: 5 additions & 5 deletions yt_idv/shaders/apply_colormap.frag.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ out vec4 color;

void main(){
float scaled = 0;
if (use_db) {
scaled = texture(db_tex, UV).x;
} else {
scaled = texture(fb_tex, UV).x;
}
#ifdef USE_DB
scaled = texture(db_tex, UV).x;
#else
scaled = texture(fb_tex, UV).x;
#endif
float alpha = texture(fb_tex, UV).a; // the incoming framebuffer alpha
if (alpha == 0.0) discard;
float cm = cmap_min;
Expand Down
3 changes: 0 additions & 3 deletions yt_idv/shaders/known_uniforms.inc.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ uniform sampler3D ds_tex[6];
// ray tracing control
uniform float sample_factor;

// depth buffer control
uniform bool use_db;

// curve drawing control
uniform vec4 curve_rgba;

Expand Down
29 changes: 29 additions & 0 deletions yt_idv/tests/test_preprocessor_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

from yt_idv.shader_objects import PreprocessorDefinitionState


def test_preprocessor_definition_state():

pds = PreprocessorDefinitionState()

pds.add_definition("fragment", ("USE_DB", ""))
assert ("USE_DB", "") in pds["fragment"]
pds.add_definition("vertex", ("placeholder", ""))
assert ("placeholder", "") in pds["vertex"]

with pytest.raises(ValueError, match="shader_type must be"):
pds.add_definition("not_a_shader_type", ("any_str", ""))

pds.clear_definition("fragment", ("USE_DB", ""))
assert ("USE_DB", "") not in pds["fragment"]

pds.reset("vertex")
assert len(pds.vertex) == 0

pds.add_definition("fragment", ("USE_DB", ""))
pds.add_definition("geometry", ("placeholder", ""))
pds.add_definition("vertex", ("placeholder", ""))
pds.reset()
for shadertype in pds._valid_shader_types:
assert len(pds._get_dict(shadertype)) == 0
5 changes: 5 additions & 0 deletions yt_idv/tests/test_yt_idv.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ def test_snapshots(osmesa_fake_amr, image_store):
image_store(osmesa_fake_amr)


def test_depth_buffer_toggle(osmesa_fake_amr, image_store):
osmesa_fake_amr.scene.components[0].use_db = True
image_store(osmesa_fake_amr)


def test_slice(osmesa_fake_amr, image_store):
osmesa_fake_amr.scene.components[0].render_method = "slice"
osmesa_fake_amr.scene.components[0].slice_position = (0.5, 0.5, 0.5)
Expand Down
Loading