Skip to content

Commit

Permalink
Merge pull request #453 from Burnleydev1/context-free
Browse files Browse the repository at this point in the history
Allowing multiple errors to be reported in one pass of the context_free phase
  • Loading branch information
NathanReb authored Jan 31, 2024
2 parents 98fb4d1 + c3f3c7a commit edc51a4
Show file tree
Hide file tree
Showing 12 changed files with 237 additions and 56 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
unreleased
-------------------

- raising an exception does no longer cancel the whole context free phase(#453, @burnleydev1)

- Sort embedded errors that are appended to the AST by location so the compiler
reports the one closer to the beginning of the file first. (#463, @NathanReb)

Expand Down
3 changes: 3 additions & 0 deletions src/common.ml
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ let mk_named_sig ~loc ~sg_name ~handle_polymorphic_variant = function
[ Pwith_typesubst (Located.lident ~loc "t", for_subst) ]))
| _ -> None

let exn_to_loc_error exn =
match Location.Error.of_exn exn with Some error -> error | None -> raise exn

module With_errors = struct
type 'a t = 'a * Location.Error.t list

Expand Down
3 changes: 3 additions & 0 deletions src/common.mli
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ val mk_named_sig :
It will take care of giving fresh names to unnamed type parameters. *)

val exn_to_loc_error : exn -> Location.Error.t
(** Convert [exn] to a located error if possible or reraise it otherwise *)

module With_errors : sig
type 'a t = 'a * Location.Error.t list

Expand Down
89 changes: 59 additions & 30 deletions src/context_free.ml
Original file line number Diff line number Diff line change
Expand Up @@ -197,43 +197,49 @@ module Generated_code_hook = struct
| _ -> t.f context { loc with loc_start = loc.loc_end } x
end

let rec map_node_rec context ts super_call loc base_ctxt x =
let rec map_node_rec context ts super_call loc base_ctxt x ~embed_errors =
let ctxt =
Expansion_context.Extension.make ~extension_point_loc:loc ~base:base_ctxt ()
in
match EC.get_extension context x with
| None -> super_call base_ctxt x
| Some (ext, attrs) -> (
E.For_context.convert_res ts ~ctxt ext
|> With_errors.of_result ~default:None
(try
E.For_context.convert_res ts ~ctxt ext
|> With_errors.of_result ~default:None
with exn when embed_errors -> (None, [ exn_to_loc_error exn ]))
>>= fun converted ->
match converted with
| None -> super_call base_ctxt x
| Some x ->
EC.merge_attributes_res context x attrs
|> With_errors.of_result ~default:x
>>= fun x -> map_node_rec context ts super_call loc base_ctxt x)
>>= fun x ->
map_node_rec context ts super_call loc base_ctxt x ~embed_errors)

let map_node context ts super_call loc base_ctxt x ~hook =
let map_node context ts super_call loc base_ctxt x ~hook ~embed_errors =
let ctxt =
Expansion_context.Extension.make ~extension_point_loc:loc ~base:base_ctxt ()
in
match EC.get_extension context x with
| None -> super_call base_ctxt x
| Some (ext, attrs) -> (
E.For_context.convert_res ts ~ctxt ext
|> With_errors.of_result ~default:None
(try
E.For_context.convert_res ts ~ctxt ext
|> With_errors.of_result ~default:None
with exn when embed_errors -> (None, [ exn_to_loc_error exn ]))
>>= fun converted ->
match converted with
| None -> super_call base_ctxt x
| Some x ->
map_node_rec context ts super_call loc base_ctxt
(EC.merge_attributes context x attrs)
~embed_errors
>>| fun generated_code ->
Generated_code_hook.replace hook context loc (Single generated_code);
generated_code)

