From 7f16613ed92be2a7b2482319f28e09220b442ccf Mon Sep 17 00:00:00 2001 From: Alex Wilson Date: Tue, 30 Sep 2025 08:19:23 -0600 Subject: [PATCH] swizzler: add dist constraint retry and fallback Extends the distribution constraint exploration of the swizzler to retry the values it selects some number of times (4 for now) and if those all fail, fall back to randomizing the field bits as any other variable. The swizzler was only picking one value, and if the given constraints excluded most values then it resulted in the solver tossing it and leaving the field unrandomized, which in many cases lets the solver default to 0 or some other value repeatedly. See #243 --- src/vsc/model/solvegroup_swizzler_partsel.py | 99 ++++++++++++-------- 1 file changed, 61 insertions(+), 38 deletions(-) diff --git a/src/vsc/model/solvegroup_swizzler_partsel.py b/src/vsc/model/solvegroup_swizzler_partsel.py index 39fc49f..769d9b5 100644 --- a/src/vsc/model/solvegroup_swizzler_partsel.py +++ b/src/vsc/model/solvegroup_swizzler_partsel.py @@ -71,7 +71,7 @@ def swizzle_field_l(self, field_l, rs : RandSet, bound_m, btor): if len(field_l) > 0: field_idx = self.randstate.randint(0, len(field_l)-1) f = field_l.pop(field_idx) - e_l = self.swizzle_field(f, rs, bound_m) + e_l = self.swizzle_field(f, rs, bound_m, btor) if e_l is not None: for e in e_l: swizzle_node_l.append(e.build(btor)) @@ -104,52 +104,75 @@ def swizzle_field_l(self, field_l, rs : RandSet, bound_m, btor): def swizzle_field(self, f : FieldScalarModel, rs : RandSet, - bound_m : VariableBoundModel)->ExprModel: + bound_m : VariableBoundModel, + btor + ) -> list[ExprModel]: ret = None - + if self.debug > 0: print("Swizzling field %s" % f.name) - + if f in rs.dist_field_m.keys(): + max_dist_samples = 4 + for _ in range(max_dist_samples): + e = self.sample_dist_weights(f, rs) + n = e.build(btor) + btor.Assume(n) + if self.solve_info is not None: + self.solve_info.n_sat_calls += 1 + if btor.Sat() == btor.SAT: + if self.debug > 0: print(" Dist constraint SAT") + btor.Assert(n) + return ret + else: + if self.debug > 0: print(" Dist constraint UNSAT") + if self.debug > 0: print(" max_dist_samples exceeded, falling back to rand domain") + + if f in bound_m.keys(): + f_bound = bound_m[f] + if not f_bound.isEmpty(): + ret = self.create_rand_domain_constraint(f, f_bound) + + return ret + + def sample_dist_weights(self, + f : FieldScalarModel, + rs : RandSet, + ) -> ExprModel: + if self.debug > 0: + print("Note: field %s is in dist map" % f.name) + for d in rs.dist_field_m[f]: + print(" Weight list %s" % d.weight_list) + + if len(rs.dist_field_m[f]) > 1: + target_d = self.randstate.randint(0, len(rs.dist_field_m[f])-1) + dist_scope_c = rs.dist_field_m[f][target_d] + else: + dist_scope_c = rs.dist_field_m[f][0] + + target_range = dist_scope_c.next_target_range(self.randstate) + target_w = dist_scope_c.dist_c.weights[target_range] + if target_w.rng_rhs is not None: + # Dual-bound range + val_l = target_w.rng_lhs.val() + val_r = target_w.rng_rhs.val() + val = self.randstate.randint(val_l, val_r) if self.debug > 0: - print("Note: field %s is in dist map" % f.name) - for d in rs.dist_field_m[f]: - print(" Weight list %s" % d.weight_list) - if len(rs.dist_field_m[f]) > 1: - target_d = self.randstate.randint(0, len(rs.dist_field_m[f])-1) - dist_scope_c = rs.dist_field_m[f][target_d] - else: - dist_scope_c = rs.dist_field_m[f][0] - - target_range = dist_scope_c.next_target_range(self.randstate) - target_w = dist_scope_c.dist_c.weights[target_range] - if target_w.rng_rhs is not None: - # Dual-bound range - val_l = target_w.rng_lhs.val() - val_r = target_w.rng_rhs.val() - val = self.randstate.randint(val_l, val_r) - if self.debug > 0: - print("Select dist-weight range: %d..%d ; specific value %d" % ( - int(val_l), int(val_r), int(val))) - ret = [ExprBinModel( + print("Select dist-weight range: %d..%d ; specific value %d" % ( + int(val_l), int(val_r), int(val))) + ret = ExprBinModel( ExprFieldRefModel(f), BinExprType.Eq, - ExprLiteralModel(val, f.is_signed, f.width))] - else: - # Single value - val = target_w.rng_lhs.val() - if self.debug > 0: - print("Select dist-weight value %d" % (int(val))) - ret = [ExprBinModel( + ExprLiteralModel(val, f.is_signed, f.width)) + else: + # Single value + val = target_w.rng_lhs.val() + if self.debug > 0: + print("Select dist-weight value %d" % (int(val))) + ret = ExprBinModel( ExprFieldRefModel(f), BinExprType.Eq, - ExprLiteralModel(int(val), f.is_signed, f.width))] - else: - if f in bound_m.keys(): - f_bound = bound_m[f] - if not f_bound.isEmpty(): - ret = self.create_rand_domain_constraint(f, f_bound) - + ExprLiteralModel(int(val), f.is_signed, f.width)) return ret def create_rand_domain_constraint(self,