diff --git a/src/kcas/kcas.ml b/src/kcas/kcas.ml index 799e12cc..fb163659 100644 --- a/src/kcas/kcas.ml +++ b/src/kcas/kcas.ml @@ -685,83 +685,100 @@ module Xt = struct (* Fenceless is safe as we are accessing a private location. *) xt_r.mode == `Obstruction_free && 0 <= loc.id - let[@inline] update_new loc f xt lt gt = - (* Fenceless is safe inside transactions as each log update has a fence. *) + type (_, _) up = + | Get : (unit, 'a) up + | Fetch_and_add : (int, int) up + | Exchange : ('a, 'a) up + | Compare_and_swap : ('a * 'a, 'a) up + | Fn : ('a -> 'a, 'a) up + + let[@inline] update : + type c a. + 'x t -> a loc -> c -> (c, a) up -> tree -> tree -> a state -> a -> a = + fun xt loc c up lt gt state before -> + let after = + match up with + | Get -> before + | Fetch_and_add -> before + c + | Exchange -> c + | Compare_and_swap -> if fst c == before then snd c else before + | Fn -> begin + try c before + with exn -> + tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] }); + raise exn + end + in + let state = + if before == after && is_obstruction_free xt loc then state + else { before; after; which = W xt; awaiters = [] } + in + tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] }); + before + + let[@inline] update_new : + type c a. 'x t -> a loc -> c -> (c, a) up -> tree -> tree -> a = + fun xt loc c up lt gt -> let state = fenceless_get (as_atomic loc) in let before = eval state in - match f before with - | after -> - let state = - if before == after && is_obstruction_free xt loc then state - else { before; after; which = W xt; awaiters = [] } - in - tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] }); - before - | exception exn -> - tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] }); - raise exn - - let[@inline] update_top loc f xt state' lt gt = - let state = Obj.magic state' in - if is_cmp xt state then begin - let before = eval state in - let after = f before in - let state = - if before == after then state - else { before; after; which = W xt; awaiters = [] } + update xt loc c up lt gt state before + + let[@inline] update_top : + type c a. 'x t -> a loc -> c -> (c, a) up -> 'b state -> tree -> tree -> a + = + fun xt loc c up state' lt gt -> + let state : a state = Obj.magic state' in + if is_cmp xt state then update xt loc c up lt gt state (eval state) + else + let before = state.after in + let after = + match up with + | Get -> before + | Fetch_and_add -> before + c + | Exchange -> c + | Compare_and_swap -> if fst c == before then snd c else before + | Fn -> c before in + let state = if before == after then state else { state with after } in tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] }); before - end - else - let current = state.after in - let state = { state with after = f current } in - tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] }); - current - let[@inline] unsafe_update ~xt loc f = + let unsafe_update ~xt loc c up = let loc = Loc.to_loc loc in maybe_validate_log xt; let x = loc.id in match !(tree_as_ref xt) with - | T Leaf -> update_new loc f xt (T Leaf) (T Leaf) + | T Leaf -> update_new xt loc c up (T Leaf) (T Leaf) | T (Node { loc = a; lt = T Leaf; _ }) as tree when x < a.id -> - update_new loc f xt (T Leaf) tree + update_new xt loc c up (T Leaf) tree | T (Node { loc = a; gt = T Leaf; _ }) as tree when a.id < x -> - update_new loc f xt tree (T Leaf) + update_new xt loc c up tree (T Leaf) | T (Node { loc = a; state; lt; gt; _ }) when Obj.magic a == loc -> - update_top loc f xt state lt gt + update_top xt loc c up state lt gt | tree -> begin match splay ~hit_parent:false x tree with - | l, T Leaf, r -> update_new loc f xt l r - | l, T (Node node_r), r -> update_top loc f xt node_r.state l r + | l, T Leaf, r -> update_new xt loc c up l r + | l, T (Node node_r), r -> update_top xt loc c up node_r.state l r end - let[@inline] protect xt f x = - let tree = !(tree_as_ref xt) in - let y = f x in - assert (!(tree_as_ref xt) == tree); - y - - let get ~xt loc = unsafe_update ~xt loc Fun.id - let set ~xt loc after = unsafe_update ~xt loc (fun _ -> after) |> ignore - let modify ~xt loc f = unsafe_update ~xt loc (protect xt f) |> ignore + let get ~xt loc = unsafe_update ~xt loc () Get + let set ~xt loc after = unsafe_update ~xt loc after Exchange |> ignore + let modify ~xt loc f = unsafe_update ~xt loc f Fn |> ignore let compare_and_swap ~xt loc before after = - unsafe_update ~xt loc (fun actual -> - if actual == before then after else actual) + unsafe_update ~xt loc (before, after) Compare_and_swap let compare_and_set ~xt loc before after = compare_and_swap ~xt loc before after == before - let exchange ~xt loc after = unsafe_update ~xt loc (fun _ -> after) - let fetch_and_add ~xt loc n = unsafe_update ~xt loc (( + ) n) - let incr ~xt loc = unsafe_update ~xt loc inc |> ignore - let decr ~xt loc = unsafe_update ~xt loc dec |> ignore - let update ~xt loc f = unsafe_update ~xt loc (protect xt f) + let exchange ~xt loc after = unsafe_update ~xt loc after Exchange + let fetch_and_add ~xt loc n = unsafe_update ~xt loc n Fetch_and_add + let incr ~xt loc = unsafe_update ~xt loc 1 Fetch_and_add |> ignore + let decr ~xt loc = unsafe_update ~xt loc (-1) Fetch_and_add |> ignore + let update ~xt loc f = unsafe_update ~xt loc f Fn let swap ~xt l1 l2 = set ~xt l1 @@ exchange ~xt l2 @@ get ~xt l1 - let unsafe_modify ~xt loc f = unsafe_update ~xt loc f |> ignore - let unsafe_update ~xt loc f = unsafe_update ~xt loc f + let unsafe_modify ~xt loc f = unsafe_update ~xt loc f Fn |> ignore + let unsafe_update ~xt loc f = unsafe_update ~xt loc f Fn let[@inline] to_blocking ~xt tx = match tx ~xt with None -> Retry.later () | Some value -> value