let rec map_nodes context ts super_call get_loc base_ctxt l ~hook
let rec map_nodes context ts super_call get_loc base_ctxt l ~hook ~embed_errors
~in_generated_code =
match l with
| [] -> return []
Expand All @@ -244,32 +250,34 @@ let rec map_nodes context ts super_call get_loc base_ctxt l ~hook
same order as they appear in the source file. *)
super_call base_ctxt x >>= fun x ->
map_nodes context ts super_call get_loc base_ctxt l ~hook
~in_generated_code
~embed_errors ~in_generated_code
>>| fun l -> x :: l
| Some (ext, attrs) -> (
let extension_point_loc = get_loc x in
let ctxt =
Expansion_context.Extension.make ~extension_point_loc
~base:base_ctxt ()
in
E.For_context.convert_inline_res ts ~ctxt ext
|> With_errors.of_result ~default:None
(try
E.For_context.convert_inline_res ts ~ctxt ext
|> With_errors.of_result ~default:None
with exn when embed_errors -> (None, [ exn_to_loc_error exn ]))
>>= function
| None ->
super_call base_ctxt x >>= fun x ->
map_nodes context ts super_call get_loc base_ctxt l ~hook
~in_generated_code
~embed_errors ~in_generated_code
>>| fun l -> x :: l
| Some converted ->
((), attributes_errors attrs) >>= fun () ->
map_nodes context ts super_call get_loc base_ctxt converted ~hook
~in_generated_code:true
~embed_errors ~in_generated_code:true
>>= fun generated_code ->
if not in_generated_code then
Generated_code_hook.replace hook context extension_point_loc
(Many generated_code);
map_nodes context ts super_call get_loc base_ctxt l ~hook
~in_generated_code
~embed_errors ~in_generated_code
>>| fun code -> generated_code @ code))

let map_nodes = map_nodes ~in_generated_code:false
Expand Down Expand Up @@ -341,7 +349,8 @@ let context_free_attribute_modification ~loc =
This complexity is horrible, but in practice we don't care as [attrs] is always a list
of one element; it only has [@@deriving].
*)
let handle_attr_group_inline attrs rf ~items ~expanded_items ~loc ~base_ctxt =
let handle_attr_group_inline attrs rf ~items ~expanded_items ~loc ~base_ctxt
~embed_errors =
List.fold_left attrs ~init:(return [])
~f:(fun acc (Rule.Attr_group_inline.T group) ->
acc >>= fun acc ->
Expand All @@ -351,15 +360,18 @@ let handle_attr_group_inline attrs rf ~items ~expanded_items ~loc ~base_ctxt =
| None, None -> return acc
| None, Some _ | Some _, None ->
context_free_attribute_modification ~loc |> of_result ~default:acc
| Some values, Some _ ->
| Some values, Some _ -> (
let ctxt =
Expansion_context.Deriver.make ~derived_item_loc:loc
~inline:group.expect ~base:base_ctxt ()
in
let expect_items = group.expand ~ctxt rf expanded_items values in
return (expect_items :: acc))
try
let expect_items = group.expand ~ctxt rf expanded_items values in
return (expect_items :: acc)
with exn when embed_errors -> (acc, [ exn_to_loc_error exn ])))

let handle_attr_inline attrs ~item ~expanded_item ~loc ~base_ctxt =
let handle_attr_inline attrs ~item ~expanded_item ~loc ~base_ctxt ~embed_errors
=
List.fold_left attrs ~init:(return []) ~f:(fun acc (Rule.Attr_inline.T a) ->
acc >>= fun acc ->
Attribute.get_res a.attribute item |> of_result ~default:None
Expand All @@ -370,13 +382,15 @@ let handle_attr_inline attrs ~item ~expanded_item ~loc ~base_ctxt =
| None, None -> return acc
| None, Some _ | Some _, None ->
context_free_attribute_modification ~loc |> of_result ~default:acc
| Some value, Some _ ->
| Some value, Some _ -> (
let ctxt =
Expansion_context.Deriver.make ~derived_item_loc:loc
~inline:a.expect ~base:base_ctxt ()
in
let expect_items = a.expand ~ctxt expanded_item value in
return (expect_items :: acc))
try
let expect_items = a.expand ~ctxt expanded_item value in
return (expect_items :: acc)
with exn when embed_errors -> (acc, [ exn_to_loc_error exn ])))

