diff --git a/CHANGELOG.md b/CHANGELOG.md index bce3907a29..d7f9afc906 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -200,6 +200,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148] - Fix crash when a texture argument is missing. By @aedm in [#6486](https://github.com/gfx-rs/wgpu/pull/6486) - Emit an error in constant evaluation, rather than crash, in certain cases where `vecN` constructors have less than N arguments. By @ErichDonGubler in [#6508](https://github.com/gfx-rs/wgpu/pull/6508). +- Fix a leak by ensuring that types that depend on expressions are correctly compacted. By @KentSlaney in [#6806](https://github.com/gfx-rs/wgpu/pull/6806). #### D3D12 diff --git a/naga/src/arena/unique_arena.rs b/naga/src/arena/unique_arena.rs index c64bb302eb..db4348aa5a 100644 --- a/naga/src/arena/unique_arena.rs +++ b/naga/src/arena/unique_arena.rs @@ -23,7 +23,8 @@ use std::{fmt, hash, ops}; /// /// [`Arena`]: super::Arena #[derive(Clone)] -pub struct UniqueArena { +#[cfg_attr(test, derive(PartialEq))] +pub struct UniqueArena { set: FastIndexSet, /// Spans for the elements, indexed by handle. @@ -35,7 +36,7 @@ pub struct UniqueArena { span_info: Vec, } -impl UniqueArena { +impl UniqueArena { /// Create a new arena with no initial capacity allocated. pub fn new() -> Self { UniqueArena { @@ -182,7 +183,7 @@ impl UniqueArena { } } -impl Default for UniqueArena { +impl Default for UniqueArena { fn default() -> Self { Self::new() } @@ -194,7 +195,7 @@ impl fmt::Debug for UniqueArena { } } -impl ops::Index> for UniqueArena { +impl ops::Index> for UniqueArena { type Output = T; fn index(&self, handle: Handle) -> &T { &self.set[handle.index()] diff --git a/naga/src/block.rs b/naga/src/block.rs index 2e86a928f1..f2d2d027b3 100644 --- a/naga/src/block.rs +++ b/naga/src/block.rs @@ -6,6 +6,7 @@ use std::ops::{Deref, DerefMut, RangeBounds}; #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "serialize", serde(transparent))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub struct Block { body: Vec, #[cfg_attr(feature = "serialize", serde(skip))] diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index 7e611c35fe..b04476838f 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -1,7 +1,9 @@ use super::{HandleMap, HandleSet, ModuleMap}; -use crate::arena::{Arena, Handle}; +use crate::arena::{Arena, Handle, UniqueArena}; +use crate::compact::types::TypeTracer; pub struct ExpressionTracer<'tracer> { + pub types: Option<&'tracer UniqueArena>, pub constants: &'tracer Arena, /// The arena in which we are currently tracing expressions. @@ -28,6 +30,26 @@ pub struct ExpressionTracer<'tracer> { } impl ExpressionTracer<'_> { + fn types_used_insert(&mut self, x: Handle) { + if self.types.is_some() { + self.trace_type(x); + } + self.types_used.insert(x); + } + + fn trace_type(&mut self, x: Handle) { + fn handle2type(x: &mut TypeTracer, y: Handle) { + x.types_used.insert(y); + x.trace_type(&x.types[y], handle2type); + } + TypeTracer { + types: self.types.unwrap(), + types_used: self.types_used, + expressions_used: self.expressions_used, + } + .trace_type(&self.types.unwrap()[x], handle2type) + } + /// Propagate usage through `self.expressions`, starting with `self.expressions_used`. /// /// Treat `self.expressions_used` as the initial set of "known @@ -62,168 +84,171 @@ impl ExpressionTracer<'_> { } log::trace!("tracing new expression {:?}", expr); + self.trace_expression(expr); + } + } - use crate::Expression as Ex; - match *expr { - // Expressions that do not contain handles that need to be traced. - Ex::Literal(_) - | Ex::FunctionArgument(_) - | Ex::GlobalVariable(_) - | Ex::LocalVariable(_) - | Ex::CallResult(_) - | Ex::SubgroupBallotResult - | Ex::RayQueryProceedResult => {} + pub fn trace_expression(&mut self, expr: &crate::Expression) { + use crate::Expression as Ex; + match *expr { + // Expressions that do not contain handles that need to be traced. + Ex::Literal(_) + | Ex::FunctionArgument(_) + | Ex::GlobalVariable(_) + | Ex::LocalVariable(_) + | Ex::CallResult(_) + | Ex::SubgroupBallotResult + | Ex::RayQueryProceedResult => {} - Ex::Constant(handle) => { - self.constants_used.insert(handle); - // Constants and expressions are mutually recursive, which - // complicates our nice one-pass algorithm. However, since - // constants don't refer to each other, we can get around - // this by looking *through* each constant and marking its - // initializer as used. Since `expr` refers to the constant, - // and the constant refers to the initializer, it must - // precede `expr` in the arena. - let init = self.constants[handle].init; - match self.global_expressions_used { - Some(ref mut used) => used.insert(init), - None => self.expressions_used.insert(init), - }; - } - Ex::Override(_) => { - // All overrides are considered used by definition. We mark - // their types and initialization expressions as used in - // `compact::compact`, so we have no more work to do here. - } - Ex::ZeroValue(ty) => { - self.types_used.insert(ty); - } - Ex::Compose { ty, ref components } => { - self.types_used.insert(ty); - self.expressions_used - .insert_iter(components.iter().cloned()); - } - Ex::Access { base, index } => self.expressions_used.insert_iter([base, index]), - Ex::AccessIndex { base, index: _ } => { - self.expressions_used.insert(base); - } - Ex::Splat { size: _, value } => { - self.expressions_used.insert(value); - } - Ex::Swizzle { - size: _, - vector, - pattern: _, - } => { - self.expressions_used.insert(vector); - } - Ex::Load { pointer } => { - self.expressions_used.insert(pointer); - } - Ex::ImageSample { - image, - sampler, - gather: _, - coordinate, - array_index, - offset, - ref level, - depth_ref, - } => { - self.expressions_used - .insert_iter([image, sampler, coordinate]); - self.expressions_used.insert_iter(array_index); - match self.global_expressions_used { - Some(ref mut used) => used.insert_iter(offset), - None => self.expressions_used.insert_iter(offset), - } - use crate::SampleLevel as Sl; - match *level { - Sl::Auto | Sl::Zero => {} - Sl::Exact(expr) | Sl::Bias(expr) => { - self.expressions_used.insert(expr); - } - Sl::Gradient { x, y } => self.expressions_used.insert_iter([x, y]), - } - self.expressions_used.insert_iter(depth_ref); - } - Ex::ImageLoad { - image, - coordinate, - array_index, - sample, - level, - } => { - self.expressions_used.insert(image); - self.expressions_used.insert(coordinate); - self.expressions_used.insert_iter(array_index); - self.expressions_used.insert_iter(sample); - self.expressions_used.insert_iter(level); + Ex::Constant(handle) => { + self.constants_used.insert(handle); + // Constants and expressions are mutually recursive, which + // complicates our nice one-pass algorithm. However, since + // constants don't refer to each other, we can get around + // this by looking *through* each constant and marking its + // initializer as used. Since `expr` refers to the constant, + // and the constant refers to the initializer, it must + // precede `expr` in the arena. + let init = self.constants[handle].init; + match self.global_expressions_used { + Some(ref mut used) => used.insert(init), + None => self.expressions_used.insert(init), + }; + } + Ex::Override(_) => { + // All overrides are considered used by definition. We mark + // their types and initialization expressions as used in + // `compact::compact`, so we have no more work to do here. + } + Ex::ZeroValue(ty) => { + self.types_used_insert(ty); + } + Ex::Compose { ty, ref components } => { + self.types_used_insert(ty); + self.expressions_used + .insert_iter(components.iter().cloned()); + } + Ex::Access { base, index } => self.expressions_used.insert_iter([base, index]), + Ex::AccessIndex { base, index: _ } => { + self.expressions_used.insert(base); + } + Ex::Splat { size: _, value } => { + self.expressions_used.insert(value); + } + Ex::Swizzle { + size: _, + vector, + pattern: _, + } => { + self.expressions_used.insert(vector); + } + Ex::Load { pointer } => { + self.expressions_used.insert(pointer); + } + Ex::ImageSample { + image, + sampler, + gather: _, + coordinate, + array_index, + offset, + ref level, + depth_ref, + } => { + self.expressions_used + .insert_iter([image, sampler, coordinate]); + self.expressions_used.insert_iter(array_index); + match self.global_expressions_used { + Some(ref mut used) => used.insert_iter(offset), + None => self.expressions_used.insert_iter(offset), } - Ex::ImageQuery { image, ref query } => { - self.expressions_used.insert(image); - use crate::ImageQuery as Iq; - match *query { - Iq::Size { level } => self.expressions_used.insert_iter(level), - Iq::NumLevels | Iq::NumLayers | Iq::NumSamples => {} + use crate::SampleLevel as Sl; + match *level { + Sl::Auto | Sl::Zero => {} + Sl::Exact(expr) | Sl::Bias(expr) => { + self.expressions_used.insert(expr); } + Sl::Gradient { x, y } => self.expressions_used.insert_iter([x, y]), } - Ex::Unary { op: _, expr } => { - self.expressions_used.insert(expr); - } - Ex::Binary { op: _, left, right } => { - self.expressions_used.insert_iter([left, right]); - } - Ex::Select { - condition, - accept, - reject, - } => self - .expressions_used - .insert_iter([condition, accept, reject]), - Ex::Derivative { - axis: _, - ctrl: _, - expr, - } => { - self.expressions_used.insert(expr); - } - Ex::Relational { fun: _, argument } => { - self.expressions_used.insert(argument); - } - Ex::Math { - fun: _, - arg, - arg1, - arg2, - arg3, - } => { - self.expressions_used.insert(arg); - self.expressions_used.insert_iter(arg1); - self.expressions_used.insert_iter(arg2); - self.expressions_used.insert_iter(arg3); - } - Ex::As { - expr, - kind: _, - convert: _, - } => { - self.expressions_used.insert(expr); - } - Ex::ArrayLength(expr) => { - self.expressions_used.insert(expr); - } - Ex::AtomicResult { ty, comparison: _ } - | Ex::WorkGroupUniformLoadResult { ty } - | Ex::SubgroupOperationResult { ty } => { - self.types_used.insert(ty); - } - Ex::RayQueryGetIntersection { - query, - committed: _, - } => { - self.expressions_used.insert(query); + self.expressions_used.insert_iter(depth_ref); + } + Ex::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + self.expressions_used.insert(image); + self.expressions_used.insert(coordinate); + self.expressions_used.insert_iter(array_index); + self.expressions_used.insert_iter(sample); + self.expressions_used.insert_iter(level); + } + Ex::ImageQuery { image, ref query } => { + self.expressions_used.insert(image); + use crate::ImageQuery as Iq; + match *query { + Iq::Size { level } => self.expressions_used.insert_iter(level), + Iq::NumLevels | Iq::NumLayers | Iq::NumSamples => {} } } + Ex::Unary { op: _, expr } => { + self.expressions_used.insert(expr); + } + Ex::Binary { op: _, left, right } => { + self.expressions_used.insert_iter([left, right]); + } + Ex::Select { + condition, + accept, + reject, + } => self + .expressions_used + .insert_iter([condition, accept, reject]), + Ex::Derivative { + axis: _, + ctrl: _, + expr, + } => { + self.expressions_used.insert(expr); + } + Ex::Relational { fun: _, argument } => { + self.expressions_used.insert(argument); + } + Ex::Math { + fun: _, + arg, + arg1, + arg2, + arg3, + } => { + self.expressions_used.insert(arg); + self.expressions_used.insert_iter(arg1); + self.expressions_used.insert_iter(arg2); + self.expressions_used.insert_iter(arg3); + } + Ex::As { + expr, + kind: _, + convert: _, + } => { + self.expressions_used.insert(expr); + } + Ex::ArrayLength(expr) => { + self.expressions_used.insert(expr); + } + Ex::AtomicResult { ty, comparison: _ } + | Ex::WorkGroupUniformLoadResult { ty } + | Ex::SubgroupOperationResult { ty } => { + self.types_used_insert(ty); + } + Ex::RayQueryGetIntersection { + query, + committed: _, + } => { + self.expressions_used.insert(query); + } } } } diff --git a/naga/src/compact/functions.rs b/naga/src/compact/functions.rs index d523c1889f..9ab7bd792a 100644 --- a/naga/src/compact/functions.rs +++ b/naga/src/compact/functions.rs @@ -46,6 +46,7 @@ impl FunctionTracer<'_> { fn as_expression(&mut self) -> super::expressions::ExpressionTracer { super::expressions::ExpressionTracer { + types: None, constants: self.constants, expressions: &self.function.expressions, diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index 6b41a2c9e2..0348df35a4 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -1,4 +1,4 @@ -mod expressions; +pub mod expressions; mod functions; mod handle_set_map; mod statements; @@ -63,16 +63,6 @@ pub fn compact(module: &mut crate::Module) { } } - for (_, ty) in module.types.iter() { - if let crate::TypeInner::Array { - size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(size_expr)), - .. - } = ty.inner - { - module_tracer.global_expressions_used.insert(size_expr); - } - } - for e in module.entry_points.iter() { if let Some(sizes) = e.workgroup_size_overrides { for size in sizes.iter().filter_map(|x| *x) { @@ -112,12 +102,6 @@ pub fn compact(module: &mut crate::Module) { }) .collect(); - // Given that the above steps have marked all the constant - // expressions used directly by globals, constants, functions, and - // entry points, walk the constant expression arena to find all - // constant expressions used, directly or indirectly. - module_tracer.as_const_expression().trace_expressions(); - // Constants' initializers are taken care of already, because // expression tracing sees through constants. But we still need to // note type usage. @@ -135,8 +119,7 @@ pub fn compact(module: &mut crate::Module) { } } - // Propagate usage through types. - module_tracer.as_type().trace_types(); + module_tracer.type_expression_tandem(); // Now that we know what is used and what is never touched, // produce maps from the `Handle`s that appear in `module` now to @@ -272,15 +255,54 @@ impl<'module> ModuleTracer<'module> { } } + fn type_expression_tandem(&mut self) { + // assume there are no cycles in the type/expression graph (guaranteed by validator) + // assume that the expressions are well ordered since they're not merged like types are + // ie. expression A referring to a type referring to expression B has A > B. + // (also guaranteed by validator) + + // 1. iterate over types, skipping unused ones + // a. if the type references an expression, mark it used + // b. repeat `a` while walking referenced types, marking them as used + // 2. iterate backwards over expressions, skipping unused ones + // a. if the expression references a type + // i. walk the type's dependency tree, marking the types and their referenced + // expressions as used (types_used_insert instead of types_used.insert) + // b. if the expression references another expression, mark the latter as used + + // ┌───────────┐ ┌───────────┐ + // │Expressions│ │ Types │ + // │ ╵ ╵ │ + // │ covered by │ │ So that back/forths starting with a type now start with an + // │ step 1 │ │ expression instead. + // │ ◄────────────┘ │ + // │ │ │ │ + // │ │ │ │ + // │ ◄────────────┐ │ This arrow is only as needed. + // │ │ │ │ │ + // │ ┌────────────►│ │ + // │ │ covered by │ This covers back/forths starting with an expression. + // │ │ step 2 │ + // │ ╷ ╷ │ + // └───────────┘ └───────────┘ + + // 1 + self.as_type().trace_types(); + // 2 + self.as_const_expression().trace_expressions(); + } + fn as_type(&mut self) -> types::TypeTracer { types::TypeTracer { types: &self.module.types, types_used: &mut self.types_used, + expressions_used: &mut self.global_expressions_used, } } fn as_const_expression(&mut self) -> expressions::ExpressionTracer { expressions::ExpressionTracer { + types: Some(&self.module.types), expressions: &self.module.global_expressions, constants: &self.module.constants, types_used: &mut self.types_used, @@ -353,3 +375,101 @@ impl From> for FunctionMap { } } } + +#[test] +fn type_expression_interdependence() { + let mut module: crate::Module = Default::default(); + let u32 = module.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }), + }, + crate::Span::default(), + ); + let expr = module.global_expressions.append( + crate::Expression::Literal(crate::Literal::U32(0)), + crate::Span::default(), + ); + let type_needs_expression = |module: &mut crate::Module, handle| { + module.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Array { + base: u32, + size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(handle)), + stride: 4, + }, + }, + crate::Span::default(), + ) + }; + let expression_needs_type = |module: &mut crate::Module, handle| { + module + .global_expressions + .append(crate::Expression::ZeroValue(handle), crate::Span::default()) + }; + let expression_needs_expression = |module: &mut crate::Module, handle| { + module.global_expressions.append( + crate::Expression::Load { pointer: handle }, + crate::Span::default(), + ) + }; + let type_needs_type = |module: &mut crate::Module, handle| { + module.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Array { + base: handle, + size: crate::ArraySize::Dynamic, + stride: 0, + }, + }, + crate::Span::default(), + ) + }; + let mut type_name_counter = 0; + let mut type_needed = |module: &mut crate::Module, handle| { + let name = Some(format!("type{}", type_name_counter)); + type_name_counter += 1; + module.types.insert( + crate::Type { + name, + inner: crate::TypeInner::Array { + base: handle, + size: crate::ArraySize::Dynamic, + stride: 0, + }, + }, + crate::Span::default(), + ) + }; + let mut override_name_counter = 0; + let mut expression_needed = |module: &mut crate::Module, handle| { + let name = Some(format!("override{}", override_name_counter)); + override_name_counter += 1; + module.overrides.append( + crate::Override { + name, + id: None, + ty: u32, + init: Some(handle), + }, + crate::Span::default(), + ) + }; + // borrow checker breaks without the tmp variables as of Rust 1.83.0 + let expr_end = type_needs_expression(&mut module, expr); + let ty_trace = type_needs_type(&mut module, expr_end); + let expr_init = expression_needs_type(&mut module, ty_trace); + expression_needed(&mut module, expr_init); + let ty_end = expression_needs_type(&mut module, u32); + let expr_trace = expression_needs_expression(&mut module, ty_end); + let ty_init = type_needs_expression(&mut module, expr_trace); + type_needed(&mut module, ty_init); + let untouched = module.clone(); + compact(&mut module); + assert!(module == untouched); +} diff --git a/naga/src/compact/types.rs b/naga/src/compact/types.rs index ae4ae35580..e664ff7d63 100644 --- a/naga/src/compact/types.rs +++ b/naga/src/compact/types.rs @@ -4,6 +4,7 @@ use crate::{Handle, UniqueArena}; pub struct TypeTracer<'a> { pub types: &'a UniqueArena, pub types_used: &'a mut HandleSet, + pub expressions_used: &'a mut HandleSet, } impl TypeTracer<'_> { @@ -24,34 +25,58 @@ impl TypeTracer<'_> { continue; } - use crate::TypeInner as Ti; - match ty.inner { - // Types that do not contain handles. - Ti::Scalar { .. } - | Ti::Vector { .. } - | Ti::Matrix { .. } - | Ti::Atomic { .. } - | Ti::ValuePointer { .. } - | Ti::Image { .. } - | Ti::Sampler { .. } - | Ti::AccelerationStructure - | Ti::RayQuery => {} + self.trace_type(ty, |x, y| { + x.types_used.insert(y); + }); + } + } - // Types that do contain handles. - Ti::Pointer { base, space: _ } - | Ti::Array { - base, - size: _, - stride: _, - } - | Ti::BindingArray { base, size: _ } => { - self.types_used.insert(base); - } - Ti::Struct { - ref members, - span: _, - } => { - self.types_used.insert_iter(members.iter().map(|m| m.ty)); + pub fn trace_type( + &mut self, + ty: &crate::Type, + callback: impl Fn(&mut Self, Handle), + ) { + use crate::TypeInner as Ti; + match ty.inner { + // Types that do not contain handles. + Ti::Scalar { .. } + | Ti::Vector { .. } + | Ti::Matrix { .. } + | Ti::Atomic { .. } + | Ti::ValuePointer { .. } + | Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery => {} + + // Types that do contain handles. + Ti::Array { + base, + size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(expr)), + stride: _, + } + | Ti::BindingArray { + base, + size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(expr)), + } => { + self.expressions_used.insert(expr); + callback(self, base); + } + Ti::Pointer { base, space: _ } + | Ti::Array { + base, + size: _, + stride: _, + } + | Ti::BindingArray { base, size: _ } => { + callback(self, base); + } + Ti::Struct { + ref members, + span: _, + } => { + for m in members.iter() { + callback(self, m.ty); } } } diff --git a/naga/src/diagnostic_filter.rs b/naga/src/diagnostic_filter.rs index 2fa5464cdf..e3062bfa52 100644 --- a/naga/src/diagnostic_filter.rs +++ b/naga/src/diagnostic_filter.rs @@ -94,6 +94,7 @@ impl StandardFilterableTriggeringRule { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub struct DiagnosticFilter { pub new_severity: Severity, pub triggering_rule: FilterableTriggeringRule, @@ -235,6 +236,7 @@ pub(crate) struct ConflictingDiagnosticRuleError { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub struct DiagnosticFilterNode { pub inner: DiagnosticFilter, pub parent: Option>, diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 687dc5b441..34f22e622a 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -999,6 +999,7 @@ pub struct GlobalVariable { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub struct LocalVariable { /// Name of the variable, if any. pub name: Option, @@ -1720,6 +1721,7 @@ pub enum SwitchValue { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub struct SwitchCase { /// Value, upon which the case is considered true. pub value: SwitchValue, @@ -1738,6 +1740,7 @@ pub struct SwitchCase { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub enum RayQueryFunction { /// Initialize the `RayQuery` object. Initialize { @@ -1782,6 +1785,7 @@ pub enum RayQueryFunction { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub enum Statement { /// Emit a range of expressions, visible to all statements that follow in this block. /// @@ -2077,6 +2081,7 @@ pub enum Statement { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub struct FunctionArgument { /// Name of the argument, if any. pub name: Option, @@ -2092,6 +2097,7 @@ pub struct FunctionArgument { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub struct FunctionResult { /// Type of the result. pub ty: Handle, @@ -2105,6 +2111,7 @@ pub struct FunctionResult { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub struct Function { /// Name of the function, if any. pub name: Option, @@ -2187,6 +2194,7 @@ pub struct Function { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub struct EntryPoint { /// Name of this entry point, visible externally. /// @@ -2230,6 +2238,7 @@ pub enum PredeclaredType { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub struct SpecialTypes { /// Type for `RayDesc`. /// @@ -2321,6 +2330,7 @@ pub enum RayQueryIntersection { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub struct Module { /// Arena for the types defined in this module. pub types: UniqueArena, diff --git a/naga/src/span.rs b/naga/src/span.rs index 7c1ce17dca..7f3a461db3 100644 --- a/naga/src/span.rs +++ b/naga/src/span.rs @@ -345,7 +345,7 @@ impl SpanProvider for Arena { } } -impl SpanProvider for UniqueArena { +impl SpanProvider for UniqueArena { fn get_span(&self, handle: Handle) -> Span { self.get_span(handle) } diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index be4eb3dbac..fc25aa24a9 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -9,6 +9,9 @@ use crate::{ use crate::non_max_u32::NonMaxU32; use crate::{Arena, UniqueArena}; +#[cfg(feature = "compact")] +use crate::{arena::HandleSet, compact::expressions::ExpressionTracer}; + use super::ValidationError; use std::{convert::TryInto, hash::Hash}; @@ -74,6 +77,7 @@ impl super::Validator { for handle_and_expr in global_expressions.iter() { Self::validate_const_expression_handles(handle_and_expr, constants, overrides, types)?; + Self::well_ordered_deps(handle_and_expr, constants, global_expressions, types)?; } let validate_type = |handle| Self::validate_type_handle(handle, types); @@ -629,6 +633,40 @@ impl super::Validator { | crate::Statement::Barrier(_) => Ok(()), }) } + + #[cfg(feature = "compact")] + pub fn well_ordered_deps( + (handle, expression): (Handle, &crate::Expression), + constants: &Arena, + global_expressions: &Arena, + types: &UniqueArena, + ) -> Result<(), InvalidHandleError> { + let mut exprs = HandleSet::for_arena(global_expressions); + ExpressionTracer { + types: Some(types), + expressions: global_expressions, + constants, + types_used: &mut HandleSet::for_arena(types), + constants_used: &mut HandleSet::for_arena(constants), + expressions_used: &mut exprs, + global_expressions_used: None, + } + .trace_expression(expression); + if let Err(error) = handle.check_dep_iter(exprs.iter()) { + return Err(InvalidHandleError::ForwardDependency(error)); + } + Ok(()) + } + + #[cfg(not(feature = "compact"))] + pub const fn well_ordered_deps( + (_handle, _expression): (Handle, &crate::Expression), + _constants: &Arena, + _global_expressions: &Arena, + _types: &UniqueArena, + ) -> Result<(), InvalidHandleError> { + Ok(()) + } } impl From for ValidationError { @@ -788,3 +826,46 @@ fn constant_deps() { .is_err()); } } + +#[test] +fn well_ordered_expressions() { + use super::Validator; + use crate::{ArraySize, Expression, Literal, PendingArraySize, Scalar, Span, Type, TypeInner}; + + let nowhere = Span::default(); + + let mut m = crate::Module::default(); + + let ty_u32 = m.types.insert( + Type { + name: Some("u32".to_string()), + inner: TypeInner::Scalar(Scalar::U32), + }, + nowhere, + ); + let ex_zero = m + .global_expressions + .append(Expression::ZeroValue(ty_u32), nowhere); + let expr = m + .global_expressions + .append(Expression::Literal(Literal::U32(0)), nowhere); + let ty_arr = m.types.insert( + Type { + name: Some("bad_array".to_string()), + inner: TypeInner::Array { + base: ty_u32, + size: ArraySize::Pending(PendingArraySize::Expression(expr)), + stride: 4, + }, + }, + nowhere, + ); + + // Everything should be okay now. + assert!(Validator::validate_module_handles(&m).is_ok()); + + // Mutate `ex_zero`'s type to `ty_arr`, introducing an out of order dependency. + // Validation should catch the problem. + m.global_expressions[ex_zero] = Expression::ZeroValue(ty_arr); + assert!(Validator::validate_module_handles(&m).is_err()); +}