diff --git a/src/vsc/model/solvegroup_swizzler_partsel.py b/src/vsc/model/solvegroup_swizzler_partsel.py index 39fc49f..769d9b5 100644 --- a/src/vsc/model/solvegroup_swizzler_partsel.py +++ b/src/vsc/model/solvegroup_swizzler_partsel.py @@ -71,7 +71,7 @@ def swizzle_field_l(self, field_l, rs : RandSet, bound_m, btor): if len(field_l) > 0: field_idx = self.randstate.randint(0, len(field_l)-1) f = field_l.pop(field_idx) - e_l = self.swizzle_field(f, rs, bound_m) + e_l = self.swizzle_field(f, rs, bound_m, btor) if e_l is not None: for e in e_l: swizzle_node_l.append(e.build(btor)) @@ -104,52 +104,75 @@ def swizzle_field_l(self, field_l, rs : RandSet, bound_m, btor): def swizzle_field(self, f : FieldScalarModel, rs : RandSet, - bound_m : VariableBoundModel)->ExprModel: + bound_m : VariableBoundModel, + btor + ) -> list[ExprModel]: ret = None - + if self.debug > 0: print("Swizzling field %s" % f.name) - + if f in rs.dist_field_m.keys(): + max_dist_samples = 4 + for _ in range(max_dist_samples): + e = self.sample_dist_weights(f, rs) + n = e.build(btor) + btor.Assume(n) + if self.solve_info is not None: + self.solve_info.n_sat_calls += 1 + if btor.Sat() == btor.SAT: + if self.debug > 0: print(" Dist constraint SAT") + btor.Assert(n) + return ret + else: + if self.debug > 0: print(" Dist constraint UNSAT") + if self.debug > 0: print(" max_dist_samples exceeded, falling back to rand domain") + + if f in bound_m.keys(): + f_bound = bound_m[f] + if not f_bound.isEmpty(): + ret = self.create_rand_domain_constraint(f, f_bound) + + return ret + + def sample_dist_weights(self, + f : FieldScalarModel, + rs : RandSet, + ) -> ExprModel: + if self.debug > 0: + print("Note: field %s is in dist map" % f.name) + for d in rs.dist_field_m[f]: + print(" Weight list %s" % d.weight_list) + + if len(rs.dist_field_m[f]) > 1: + target_d = self.randstate.randint(0, len(rs.dist_field_m[f])-1) + dist_scope_c = rs.dist_field_m[f][target_d] + else: + dist_scope_c = rs.dist_field_m[f][0] + + target_range = dist_scope_c.next_target_range(self.randstate) + target_w = dist_scope_c.dist_c.weights[target_range] + if target_w.rng_rhs is not None: + # Dual-bound range + val_l = target_w.rng_lhs.val() + val_r = target_w.rng_rhs.val() + val = self.randstate.randint(val_l, val_r) if self.debug > 0: - print("Note: field %s is in dist map" % f.name) - for d in rs.dist_field_m[f]: - print(" Weight list %s" % d.weight_list) - if len(rs.dist_field_m[f]) > 1: - target_d = self.randstate.randint(0, len(rs.dist_field_m[f])-1) - dist_scope_c = rs.dist_field_m[f][target_d] - else: - dist_scope_c = rs.dist_field_m[f][0] - - target_range = dist_scope_c.next_target_range(self.randstate) - target_w = dist_scope_c.dist_c.weights[target_range] - if target_w.rng_rhs is not None: - # Dual-bound range - val_l = target_w.rng_lhs.val() - val_r = target_w.rng_rhs.val() - val = self.randstate.randint(val_l, val_r) - if self.debug > 0: - print("Select dist-weight range: %d..%d ; specific value %d" % ( - int(val_l), int(val_r), int(val))) - ret = [ExprBinModel( + print("Select dist-weight range: %d..%d ; specific value %d" % ( + int(val_l), int(val_r), int(val))) + ret = ExprBinModel( ExprFieldRefModel(f), BinExprType.Eq, - ExprLiteralModel(val, f.is_signed, f.width))] - else: - # Single value - val = target_w.rng_lhs.val() - if self.debug > 0: - print("Select dist-weight value %d" % (int(val))) - ret = [ExprBinModel( + ExprLiteralModel(val, f.is_signed, f.width)) + else: + # Single value + val = target_w.rng_lhs.val() + if self.debug > 0: + print("Select dist-weight value %d" % (int(val))) + ret = ExprBinModel( ExprFieldRefModel(f), BinExprType.Eq, - ExprLiteralModel(int(val), f.is_signed, f.width))] - else: - if f in bound_m.keys(): - f_bound = bound_m[f] - if not f_bound.isEmpty(): - ret = self.create_rand_domain_constraint(f, f_bound) - + ExprLiteralModel(int(val), f.is_signed, f.width)) return ret def create_rand_domain_constraint(self,