module Expect_mismatch_handler = struct
type t = {
Expand All @@ -387,7 +401,7 @@ module Expect_mismatch_handler = struct
end

class map_top_down ?(expect_mismatch_handler = Expect_mismatch_handler.nop)
?(generated_code_hook = Generated_code_hook.nop) rules =
?(generated_code_hook = Generated_code_hook.nop) rules ~embed_errors =
let hook = generated_code_hook in

let special_functions =
Expand Down Expand Up @@ -448,8 +462,10 @@ class map_top_down ?(expect_mismatch_handler = Expect_mismatch_handler.nop)
|> sort_attr_inline |> Rule.Attr_inline.split_normal_and_expect
in

let map_node = map_node ~hook in
let map_nodes = map_nodes ~hook in
let map_node = map_node ~hook ~embed_errors in
let map_nodes = map_nodes ~hook ~embed_errors in
let handle_attr_group_inline = handle_attr_group_inline ~embed_errors in
let handle_attr_inline = handle_attr_inline ~embed_errors in

object (self)
inherit Ast_traverse.map_with_expansion_context_and_errors as super
Expand Down Expand Up @@ -499,7 +515,12 @@ class map_top_down ?(expect_mismatch_handler = Expect_mismatch_handler.nop)
| None ->
self#pexp_apply_without_traversing_function base_ctxt e func args
| Some pattern -> (
match pattern e with
let generated_code =
try return (pattern e)
with exn when embed_errors -> (None, [ exn_to_loc_error exn ])
in
generated_code >>= fun expr ->
match expr with
| None ->
self#pexp_apply_without_traversing_function base_ctxt e func
args
Expand All @@ -508,12 +529,20 @@ class map_top_down ?(expect_mismatch_handler = Expect_mismatch_handler.nop)
match Hashtbl.find_opt special_functions id.txt with
| None -> super#expression base_ctxt e
| Some pattern -> (
match pattern e with
let generated_code =
try return (pattern e)
with exn when embed_errors -> (None, [ exn_to_loc_error exn ])
in
generated_code >>= fun expr ->
match expr with
| None -> super#expression base_ctxt e
| Some e -> self#expression base_ctxt e))
| Pexp_constant (Pconst_integer (s, Some c)) ->
expand_constant Integer c s
| Pexp_constant (Pconst_float (s, Some c)) -> expand_constant Float c s
| Pexp_constant (Pconst_integer (s, Some c)) -> (
try expand_constant Integer c s
with exn when embed_errors -> (e, [ exn_to_loc_error exn ]))
| Pexp_constant (Pconst_float (s, Some c)) -> (
try expand_constant Float c s
with exn when embed_errors -> (e, [ exn_to_loc_error exn ]))
| _ -> super#expression base_ctxt e

(* Pre-conditions:
Expand Down
1 change: 1 addition & 0 deletions src/context_free.mli
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class map_top_down :
-> ?generated_code_hook:
Generated_code_hook.t (* default: Generated_code_hook.nop *)
-> Rule.t list
-> embed_errors:bool
-> object
inherit Ast_traverse.map_with_expansion_context_and_errors
end
36 changes: 15 additions & 21 deletions src/driver.ml
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,12 @@ module Transform = struct
let last = get_loc (last x l) in
Some { first with loc_end = last.loc_end }

let merge_into_generic_mappers t ~hook ~expect_mismatch_handler ~tool_name
~input_name =
let merge_into_generic_mappers t ~embed_errors ~hook ~expect_mismatch_handler
~tool_name ~input_name =
let { rules; enclose_impl; enclose_intf; impl; intf; _ } = t in
let map =
new Context_free.map_top_down
rules ~generated_code_hook:hook ~expect_mismatch_handler
rules ~embed_errors ~generated_code_hook:hook ~expect_mismatch_handler
in
let gen_header_and_footer context whole_loc f =
let header, footer = f whole_loc in
Expand Down Expand Up @@ -455,7 +455,8 @@ let debug_dropped_attribute name ~old_dropped ~new_dropped =
print_diff "disappeared" new_dropped old_dropped;
print_diff "reappeared" old_dropped new_dropped

let get_whole_ast_passes ~hook ~expect_mismatch_handler ~tool_name ~input_name =
let get_whole_ast_passes ~embed_errors ~hook ~expect_mismatch_handler ~tool_name
~input_name =
let cts =
match !apply_list with
| None -> List.rev !Transform.all
Expand Down Expand Up @@ -484,7 +485,7 @@ let get_whole_ast_passes ~hook ~expect_mismatch_handler ~tool_name ~input_name =
if !no_merge then
List.map transforms
~f:
(Transform.merge_into_generic_mappers ~hook ~tool_name
(Transform.merge_into_generic_mappers ~embed_errors ~hook ~tool_name
~expect_mismatch_handler ~input_name)
else
(let get_enclosers ~f =
Expand Down Expand Up @@ -515,8 +516,8 @@ let get_whole_ast_passes ~hook ~expect_mismatch_handler ~tool_name ~input_name =
let footers = List.concat (List.rev footers) in
(headers, footers))
in
Transform.builtin_of_context_free_rewriters ~rules ~hook
~expect_mismatch_handler
Transform.builtin_of_context_free_rewriters ~rules ~embed_errors
~hook ~expect_mismatch_handler
~enclose_impl:(merge_encloser impl_enclosers)
~enclose_intf:(merge_encloser intf_enclosers)
~tool_name ~input_name
Expand All @@ -529,21 +530,15 @@ let get_whole_ast_passes ~hook ~expect_mismatch_handler ~tool_name ~input_name =
let apply_transforms ~tool_name ~file_path ~field ~lint_field ~dropped_so_far
~hook ~expect_mismatch_handler ~input_name ~embed_errors ast =
let cts =
get_whole_ast_passes ~tool_name ~hook ~expect_mismatch_handler ~input_name
get_whole_ast_passes ~tool_name ~embed_errors ~hook ~expect_mismatch_handler
~input_name
in
let finish (ast, _dropped, lint_errors, errors) =
( ast,
List.map lint_errors ~f:(fun (loc, s) ->
Common.attribute_of_warning loc s),
errors )
in

let exn_to_error exn =
match Location.Error.of_exn exn with
| None -> raise exn
| Some error -> error
in

let acc =
List.fold_left cts ~init:(ast, [], [], [])
~f:(fun (ast, dropped, (lint_errors : _ list), errors) (ct : Transform.t)
Expand All @@ -563,15 +558,15 @@ let apply_transforms ~tool_name ~file_path ~field ~lint_field ~dropped_so_far
| Some f -> (
try (lint_errors @ f ctxt ast, errors)
with exn when embed_errors ->
(lint_errors, exn_to_error exn :: errors))
(lint_errors, exn_to_loc_error exn :: errors))
in
match field ct with
| None -> (ast, dropped, lint_errors, errors)
| Some f ->
let (ast, more_errors), errors =
try (f ctxt ast, errors)
with exn when embed_errors ->
((ast, []), exn_to_error exn :: errors)
((ast, []), exn_to_loc_error exn :: errors)
in
let dropped =
if !debug_attribute_drop then (
Expand Down Expand Up @@ -607,20 +602,19 @@ let error_to_extension error ~(kind : Kind.t) =
| Impl -> Intf_or_impl.Impl [ error_to_str_extension error ]

let exn_to_extension exn ~(kind : Kind.t) =
match Location.Error.of_exn exn with
| None -> raise exn
| Some error -> error_to_extension error ~kind
exn_to_loc_error exn |> error_to_extension ~kind

(* +-----------------------------------------------------------------+
| Actual rewriting of structure/signatures |
+-----------------------------------------------------------------+ *)

let print_passes () =
let tool_name = "ppxlib_driver" in
let embed_errors = false in
let hook = Context_free.Generated_code_hook.nop in
let expect_mismatch_handler = Context_free.Expect_mismatch_handler.nop in
let cts =
get_whole_ast_passes ~hook ~expect_mismatch_handler ~tool_name
get_whole_ast_passes ~embed_errors ~hook ~expect_mismatch_handler ~tool_name
~input_name:None
in
if !perform_checks then
Expand Down
13 changes: 13 additions & 0 deletions test/driver/exception_handling/constant_type.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
open Ppxlib

let kind = Context_free.Rule.Constant_kind.Integer

let rewriter loc s =
Location.raise_errorf ~loc
"A raised located error in the constant rewriting transformation." s

let rule = Context_free.Rule.constant kind 'g' rewriter;;

Driver.register_transformation ~rules:[ rule ] "constant"

let () = Driver.standalone ()
Loading

0 comments on commit edc51a4

Please sign in to comment.