From b1c829896425d9f279876264f4dc71a1f6bc50bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Basile=20Cl=C3=A9ment?= Date: Fri, 26 Jan 2024 13:12:14 +0100 Subject: [PATCH] cleanup(CP): Simplify constraint handling This patch rewrites the `Constraints_make` module in order to make it more flexible (this is preparatory step for dealing with arithmetic bit-vector constraints, see #903). More precisely: - Constraints no longer need to keep track of their own explanations, this is now entirely done by the `Constraints_make` functor. This makes it simpler to write `Constraint` modules, and avoid repeating boilerplate code to deal with explanation storage. Instead, the explanations are provided to the `Constraint` module in its `propagate` function. - The `Constraints_make` functor no longer need to know about constraint propagation. Instead, it simply keeps track of constraints that need to be propagated (pending constraints), and provides an API to iterate (and remove) the set of constraints to be propagated, letting the user take care of propagation proper. - The `Constraints_make` functor now tracks separately the constraint arguments and the leaves of said arguments. The leaves are used to know which constraints need to be updated during a substitution, and the arguments are used to mark as pending all constraints that apply to a given representative when its domain is updated (note that, for the bit-list domains, we actually store the domains at the leaves, so the arguments mapping is not used -- but this still makes the module more flexible in general, and in particular will allow to introduce arithmetic domains that need to be stored for all values, not only leaves, for the purpose of #903 in a future PR) The new design should make it easier to write `Constraint` modules. It also fixes a bug in the contract between the `Constraint` and `Domain` modules regarding substitution: the `Domain` modules was written under the assumption that constraints applying to `changed` domains would always be marked as pending upon substitution, but that is not true because we use functional representations where such updates are delayed, and hence the `changed` flag needs to be propagated after substitution (if appropriate). --- src/lib/reasoners/bitv_rel.ml | 124 +++++++------- src/lib/reasoners/rel_utils.ml | 292 ++++++++++++++++++--------------- 2 files changed, 226 insertions(+), 190 deletions(-) diff --git a/src/lib/reasoners/bitv_rel.ml b/src/lib/reasoners/bitv_rel.ml index 0bfc973692..6b3ed42f1c 100644 --- a/src/lib/reasoners/bitv_rel.ml +++ b/src/lib/reasoners/bitv_rel.ml @@ -223,12 +223,18 @@ end = struct let subst ex rr nrr t = match MX.find rr t.bitlists with | bl -> - (* Note: even if [rr] had changed its domain, we don't need to keep that - information because if the constraints that used to apply to [rr] were - not already valid, they will be marked as fresh in the [Constraints.t] - after substitution there and propagated already. *) + (* The substitution code for constraints requires that we properly update + the [changed] field here: if the domain of [rr] has changed, + constraints that applied to [rr] will apply to [nrr] after + substitution and must be propagated again. *) + let changed = + if SX.mem rr t.changed then + SX.add nrr t.changed + else + t.changed + in let t = - { changed = SX.remove rr t.changed + { changed = SX.remove rr changed ; bitlists = MX.remove rr t.bitlists } in @@ -243,19 +249,25 @@ end = struct end module Constraint : sig - include Rel_utils.Constraint with type domain = Domains.t + include Rel_utils.Constraint + + val bvand : X.r -> X.r -> X.r -> t + (** [bvand x y z] is the constraint [x = y & z] *) - val bvand : ex:Ex.t -> X.r -> X.r -> X.r -> t - (** [bvand ~ex x y z] is the constraint [x = y & z] *) + val bvor : X.r -> X.r -> X.r -> t + (** [bvor x y z] is the constraint [x = y | z] *) - val bvor : ex:Ex.t -> X.r -> X.r -> X.r -> t - (** [bvor ~ex x y z] is the constraint [x = y | z] *) + val bvxor : X.r -> X.r -> X.r -> t + (** [bvxor x y z] is the constraint [x ^ y ^ z = 0] *) - val bvxor : ex:Ex.t -> X.r -> X.r -> X.r -> t - (** [bvxor ~ex x y z] is the constraint [x ^ y ^ z = 0] *) + val bvnot : X.r -> X.r -> t + (** [bvnot x y] is the constraint [x = not y] *) - val bvnot : ex:Ex.t -> X.r -> X.r -> t - (** [bvnot ~ex x y] is the constraint [x = not y] *) + val propagate : ex:Ex.t -> t -> Domains.t -> Domains.t + (** [propagate ~ex t dom] propagates the constraint [t] in domain [dom]. + + The explanation [ex] justifies that the constraint [t] applies, and must + be added to any domain that gets updated during propagation. *) end = struct type repr = | Band of X.r * X.r * X.r @@ -292,10 +304,10 @@ end = struct Hashtbl.hash (2, SX.fold (fun r acc -> X.hash r :: acc) xs []) | Bnot (x, y) -> Hashtbl.hash (2, X.hash x, X.hash y) - type tagged_repr = { repr : repr ; mutable tag : int } + type t = { repr : repr ; mutable tag : int } module W = Weak.Make(struct - type t = tagged_repr + type nonrec t = t let equal { repr = lhs; _ } { repr = rhs; _ } = equal_repr lhs rhs @@ -353,19 +365,15 @@ end = struct and y = X.subst rr nrr y in Bnot (x, y) - (* The explanation justifies why the constraint holds. *) - type t = { repr : tagged_repr ; ex : Ex.t } - - let pp ppf { repr; _ } = pp_repr ppf repr.repr + let pp ppf { repr; _ } = pp_repr ppf repr - let compare { repr = r1; _ } { repr = r2; _ } = - Int.compare r1.tag r2.tag + let compare { tag = t1; _ } { tag = t2; _ } = Stdlib.compare t1 t2 - let subst ex rr nrr c = - { repr = hcons @@ subst_repr rr nrr c.repr.repr ; ex = Ex.union ex c.ex } + let subst rr nrr c = + hcons @@ subst_repr rr nrr c.repr - let fold_deps f { repr; _ } acc = - match repr.repr with + let fold_args f { repr; _ } acc = + match repr with | Band (x, y, z) | Bor (x, y, z) -> let acc = f x acc in let acc = f y acc in @@ -377,16 +385,9 @@ end = struct let acc = f y acc in acc - let fold_leaves f c acc = - fold_deps (fun r acc -> - List.fold_left (fun acc r -> f r acc) acc (X.leaves r) - ) c acc - - type domain = Domains.t - - let propagate { repr; ex } dom = + let propagate ~ex { repr; _ } dom = Steps.incr CP; - match repr.repr with + match repr with | Band (x, y, z) -> let dx = Domains.get x dom and dy = Domains.get y dom @@ -446,19 +447,17 @@ end = struct let dom = Domains.update ex y dom @@ Bitlist.lognot dx in dom - let make ?(ex = Ex.empty) repr = { repr = hcons repr ; ex } - - let bvand ~ex x y z = make ~ex @@ Band (x, y, z) - let bvor ~ex x y z = make ~ex @@ Bor (x, y, z) - let bvxor ~ex x y z = + let bvand x y z = hcons @@ Band (x, y, z) + let bvor x y z = hcons @@ Bor (x, y, z) + let bvxor x y z = let xs = SX.singleton x in let xs = if SX.mem y xs then SX.remove y xs else SX.add y xs in let xs = if SX.mem z xs then SX.remove z xs else SX.add z xs in - make ~ex @@ Bxor xs - let bvnot ~ex x y = make ~ex @@ Bnot (x, y) + hcons @@ Bxor xs + let bvnot x y = hcons @@ Bnot (x, y) end -module Constraints = Rel_utils.Constraints_Make(Constraint) +module Constraints = Rel_utils.Constraints_make(Constraint) let extract_constraints bcs uf r t = match E.term_view t with @@ -466,19 +465,19 @@ let extract_constraints bcs uf r t = without needing a round-trip through Uf *) | { f = Op BVnot; xs = [ x ] ; _ } -> let rx, exx = Uf.find uf x in - Constraints.add bcs @@ Constraint.bvnot ~ex:exx r rx + Constraints.add ~ex:exx (Constraint.bvnot r rx) bcs | { f = Op BVand; xs = [ x; y ]; _ } -> let rx, exx = Uf.find uf x and ry, exy = Uf.find uf y in - Constraints.add bcs @@ Constraint.bvand ~ex:(Ex.union exx exy) r rx ry + Constraints.add ~ex:(Ex.union exx exy) (Constraint.bvand r rx ry) bcs | { f = Op BVor; xs = [ x; y ]; _ } -> let rx, exx = Uf.find uf x and ry, exy = Uf.find uf y in - Constraints.add bcs @@ Constraint.bvor ~ex:(Ex.union exx exy) r rx ry + Constraints.add ~ex:(Ex.union exx exy) (Constraint.bvor r rx ry) bcs | { f = Op BVxor; xs = [ x; y ]; _ } -> let rx, exx = Uf.find uf x and ry, exy = Uf.find uf y in - Constraints.add bcs @@ Constraint.bvxor ~ex:(Ex.union exx exy) r rx ry + Constraints.add ~ex:(Ex.union exx exy) (Constraint.bvxor r rx ry) bcs | _ -> bcs let rec mk_eq ex lhs w z = @@ -530,21 +529,26 @@ let add_eqs = (* Propagate: - - The constraints that were never propagated since they were added (this - includes constraints that changed due to substitutions) + - The constraints that were never propagated since they were added - The constraints involving variables whose domain changed since the last - propagation *) + propagation + + Iterate until fixpoint is reached. *) let propagate = let rec propagate changed bcs dom = - match Domains.choose_changed dom with - | r, dom -> - propagate (SX.add r changed) bcs @@ - Constraints.propagate bcs r dom - | exception Not_found -> changed, dom + match Constraints.next_pending bcs with + | { value; explanation = ex }, bcs -> + let dom = Constraint.propagate ~ex value dom in + propagate changed bcs dom + | exception Not_found -> + match Domains.choose_changed dom with + | r, dom -> + propagate (SX.add r changed) (Constraints.notify_leaf r bcs) dom + | exception Not_found -> + changed, bcs, dom in fun bcs dom -> - let bcs, dom = Constraints.propagate_fresh bcs dom in - let changed, dom = propagate SX.empty bcs dom in + let changed, bcs, dom = propagate SX.empty bcs dom in SX.fold (fun r acc -> add_eqs acc (Shostak.Bitv.embed r) (Domains.get r dom) ) changed [], bcs, dom @@ -580,7 +584,7 @@ let assume env uf la = match a, orig with | L.Eq (rr, nrr), Subst when is_bv_r rr -> let dom = Domains.subst ex rr nrr dom in - let bcs = Constraints.subst ex rr nrr bcs in + let bcs = Constraints.subst ~ex rr nrr bcs in ((bcs, dom), ss) | L.Distinct (false, [rr; nrr]), _ when is_1bit rr -> (* We don't (yet) support [distinct] in general, but we must @@ -595,7 +599,7 @@ let assume env uf la = let nrr, exnrr = Uf.find_r uf nrr in let ex = Ex.union ex (Ex.union exrr exnrr) in let bcs = - Constraints.add bcs @@ Constraint.bvnot ~ex rr nrr + Constraints.add ~ex (Constraint.bvnot rr nrr) bcs in ((bcs, dom), ss) | _ -> ((bcs, dom), ss) @@ -649,7 +653,7 @@ let case_split env _uf ~for_model = in let _, candidates = match - Constraints.fold_r (fun r acc -> + Constraints.fold_args (fun r acc -> List.fold_left (fun acc { Bitv.bv; _ } -> match bv with | Bitv.Cte _ -> acc diff --git a/src/lib/reasoners/rel_utils.ml b/src/lib/reasoners/rel_utils.ml index f79d48d217..4660feb3f9 100644 --- a/src/lib/reasoners/rel_utils.ml +++ b/src/lib/reasoners/rel_utils.ml @@ -197,182 +197,214 @@ module type Constraint = sig type t (** The type of constraints. - Constraints are associated with a justification as to why they are - currently valid. The justification is only used to update domains, - identical constraints with different justifications will otherwise behave - identically (and, notably, will compare equal). - - Constraints contains semantic values / term representatives of type - [X.r]. We maintain the invariant that the semantic values used inside the - constraints are *class representatives* i.e. normal forms wrt the `Uf` - module, i.e. constraints have a normalized representation. Use `subst` to - ensure normalization. *) + Constraints apply to semantic values of type [X.r] as arguments. *) val pp : t Fmt.t (** Pretty-printer for constraints. *) val compare : t -> t -> int - (** Comparison function for constraints. - - Constraints typically include explanations, which should not be included - in the comparison function: code working with constraints expects - constraints with identical representations but different explanations to - compare equal. - - {b Note}: The comparison function is arbitrary and has no semantic - meaning. You should not depend on any of its properties, other than it - defines an (arbitrary) total order on constraint representations. *) - - val subst : Explanation.t -> X.r -> X.r -> t -> t - (** [subst ex p v cs] replaces all the instances of [p] with [v] in the + (** Comparison function for constraints. The comparison function is + arbitrary and has no semantic meaning. You should not depend on any of + its properties, other than it defines an (arbitrary) total order on + constraint representations. *) + + val fold_args : (X.r -> 'a -> 'a) -> t -> 'a -> 'a + (** [fold_args f c acc] folds function [f] over the arguments of constraint + [c]. + + During propagation, the constraint {b MUST} only look at (and update) + the domains associated of its arguments; it is not allowed to look at + the domains of other semantic values. This allows efficient updates of + the pending constraints. *) + + val subst : X.r -> X.r -> t -> t + (** [subst p v cs] replaces all the instances of [p] with [v] in the constraint. - Use this to ensure that the representation is always normalized. - - The explanation [ex] justifies the equality [p = v]. *) - - val fold_leaves : (X.r -> 'a -> 'a) -> t -> 'a -> 'a - - type domain - (** The type of domains. - - This is typically a mapping from variables to their own domain, but no - expectations is made upon the actual structure of that type. *) + {b Note}: Substitution {b MUST} be inert, i.e. substitution {b MUST NOT} + depend on the result of the substitution applied to the constraint + arguments. The new constraint must have the same effect as the old one + on *any* domain where the substitution has been applied. This is allows + efficient updates of the pending constraints. *) +end - val propagate : t -> domain -> domain - (** [propagate c dom] propagates the constraints [c] in [d] and returns the - new domain. *) +type 'a explained = { value : 'a ; explanation : Explanation.t } -end +let explained ~ex value = { value ; explanation = ex } -module Constraints_Make(Constraint : Constraint) : sig +module Constraints_make(Constraint : Constraint) : sig type t - (** The type of constraint sets. A constraint sets records a set of - constraints that applies to semantic values, and remembers which - constraints are associated with each semantic values. + (** The type of constraint sets. A constraint set records a set of + constraints that applies to semantic values, and remembers the relation + between constraints and semantic values. - It is used to only propagate constraints involving semantic values whose - associated domain has changed. + The constraints associated with specific semantic values can be notified + (see [notify]), which is used to only propagate constraints involving + semantic values whose domain has changed. - The constraint sets are expected to keep track of *class representatives*, - i.e. normal forms wrt the `Uf` module, in which case we say the - constraint set is *normalized*. Use `subst` to ensure normalization. *) + The constraints that have been notified are called "pending + constraints", and the set thereof is the "pending set". These are + constraints that need to be propagated, and can be recovered using + [next_pending]. *) val pp : t Fmt.t (** Pretty-printer for constraint sets. *) val empty : t - (** Returns an empty constraint set. *) + (** The empty constraint set. *) - val subst : Explanation.t -> X.r -> X.r -> t -> t - (** [subst ex p v cs] replaces all the instances of [p] with [v] in the - constraints. + val add : ?pending:bool -> ex:Explanation.t -> Constraint.t -> t -> t + (** [add ~ex c t] adds the constraint [c] to the set [t]. - Use this to ensure that the representation is always normalized. + The explanation [ex] justifies that the constraint [c] holds. + + The constraint is only added to the pending set if it was not already + active (i.e. previously added). Setting the [pending] optional argument to + [true] forces the constraint to be marked as pending even if it is already + active. *) + + val subst : ex:Explanation.t -> X.r -> X.r -> t -> t + (** [subst ~ex p v t] replaces all instances of [p] with [v] in the + constraints. The explanation [ex] justifies the equality [p = v]. *) - val add : t -> Constraint.t -> t - (** [add c cs] adds the constraint [c] to [cs]. *) + val notify : X.r -> t -> t + (** [notify r t] marks all constraints involving [r] (i.e. all constraints + that have [r] as one of their arguments) as pending. - val propagate_fresh : t -> Constraint.domain -> t * Constraint.domain - (** [propagate_fresh cs acc] propagates the fresh constraints and returns the - new domain, as well as a copy of the constraint set with no fresh - constraints. + This function should be used when the domain of [r] is updated, if + domains are tracked for all representatives. *) + + val notify_leaf : X.r -> t -> t + (** [notify_leaf r t] marks all constraints that have [r] as a leaf (i.e. + all constraints that have at least one argument [a] such that [r] is in + [X.leaves a]) as pending. - Fresh constraints are constraints that were never propagated yet. *) + This function should be used when the domain of [r] is updated, if + domains are tracked for leaves only. *) - val fold_r : (X.r -> 'a -> 'a) -> t -> 'a -> 'a - (** [fold_r f cs acc] folds [f] over any representative [r] that is currently - associated with a constraint (i.e. at least one constraint currently - applies to [r]). *) + val fold_args : (X.r -> 'a -> 'a) -> t -> 'a -> 'a + (** [fold_args f t acc] folds [f] over all the term representatives that are + arguments of at least one constraint. *) + + val next_pending : t -> Constraint.t explained * t + (** [next_pending t] returns a pair [c, t'] where [c] was pending in [t] and + [t'] is identical to [t], except that [c] is no longer a pending + constraint. - val propagate : t -> X.r -> Constraint.domain -> Constraint.domain - (** [propagate cs r dom] propagates the constraints associated with [r] in the - constraint set [cs] and returns the new domain map after propagation. *) + @raise Not_found if there are no pending constraints. *) end = struct - module IM = Util.MI module MX = Shostak.MXH - module CS = Set.Make(Constraint) + module CS = Set.Make(struct + type t = Constraint.t explained + + let compare a b = Constraint.compare a.value b.value + end) type t = { - cs_set : CS.t ; - (*** All the constraints currently active *) - cs_map : CS.t MX.t ; - (*** Mapping from semantic values to the constraints that involves them *) - fresh : CS.t ; - (*** Fresh constraints that have never been propagated *) + args_map : CS.t MX.t ; + (** Mapping from semantic values to constraints involving them *) + + leaves_map : CS.t MX.t ; + (** Mapping from semantic values to constraints they are a leaf of *) + + active : CS.t ; + (** Set of all currently active constraints *) + + pending : CS.t ; + (** Set of active constraints that have not yet been propagated *) } - let pp ppf { cs_set; cs_map = _ ; fresh = _ } = + let pp ppf { active; _ } = Fmt.( braces @@ hvbox @@ iter ~sep:semi CS.iter @@ + using (fun { value; _ } -> value) @@ box ~indent:2 @@ braces @@ Constraint.pp - ) ppf cs_set + ) ppf active let empty = - { cs_set = CS.empty - ; cs_map = MX.empty - ; fresh = CS.empty } + { args_map = MX.empty + ; leaves_map = MX.empty + ; active = CS.empty + ; pending = CS.empty } - let cs_add cs r cs_map = + let cs_add c r cs_map = MX.update r (function - | Some css -> Some (CS.add cs css) - | None -> Some (CS.singleton cs) + | Some cs -> Some (CS.add c cs) + | None -> Some (CS.singleton c) ) cs_map - let cs_remove cs r cs_map = + let fold_leaves f c acc = + Constraint.fold_args (fun r acc -> + List.fold_left (fun acc r -> f r acc) acc (X.leaves r) + ) c acc + + let add ?(pending = false) ~ex c t = + let c = explained ~ex c in + (* Note: use [CS.find] here, not [CS.mem], to ensure we use the same + explanation for [c] in the [pending] and [active] sets. *) + match CS.find c t.active with + | c -> + if pending then { t with pending = CS.add c t.pending } else t + | exception Not_found -> + let active = CS.add c t.active in + let args_map = + Constraint.fold_args (cs_add c) c.value t.args_map + in + let leaves_map = fold_leaves (cs_add c) c.value t.leaves_map in + let pending = CS.add c t.pending in + { active; args_map; leaves_map; pending } + + let cs_remove c r cs_map = MX.update r (function - | Some css -> - let css = CS.remove cs css in - if CS.is_empty css then None else Some css - | None -> - (* Can happen if the same argument is repeated *) - None + | Some cs -> + let cs = CS.remove c cs in + if CS.is_empty cs then None else Some cs + | None -> None ) cs_map - let subst ex rr nrr bcs = - match MX.find rr bcs.cs_map with - | ids -> - let cs_map, cs_set, fresh = - CS.fold (fun cs (cs_map, cs_set, fresh) -> - let fresh = CS.remove cs fresh in - let cs_set = CS.remove cs cs_set in - let cs_map = Constraint.fold_leaves (cs_remove cs) cs cs_map in - let cs' = Constraint.subst ex rr nrr cs in - if CS.mem cs' cs_set then - cs_map, cs_set, fresh - else - let cs_set = CS.add cs' cs_set in - let cs_map = Constraint.fold_leaves (cs_add cs') cs' cs_map in - (cs_map, cs_set, CS.add cs' fresh) - ) ids (bcs.cs_map, bcs.cs_set, bcs.fresh) - in - assert (not (MX.mem rr cs_map)); - { cs_set ; cs_map ; fresh } - | exception Not_found -> bcs - - let add bcs c = - if CS.mem c bcs.cs_set then - bcs - else - let cs_set = CS.add c bcs.cs_set in - let cs_map = Constraint.fold_leaves (cs_add c) c bcs.cs_map in - let fresh = CS.add c bcs.fresh in - { cs_set ; cs_map ; fresh } - - let fold_r f bcs acc = - MX.fold (fun r _ acc -> f r acc) bcs.cs_map acc - - let propagate bcs r dom = - match MX.find r bcs.cs_map with - | cs -> CS.fold Constraint.propagate cs dom - | exception Not_found -> dom - - let propagate_fresh bcs dom = - let dom = CS.fold Constraint.propagate bcs.fresh dom in - { bcs with fresh = CS.empty }, dom + let remove c t = + let active = CS.remove c t.active in + let args_map = + Constraint.fold_args (cs_remove c) c.value t.args_map + in + let leaves_map = fold_leaves (cs_remove c) c.value t.leaves_map in + let pending = CS.remove c t.pending in + { active; args_map; leaves_map; pending } + + let subst ~ex rr nrr t = + match MX.find rr t.leaves_map with + | cs -> + CS.fold (fun c t -> + let pending = CS.mem c t.pending in + let t = remove c t in + let ex = Explanation.union ex c.explanation in + add ~pending ~ex (Constraint.subst rr nrr c.value) t + ) cs t + | exception Not_found -> t + + let notify r t = + match MX.find r t.args_map with + | cs -> + CS.fold (fun c t -> { t with pending = CS.add c t.pending }) cs t + | exception Not_found -> t + + let notify_leaf r t = + match MX.find r t.leaves_map with + | cs -> + CS.fold (fun c t -> { t with pending = CS.add c t.pending }) cs t + | exception Not_found -> t + + let fold_args f c acc = + MX.fold (fun r _ acc -> + f r acc + ) c.args_map acc + + let next_pending t = + let c = CS.choose t.pending in + c, { t with pending = CS.remove c t.pending } end