Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Help Using Python to Extract Minimal "Join Tree" from Substrait Plan #64

Open
nathanwilk7 opened this issue Nov 16, 2024 · 0 comments
Open

Comments

@nathanwilk7
Copy link

I'm working on tool which needs to be able to extract a skeleton "join tree" from a given substrait query plan. For now, it will assume the plan was generated by SQL queries of the shape: SELECT * FROM r1, r2, ..., rN [WHERE ...].

Could I ask for some help figuring out the TODO's I've left in the code below (which has a simple script that creates some dummy tables using duckdb, then converts a cross join query on those tables into a substrait plan, and then extracts a dict-based join tree from that substrait plan.

For example, if I run the script below, it generates a three relation cross join query SELECT * FROM r1, r2, r3 and creates r1 with ten rows, r2 with one hundred rows, and r3 with 1 row. The DuckDB/substrait plan is then converted into this dict: {'left': 'r2', 'right': {'left': 'r1', 'right': 'r3'}} (note that the order of r2, r1, r3 change if you change the cardinalities of the relations).

Questions

  1. For queries of the "shape" above (eg SELECT * FROM r1, r2, ..., rN [WHERE ...], can I assume the root plan.relations will always have a length of 1 (eg the "result relation)?
  2. How can I avoid the implement the recur over fields/node as an iterable in a correct way (instead of just taking the first non-None result as I do now)? Perhaps I just need to spend more time reading the spec.
  3. Any other feedback/ideas to make this more robust/simpler?

Simplified/runnable example:

import duckdb
# print(duckdb.__version__)
# 1.1.2

# import substrait
# print(substrait.__version__)
# 0.23.0

from substrait.proto import (
    CrossRel,
    JoinRel,
    Plan,
    ReadRel,
)

BOOL_TYPE = type(True)
INT_TYPE = type(0)
STRING_TYPE = type('')
BASE_TYPES = {
    BOOL_TYPE,
    INT_TYPE,
    STRING_TYPE,
}

READ_REL_TYPE = ReadRel
BASE_RELATION_TYPES = {
    READ_REL_TYPE,
}

CROSS_REL_TYPE = CrossRel
JOIN_REL_TYPE = JoinRel
JOIN_TYPES = {
    CROSS_REL_TYPE,
    JOIN_REL_TYPE,
}

def plan_to_join_tree(plan: Plan) -> dict:
    # TODO can i assume the length of the plan root's relations is 1 for queries like 'select * from r1, r2, ..., r2 where ...'?
    assert len(plan.relations) == 1
    input = plan.relations[0].root.input
    def recur(node):
        node_type = type(node)
        if node_type in BASE_TYPES:
            return None
        
        if node_type in BASE_RELATION_TYPES:
            read_type = node.WhichOneof('read_type')
            if read_type == 'named_table':
                return '.'.join(node.named_table.names)
            else:
                raise NotImplementedError(f'unimplemented readrel type: {read_type}')
        
        if not hasattr(node, 'ListFields'):
            raise Exception(f'UNEXPECTED TYPE, {node}, {type(node)}')

        fields = node.ListFields()
        if node_type in JOIN_TYPES:
            field_names = set(desc.name for desc, _ in fields)
            if not ('left' in field_names and 'right' in field_names):
                raise Exception(f'bad join type: {node}, {type(node)}, {field_names}')
            return {
                'left': recur(node.left),
                'right': recur(node.right),
            }

        # TODO how to handle multiple fields/iterable returning not None
        for _, field in fields:
            res = recur(field)
            if res is not None:
                return res
        if hasattr(node, '__len__'):
            for el in node:
                res = recur(el)
                if res is not None:
                    return res


    return recur(input)

DEFAULT_TABLE_SIZES = {
    'r1': 1,
    'r2': 10,
}

def duckdb_substrait_plan(table_sizes=DEFAULT_TABLE_SIZES):
    table_names = list(table_sizes.keys())
    con = duckdb.connect("TwoRelCross.duckdb")
    con.install_extension("substrait")
    con.load_extension("substrait")
    # TODO avoid sql string injection
    for table_name in table_names:
        con.execute(query=f"create table {table_name} (c1 integer)")
    for table_name, table_size in table_sizes.items():
        con.execute(query=f"insert into {table_name} values ({'),('.join(map(str, range(table_size)))})")
    for table_name in table_names:
        con.execute(query=f"vacuum {table_name}")
        con.execute(query=f"vacuum analyze {table_name}")
        con.execute(query=f"analyze {table_name}")
    substrait_proto_bytes = con.get_substrait(query=f"select * from {','.join(table_names)}").fetchone()[0]
    p = Plan()
    p.ParseFromString(substrait_proto_bytes)
    return p

def main():
    substrait_plan = duckdb_substrait_plan()
    join_tree = plan_to_join_tree(substrait_plan)
    print(join_tree)

if __name__ == '__main__':
    join_tree = main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant