Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Bottom level categories:
- Fixed `workgroupUniformLoad` incorrectly returning an atomic when called on an atomic, it now returns the inner `T` as per the spec. By @cryvosh in [#8791](https://github.com/gfx-rs/wgpu/pull/8791).
- Fixed constant evaluation for `sign()` builtin to return zero when the argument is zero. By @mandryskowski in [#8942](https://github.com/gfx-rs/wgpu/pull/8942).
- Allow array generation to compile with the macOS 10.12 Metal compiler. By @madsmtm in [#8953](https://github.com/gfx-rs/wgpu/pull/8953)
- Naga now detects bitwise shifts by a constant exceeding the operand bit width at compile time, and disallows scalar-by-vector and vector-by-scalar shifts in constant evaluation. By @andyleiserson in [#8907](https://github.com/gfx-rs/wgpu/pull/8907).

#### Validation

Expand Down
2 changes: 1 addition & 1 deletion cts_runner/fail.lst
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ webgpu:shader,validation,expression,access,matrix:* // 93%, runtime OOB matrix a
webgpu:shader,validation,expression,access,vector:* // 52%, https://github.com/gfx-rs/wgpu/issues/4390, and missing swizzle validation
webgpu:shader,validation,expression,binary,add_sub_mul:* // 95%, u32 const-eval overflow incorrectly rejected, f16 const-eval overflow not rejected, atomics #5474
webgpu:shader,validation,expression,binary,and_or_xor:* // 96%, https://github.com/gfx-rs/wgpu/issues/5474
webgpu:shader,validation,expression,binary,bitwise_shift:* // 97%, atomics https://github.com/gfx-rs/wgpu/issues/5474, partial eval errors
webgpu:shader,validation,expression,binary,bitwise_shift:invalid_types:* // 93%, atomics #5474
webgpu:shader,validation,expression,binary,comparison:* // 74%, https://github.com/gfx-rs/wgpu/issues/5474
webgpu:shader,validation,expression,binary,div_rem:* // 86%, https://github.com/gfx-rs/wgpu/issues/5474
webgpu:shader,validation,expression,binary,short_circuiting_and_or:* // 92%, https://github.com/gfx-rs/wgpu/issues/8440
Expand Down
6 changes: 6 additions & 0 deletions cts_runner/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ webgpu:shader,validation,expression,access,array:early_eval_errors:case="overrid
webgpu:shader,validation,expression,access,structure:*
webgpu:shader,validation,expression,binary,add_sub_mul:scalar_vector_out_of_range:lhs="i32";*
webgpu:shader,validation,expression,binary,add_sub_mul:scalar_vector_out_of_range:lhs="u32";*
webgpu:shader,validation,expression,binary,bitwise_shift:partial_eval_errors:*
webgpu:shader,validation,expression,binary,bitwise_shift:scalar_vector:*
webgpu:shader,validation,expression,binary,bitwise_shift:shift_left_abstract:*
webgpu:shader,validation,expression,binary,bitwise_shift:shift_left_concrete:*
webgpu:shader,validation,expression,binary,bitwise_shift:shift_right_abstract:*
webgpu:shader,validation,expression,binary,bitwise_shift:shift_right_concrete:*
webgpu:shader,validation,expression,binary,parse:*
webgpu:shader,validation,expression,binary,short_circuiting_and_or:array_override:op="%26%26";a_val=1;b_val=1
webgpu:shader,validation,expression,binary,short_circuiting_and_or:invalid_types:*
Expand Down
38 changes: 26 additions & 12 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2753,26 +2753,36 @@ impl<'a> ConstantEvaluator<'a> {
ty,
},
&Expression::Literal(_),
) => {
let mut components = src_components.clone();
for component in &mut components {
*component = self.binary_op(op, *component, right, span)?;
) => match op {
BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
}
Expression::Compose { ty, components }
}
_ => {
let mut components = src_components.clone();
for component in &mut components {
*component = self.binary_op(op, *component, right, span)?;
}
Expression::Compose { ty, components }
}
},
(
&Expression::Literal(_),
&Expression::Compose {
components: ref src_components,
ty,
},
) => {
let mut components = src_components.clone();
for component in &mut components {
*component = self.binary_op(op, left, *component, span)?;
) => match op {
BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
}
Expression::Compose { ty, components }
}
_ => {
let mut components = src_components.clone();
for component in &mut components {
*component = self.binary_op(op, left, *component, span)?;
}
Expression::Compose { ty, components }
}
},
(
&Expression::Compose {
components: ref left_components,
Expand Down Expand Up @@ -3030,6 +3040,10 @@ impl<'a> ConstantEvaluator<'a> {
h
}

/// Resolve the type of `expr` if it is a constant expression.
///
/// If `expr` was evaluated to a constant, returns its type.
/// Otherwise, returns an error.
fn resolve_type(
&self,
expr: Handle<Expression>,
Expand Down
7 changes: 6 additions & 1 deletion naga/src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,12 @@ pub struct GlobalCtx<'a> {
}

impl GlobalCtx<'_> {
/// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`.
/// Try to evaluate the expression in `self.global_expressions` using its `handle`
/// and return it as a `T: TryFrom<ir::Literal>`.
///
/// This currently only evaluates scalar expressions. If adding support for vectors,
/// consider changing `valid::expression::validate_constant_shift_amounts` to use that
/// support.
#[cfg_attr(
not(any(
feature = "glsl-in",
Expand Down
123 changes: 100 additions & 23 deletions naga/src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ pub enum ExpressionError {
UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
#[error("Invalid operand for cooperative op")]
InvalidCooperativeOperand(Handle<crate::Expression>),
#[error("Shift amount exceeds the bit width of {lhs_type:?}")]
ShiftAmountTooLarge {
lhs_type: crate::TypeInner,
rhs_expr: Handle<crate::Expression>,
},
}

#[derive(Clone, Debug, thiserror::Error)]
Expand Down Expand Up @@ -243,6 +248,74 @@ impl super::Validator {
Ok(())
}

/// Return an error if a constant shift amount in `right` exceeds the bit
/// width of `left_ty`.
///
/// This function promises to return an error in cases where (1) the
/// expression is well-typed, (2) `left_ty` is a concrete integer, and
/// (3) the shift will overflow. It does not return an error in cases where
/// the expression is not well-typed (e.g. vector dimension mismatch),
/// because those will be rejected elsewhere.
fn validate_constant_shift_amounts(
left_ty: &crate::TypeInner,
right: Handle<crate::Expression>,
module: &crate::Module,
function: &crate::Function,
) -> Result<(), ExpressionError> {
fn is_overflowing_shift(
left_ty: &crate::TypeInner,
right: Handle<crate::Expression>,
module: &crate::Module,
function: &crate::Function,
) -> bool {
let Some((vec_size, scalar)) = left_ty.vector_size_and_scalar() else {
return false;
};
if !matches!(
scalar.kind,
crate::ScalarKind::Sint | crate::ScalarKind::Uint
) {
return false;
}
let lhs_bits = u32::from(8 * scalar.width);
if vec_size.is_none() {
let shift_amount = module
.to_ctx()
.get_const_val_from::<u32, _>(right, &function.expressions);
shift_amount.ok().is_some_and(|s| s >= lhs_bits)
} else {
match function.expressions[right] {
crate::Expression::ZeroValue(_) => false, // zero shift does not overflow
crate::Expression::Splat { value, .. } => module
.to_ctx()
.get_const_val_from::<u32, _>(value, &function.expressions)
.ok()
.is_some_and(|s| s >= lhs_bits),
crate::Expression::Compose {
ty: _,
ref components,
} => components.iter().any(|comp| {
module
.to_ctx()
.get_const_val_from::<u32, _>(*comp, &function.expressions)
.ok()
.is_some_and(|s| s >= lhs_bits)
}),
_ => false,
}
}
}

if is_overflowing_shift(left_ty, right, module, function) {
Err(ExpressionError::ShiftAmountTooLarge {
lhs_type: left_ty.clone(),
rhs_expr: right,
})
} else {
Ok(())
}
}

#[allow(clippy::too_many_arguments)]
pub(super) fn validate_expression(
&self,
Expand Down Expand Up @@ -273,7 +346,7 @@ impl super::Validator {
| Ti::ValuePointer { size: Some(_), .. }
| Ti::BindingArray { .. } => {}
ref other => {
log::error!("Indexing of {other:?}");
log::debug!("Indexing of {other:?}");
return Err(ExpressionError::InvalidBaseType(base));
}
};
Expand All @@ -284,7 +357,7 @@ impl super::Validator {
..
}) => {}
ref other => {
log::error!("Indexing by {other:?}");
log::debug!("Indexing by {other:?}");
return Err(ExpressionError::InvalidIndexType(index));
}
}
Expand Down Expand Up @@ -342,7 +415,7 @@ impl super::Validator {
}
Ti::Struct { ref members, .. } => members.len() as u32,
ref other => {
log::error!("Indexing of {other:?}");
log::debug!("Indexing of {other:?}");
return Err(ExpressionError::InvalidBaseType(top));
}
};
Expand All @@ -358,7 +431,7 @@ impl super::Validator {
E::Splat { size: _, value } => match resolver[value] {
Ti::Scalar { .. } => ShaderStages::all(),
ref other => {
log::error!("Splat scalar type {other:?}");
log::debug!("Splat scalar type {other:?}");
return Err(ExpressionError::InvalidSplatType(value));
}
},
Expand All @@ -370,7 +443,7 @@ impl super::Validator {
let vec_size = match resolver[vector] {
Ti::Vector { size: vec_size, .. } => vec_size,
ref other => {
log::error!("Swizzle vector type {other:?}");
log::debug!("Swizzle vector type {other:?}");
return Err(ExpressionError::InvalidVectorType(vector));
}
};
Expand Down Expand Up @@ -414,7 +487,7 @@ impl super::Validator {
.contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
Ti::ValuePointer { .. } => {}
ref other => {
log::error!("Loading {other:?}");
log::debug!("Loading {other:?}");
return Err(ExpressionError::InvalidPointerType(pointer));
}
}
Expand Down Expand Up @@ -786,7 +859,7 @@ impl super::Validator {
| (Uo::LogicalNot, Some(Sk::Bool))
| (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {}
other => {
log::error!("Op {op:?} kind {other:?}");
log::debug!("Op {op:?} kind {other:?}");
return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
}
}
Expand Down Expand Up @@ -903,7 +976,7 @@ impl super::Validator {
Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
},
ref other => {
log::error!("Op {op:?} left type {other:?}");
log::debug!("Op {op:?} left type {other:?}");
false
}
}
Expand All @@ -915,7 +988,7 @@ impl super::Validator {
..
} => left_inner == right_inner,
ref other => {
log::error!("Op {op:?} left type {other:?}");
log::debug!("Op {op:?} left type {other:?}");
false
}
},
Expand All @@ -925,7 +998,7 @@ impl super::Validator {
Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
},
ref other => {
log::error!("Op {op:?} left type {other:?}");
log::debug!("Op {op:?} left type {other:?}");
false
}
},
Expand All @@ -935,7 +1008,7 @@ impl super::Validator {
Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
},
ref other => {
log::error!("Op {op:?} left type {other:?}");
log::debug!("Op {op:?} left type {other:?}");
false
}
},
Expand All @@ -944,7 +1017,7 @@ impl super::Validator {
Ti::Scalar(scalar) => (Ok(None), scalar),
Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
ref other => {
log::error!("Op {op:?} base type {other:?}");
log::debug!("Op {op:?} base type {other:?}");
(Err(()), Sc::BOOL)
}
};
Expand All @@ -955,7 +1028,7 @@ impl super::Validator {
scalar: Sc { kind: Sk::Uint, .. },
} => Ok(Some(size)),
ref other => {
log::error!("Op {op:?} shift type {other:?}");
log::debug!("Op {op:?} shift type {other:?}");
Err(())
}
};
Expand All @@ -966,12 +1039,12 @@ impl super::Validator {
}
};
if !good {
log::error!(
log::debug!(
"Left: {:?} of type {:?}",
function.expressions[left],
left_inner
);
log::error!(
log::debug!(
"Right: {:?} of type {:?}",
function.expressions[right],
right_inner
Expand All @@ -984,6 +1057,10 @@ impl super::Validator {
rhs_type: right_inner.clone(),
});
}
// For shift operations, check if the constant shift amount exceeds the bit width
if matches!(op, Bo::ShiftLeft | Bo::ShiftRight) {
Self::validate_constant_shift_amounts(left_inner, right, module, function)?;
}
ShaderStages::all()
}
E::Select {
Expand Down Expand Up @@ -1060,15 +1137,15 @@ impl super::Validator {
..
} => {}
ref other => {
log::error!("All/Any of type {other:?}");
log::debug!("All/Any of type {other:?}");
return Err(ExpressionError::InvalidBooleanVector(argument));
}
},
Rf::IsNan | Rf::IsInf => match *argument_inner {
Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
if scalar.kind == Sk::Float => {}
ref other => {
log::error!("Float test of type {other:?}");
log::debug!("Float test of type {other:?}");
return Err(ExpressionError::InvalidFloatArgument(argument));
}
},
Expand Down Expand Up @@ -1206,7 +1283,7 @@ impl super::Validator {
}
}
ref other => {
log::error!("Array length of {other:?}");
log::debug!("Array length of {other:?}");
return Err(ExpressionError::InvalidArrayType(expr));
}
},
Expand All @@ -1221,12 +1298,12 @@ impl super::Validator {
} => match resolver.types[base].inner {
Ti::RayQuery { .. } => ShaderStages::all(),
ref other => {
log::error!("Intersection result of a pointer to {other:?}");
log::debug!("Intersection result of a pointer to {other:?}");
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
ref other => {
log::error!("Intersection result of {other:?}");
log::debug!("Intersection result of {other:?}");
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
Expand All @@ -1242,12 +1319,12 @@ impl super::Validator {
vertex_return: true,
} => ShaderStages::all(),
ref other => {
log::error!("Intersection result of a pointer to {other:?}");
log::debug!("Intersection result of a pointer to {other:?}");
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
ref other => {
log::error!("Intersection result of {other:?}");
log::debug!("Intersection result of {other:?}");
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
Expand All @@ -1272,7 +1349,7 @@ impl super::Validator {
match resolver[operand] {
Ti::CooperativeMatrix { role, .. } if role == expected_role => {}
ref other => {
log::error!("{expected_role:?} operand type: {other:?}");
log::debug!("{expected_role:?} operand type: {other:?}");
return Err(ExpressionError::InvalidCooperativeOperand(a));
}
}
Expand Down
Loading