Skip to content

Commit

Permalink
Refactor AddArgsInfo into ExpandArgsInfo
Browse files Browse the repository at this point in the history
BEGIN_PUBLIC
Refactor AddArgsInfo into ExpandArgsInfo

This allows us to create a similar mechanism to the current toolchain, while maintaining type safety.
END_PUBLIC

PiperOrigin-RevId: 615939056
Change-Id: I9b6763150194f8a76dfd8da730a3e2d45accbe20
  • Loading branch information
Googler authored and copybara-github committed Mar 14, 2024
1 parent 69c9748 commit bbb0615
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 34 deletions.
27 changes: 18 additions & 9 deletions cc/toolchains/args.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""All providers for rule-based bazel toolchain config."""

load("//cc:cc_toolchain_config_lib.bzl", "flag_group")
load(
"//cc/toolchains/impl:collect.bzl",
"collect_action_types",
Expand All @@ -22,32 +23,41 @@ load(
load(
":cc_toolchain_info.bzl",
"ActionTypeSetInfo",
"AddArgsInfo",
"ArgsInfo",
"ArgsListInfo",
"ExpandArgsInfo",
"FeatureConstraintInfo",
)

visibility("public")

def _cc_args_impl(ctx):
add_args = [AddArgsInfo(
label = ctx.label,
args = tuple(ctx.attr.args),
files = depset([]),
)]
if not ctx.attr.args and not ctx.attr.env:
fail("cc_args requires at least one of args and env")

actions = collect_action_types(ctx.attr.actions)
files = collect_files(ctx.attr.data)
requires = collect_provider(ctx.attr.requires_any_of, FeatureConstraintInfo)

expand = None
if ctx.attr.args:
# TODO: This is temporary until cc_expand_args is implemented.
expand = ExpandArgsInfo(
label = ctx.label,
expand = tuple(),
iterate_over = None,
files = files,
requires_types = {},
legacy_flag_group = flag_group(flags = ctx.attr.args),
)

args = ArgsInfo(
label = ctx.label,
actions = actions,
requires_any_of = tuple(requires),
files = files,
args = add_args,
expand = expand,
env = ctx.attr.env,
files = files,
)
return [
args,
Expand All @@ -74,7 +84,6 @@ See @rules_cc//cc/toolchains/actions:all for valid options.
""",
),
"args": attr.string_list(
mandatory = True,
doc = """Arguments that should be added to the command-line.
These are evaluated in order, with earlier args appearing earlier in the
Expand Down
9 changes: 6 additions & 3 deletions cc/toolchains/cc_toolchain_info.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,16 @@ ActionTypeSetInfo = provider(
},
)

AddArgsInfo = provider(
ExpandArgsInfo = provider(
doc = "A provider representation of Args.add/add_all/add_joined parameters",
# @unsorted-dict-items
fields = {
"label": "(Label) The label defining this provider. Place in error messages to simplify debugging",
"args": "(Sequence[str]) The command-line arguments to add",
"expand": "(Sequence[ExpandArgsInfo]) The nested arg expansion. Mutually exclusive with args",
"iterate_over": "(Optional[str]) The variable to iterate over",
"files": "(depset[File]) The files required to use this variable",
"requires_types": "(dict[str, str]) A mapping from variables to their expected type name (not type). This means that we can require the generic type Option, rather than an Option[T]",
"legacy_flag_group": "(flag_group) The flag_group this corresponds to",
},
)

Expand All @@ -62,7 +65,7 @@ ArgsInfo = provider(
"label": "(Label) The label defining this provider. Place in error messages to simplify debugging",
"actions": "(depset[ActionTypeInfo]) The set of actions this is associated with",
"requires_any_of": "(Sequence[FeatureConstraintInfo]) This will be enabled if any of the listed predicates are met. Equivalent to with_features",
"args": "(Sequence[AddArgsInfo]) The command-line arguments to add.",
"expand": "(Optional[ExpandArgsInfo]) The args to expand. Equivalent to a flag group.",
"files": "(depset[File]) Files required for the args",
"env": "(dict[str, str]) Environment variables to apply",
},
Expand Down
20 changes: 10 additions & 10 deletions cc/toolchains/impl/legacy_converter.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ load(
legacy_env_set = "env_set",
legacy_feature = "feature",
legacy_feature_set = "feature_set",
legacy_flag_group = "flag_group",
legacy_flag_set = "flag_set",
legacy_tool = "tool",
legacy_with_feature_set = "with_feature_set",
Expand Down Expand Up @@ -50,25 +49,26 @@ def convert_feature_constraint(constraint):
not_features = sorted([ft.name for ft in constraint.none_of.to_list()]),
)

def _convert_add_arg(add_arg):
return [legacy_flag_group(flags = list(add_arg.args))]
def convert_args(args):
"""Converts an ArgsInfo to flag_sets and env_sets.
def _convert_args(args):
Args:
args: (ArgsInfo) The args to convert
Returns:
struct(flag_sets = List[flag_set], env_sets = List[env_sets])
"""
actions = _convert_actions(args.actions)
with_features = [
convert_feature_constraint(fc)
for fc in args.requires_any_of
]

flag_sets = []
if args.args:
flag_groups = []
for add_args in args.args:
flag_groups.extend(_convert_add_arg(add_args))
if args.expand != None:
flag_sets.append(legacy_flag_set(
actions = actions,
with_features = with_features,
flag_groups = flag_groups,
flag_groups = [args.expand.legacy_flag_group],
))

env_sets = []
Expand All @@ -93,7 +93,7 @@ def _convert_args_sequence(args_sequence):
flag_sets = []
env_sets = []
for args in args_sequence:
legacy_args = _convert_args(args)
legacy_args = convert_args(args)
flag_sets.extend(legacy_args.flag_sets)
env_sets.extend(legacy_args.env_sets)

Expand Down
11 changes: 11 additions & 0 deletions tests/rule_based_toolchain/args/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ util.helper_target(
env = {"BAR": "bar"},
)

util.helper_target(
cc_args,
name = "env_only",
actions = ["//tests/rule_based_toolchain/actions:all_compile"],
data = [
"//tests/rule_based_toolchain/testdata:file1",
"//tests/rule_based_toolchain/testdata:multiple",
],
env = {"BAR": "bar"},
)

analysis_test_suite(
name = "test_suite",
targets = TARGETS,
Expand Down
64 changes: 61 additions & 3 deletions tests/rule_based_toolchain/args/args_test.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,24 @@
# limitations under the License.
"""Tests for the cc_args rule."""

load(
"//cc:cc_toolchain_config_lib.bzl",
"env_entry",
"env_set",
"flag_group",
"flag_set",
)
load(
"//cc/toolchains:cc_toolchain_info.bzl",
"ActionTypeInfo",
"ArgsInfo",
"ArgsListInfo",
)
load(
"//cc/toolchains/impl:legacy_converter.bzl",
"convert_args",
)
load("//tests/rule_based_toolchain:subjects.bzl", "subjects")

visibility("private")

Expand All @@ -28,13 +40,17 @@ _SIMPLE_FILES = [
"tests/rule_based_toolchain/testdata/multiple2",
]

def _test_simple_args_impl(env, targets):
_CONVERTED_ARGS = subjects.struct(
flag_sets = subjects.collection,
env_sets = subjects.collection,
)

def _simple_test(env, targets):
simple = env.expect.that_target(targets.simple).provider(ArgsInfo)
simple.actions().contains_exactly([
targets.c_compile.label,
targets.cpp_compile.label,
])
simple.args().contains_exactly([targets.simple.label])
simple.env().contains_exactly({"BAR": "bar"})
simple.files().contains_exactly(_SIMPLE_FILES)

Expand All @@ -44,12 +60,54 @@ def _test_simple_args_impl(env, targets):
c_compile.args().contains_exactly([targets.simple[ArgsInfo]])
c_compile.files().contains_exactly(_SIMPLE_FILES)

converted = env.expect.that_value(
convert_args(targets.simple[ArgsInfo]),
factory = _CONVERTED_ARGS,
)
converted.env_sets().contains_exactly([env_set(
actions = ["c_compile", "cpp_compile"],
env_entries = [env_entry(key = "BAR", value = "bar")],
)])

converted.flag_sets().contains_exactly([flag_set(
actions = ["c_compile", "cpp_compile"],
flag_groups = [flag_group(flags = ["--foo", "foo"])],
)])

def _env_only_test(env, targets):
env_only = env.expect.that_target(targets.env_only).provider(ArgsInfo)
env_only.actions().contains_exactly([
targets.c_compile.label,
targets.cpp_compile.label,
])
env_only.env().contains_exactly({"BAR": "bar"})
env_only.files().contains_exactly(_SIMPLE_FILES)

c_compile = env.expect.that_target(targets.simple).provider(ArgsListInfo).by_action().get(
targets.c_compile[ActionTypeInfo],
)
c_compile.files().contains_exactly(_SIMPLE_FILES)

converted = env.expect.that_value(
convert_args(targets.env_only[ArgsInfo]),
factory = _CONVERTED_ARGS,
)
converted.env_sets().contains_exactly([env_set(
actions = ["c_compile", "cpp_compile"],
env_entries = [env_entry(key = "BAR", value = "bar")],
)])

converted.flag_sets().contains_exactly([])

TARGETS = [
":simple",
":env_only",
"//tests/rule_based_toolchain/actions:c_compile",
"//tests/rule_based_toolchain/actions:cpp_compile",
]

# @unsorted-dict-items
TESTS = {
"simple_test": _test_simple_args_impl,
"simple_test": _simple_test,
"env_only_test_test": _env_only_test,
}
38 changes: 29 additions & 9 deletions tests/rule_based_toolchain/subjects.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ load(
"ActionTypeConfigSetInfo",
"ActionTypeInfo",
"ActionTypeSetInfo",
"AddArgsInfo",
"ArgsInfo",
"ArgsListInfo",
"ExpandArgsInfo",
"FeatureConstraintInfo",
"FeatureInfo",
"FeatureSetInfo",
Expand All @@ -40,6 +40,10 @@ visibility("//tests/rule_based_toolchain/...")
# This makes it rather awkward for copybara.
runfiles_subject = lambda value, meta: _subjects.depset_file(value.files, meta = meta)

# The string type has .equals(), which is all we can really do for an unknown
# type.
unknown_subject = _subjects.str

# buildifier: disable=name-conventions
_ActionTypeFactory = generate_factory(
ActionTypeInfo,
Expand Down Expand Up @@ -102,13 +106,27 @@ _FeatureConstraintFactory = generate_factory(
),
)

_EXPAND_ARGS_FLAGS = dict(
expand = None,
files = _subjects.depset_file,
iterate_over = optional_subject(_subjects.str),
legacy_flag_group = unknown_subject,
requires_types = _subjects.dict,
)

# buildifier: disable=name-conventions
_AddArgsFactory = generate_factory(
AddArgsInfo,
"AddArgsInfo",
dict(
args = _subjects.collection,
files = _subjects.depset_file,
_FakeExpandArgsFactory = generate_factory(
ExpandArgsInfo,
"ExpandArgsInfo",
_EXPAND_ARGS_FLAGS,
)

# buildifier: disable=name-conventions
_ExpandArgsFactory = generate_factory(
ExpandArgsInfo,
"ExpandArgsInfo",
_EXPAND_ARGS_FLAGS | dict(
expand = ProviderSequence(_FakeExpandArgsFactory),
),
)

Expand All @@ -118,9 +136,10 @@ _ArgsFactory = generate_factory(
"ArgsInfo",
dict(
actions = ProviderDepset(_ActionTypeFactory),
args = ProviderSequence(_AddArgsFactory),
env = _subjects.dict,
files = _subjects.depset_file,
# Use .factory so it's not inlined.
expand = optional_subject(_ExpandArgsFactory.factory),
requires_any_of = ProviderSequence(_FeatureConstraintFactory),
),
)
Expand Down Expand Up @@ -201,7 +220,7 @@ _ToolchainConfigFactory = generate_factory(
FACTORIES = [
_ActionTypeFactory,
_ActionTypeSetFactory,
_AddArgsFactory,
_ExpandArgsFactory,
_ArgsFactory,
_ArgsListFactory,
_MutuallyExclusiveCategoryFactory,
Expand All @@ -217,6 +236,7 @@ result_fn_wrapper = _result_fn_wrapper

subjects = struct(
**(structs.to_dict(_subjects) | dict(
unknown = unknown_subject,
result = result_subject,
optional = optional_subject,
struct = struct_subject,
Expand Down

0 comments on commit bbb0615

Please sign in to comment.