Skip to content

Commit

Permalink
Fix code to use Expressions fields in VirtualTable instead of Literal
Browse files Browse the repository at this point in the history
  • Loading branch information
anshuldata committed Nov 5, 2024
1 parent 5ba7f5e commit 55ff889
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 34 deletions.
37 changes: 37 additions & 0 deletions expr/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ import (
//
// NewScalarFunc(reg, id, nil, MustExpr(NewRootFieldRef(...)),
// MustExpr(NewScalarFunc(...)))

type (
// VirtualTableExpressionValue is a slice of other expression where each
// element in the slice is a different field in the struct
VirtualTableExpressionValue []Expression
)

func MustExpr(e Expression, err error) Expression {
if err != nil {
panic(err)
Expand Down Expand Up @@ -1603,3 +1610,33 @@ func (ex *Extended) ToProto() *proto.ExtendedExpression {
ReferredExpr: refs,
}
}

func VirtualTableExpressionFromProto(s *proto.Expression_Nested_Struct, reg ExtensionRegistry) (VirtualTableExpressionValue, error) {
fields := make(VirtualTableExpressionValue, len(s.Fields))
for i, f := range s.Fields {
val, err := ExprFromProto(f, nil, reg)
if err != nil {
return nil, err
}

Check warning on line 1620 in expr/expression.go

View check run for this annotation

Codecov / codecov/patch

expr/expression.go#L1619-L1620

Added lines #L1619 - L1620 were not covered by tests
fields[i] = val
}
return fields, nil
}

func VirtualTableExprFromLiteralProto(s *proto.Expression_Literal_Struct) VirtualTableExpressionValue {
fields := make(VirtualTableExpressionValue, len(s.Fields))
for i, f := range s.Fields {
fields[i] = LiteralFromProto(f)
}
return fields

Check warning on line 1631 in expr/expression.go

View check run for this annotation

Codecov / codecov/patch

expr/expression.go#L1626-L1631

Added lines #L1626 - L1631 were not covered by tests
}

func (s VirtualTableExpressionValue) ToProto() *proto.Expression_Nested_Struct {
fields := make([]*proto.Expression, len(s))
for i, f := range s {
fields[i] = f.ToProto()
}
return &proto.Expression_Nested_Struct{
Fields: fields,
}

Check warning on line 1641 in expr/expression.go

View check run for this annotation

Codecov / codecov/patch

expr/expression.go#L1634-L1641

Added lines #L1634 - L1641 were not covered by tests
}
40 changes: 40 additions & 0 deletions expr/expressions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,43 @@ func TestRoundTripExtendedExpression(t *testing.T) {
assert.Truef(t, pb.Equal(&ex, out), "expected: %s\ngot: %s", &ex, out)
}
}

func TestVirtualTableExpressionFromProto(t *testing.T) {
// define extensions with no plan for now
const planExt = `{
"extensionUris": [
{
"extensionUriAnchor": 1,
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
}
],
"extensions": [
{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 2,
"name": "add:i32_i32"
}
}
],
"relations": []
}`

var plan proto.Plan
if err := protojson.Unmarshal([]byte(planExt), &plan); err != nil {
panic(err)
}

// get the extension set
extSet := ext.GetExtensionSet(&plan)
literal1 := expr.NewPrimitiveLiteral(int32(1), false)
expr1 := literal1.ToProto()

reg := expr.NewExtensionRegistry(extSet, &ext.DefaultCollection)
rows := &proto.Expression_Nested_Struct{Fields: []*proto.Expression{
expr1,
}}
exprRows, err := expr.VirtualTableExpressionFromProto(rows, reg)
require.NoError(t, err)
require.Len(t, exprRows, 1)
}
19 changes: 0 additions & 19 deletions expr/literals.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,6 @@ type (
Null types.Type
)

func StructLiteralFromProto(s *proto.Expression_Literal_Struct) StructLiteralValue {
fields := make(StructLiteralValue, len(s.Fields))
for i, f := range s.Fields {
fields[i] = LiteralFromProto(f)
}
return fields
}

func (s StructLiteralValue) ToProto() *proto.Expression_Literal_Struct {
fields := make([]*proto.Expression_Literal, len(s))
for i, f := range s {
fields[i] = f.ToProtoLiteral()
}

return &proto.Expression_Literal_Struct{
Fields: fields,
}
}

// Literal represents a specific literal of some type which could also
// be a typed null or a nested type like a struct/map/list.
//
Expand Down
13 changes: 13 additions & 0 deletions plan/builders.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,19 @@ func (b *builder) NamedScan(tableName []string, schema types.NamedStruct) *Named
}

func (b *builder) VirtualTableRemap(fieldNames []string, remap []int32, values ...expr.StructLiteralValue) (*VirtualTableReadRel, error) {
// convert Literal to Expression
exprs := make([]expr.VirtualTableExpressionValue, 0)
for _, row := range values {
rowExpr := make(expr.VirtualTableExpressionValue, 0)
for _, col := range row {
rowExpr = append(rowExpr, col)
}
exprs = append(exprs, rowExpr)
}
return b.VirtualTableFromExprRemap(fieldNames, remap, exprs...)
}

func (b *builder) VirtualTableFromExprRemap(fieldNames []string, remap []int32, values ...expr.VirtualTableExpressionValue) (*VirtualTableReadRel, error) {
if len(values) == 0 {
return nil, fmt.Errorf("%w: must provide at least one set of values for virtual table", substraitgo.ErrInvalidRel)
}
Expand Down
16 changes: 13 additions & 3 deletions plan/plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,19 @@ func RelFromProto(rel *proto.Rel, reg expr.ExtensionRegistry) (Rel, error) {
advExtension: readType.NamedTable.AdvancedExtension,
}
case *proto.ReadRel_VirtualTable_:
values := make([]expr.StructLiteralValue, len(readType.VirtualTable.Values))
for i, v := range readType.VirtualTable.Values {
values[i] = expr.StructLiteralFromProto(v)
if len(readType.VirtualTable.Values) > 0 && len(readType.VirtualTable.Expressions) > 0 {
return nil, fmt.Errorf("VirtualTable Value can't have both liternal and expression")
}

Check warning on line 321 in plan/plan.go

View check run for this annotation

Codecov / codecov/patch

plan/plan.go#L320-L321

Added lines #L320 - L321 were not covered by tests
var values []expr.VirtualTableExpressionValue
for _, v := range readType.VirtualTable.Values {
values = append(values, expr.VirtualTableExprFromLiteralProto(v))
}

Check warning on line 325 in plan/plan.go

View check run for this annotation

Codecov / codecov/patch

plan/plan.go#L324-L325

Added lines #L324 - L325 were not covered by tests
for _, v := range readType.VirtualTable.Expressions {
row, err := expr.VirtualTableExpressionFromProto(v, reg)
if err != nil {
return nil, err
}

Check warning on line 330 in plan/plan.go

View check run for this annotation

Codecov / codecov/patch

plan/plan.go#L329-L330

Added lines #L329 - L330 were not covered by tests
values = append(values, row)
}

out = &VirtualTableReadRel{
Expand Down
12 changes: 6 additions & 6 deletions plan/plan_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1348,17 +1348,17 @@ func TestSetRelations(t *testing.T) {
}
},
"virtualTable": {
"values": [
"expressions": [
{
"fields": [
{ "string": "foo", "nullable": false },
{ "fp32": 1.5, "nullable": false }
{"literal": { "string": "foo", "nullable": false }},
{"literal": { "fp32": 1.5, "nullable": false }}
]
},
{
"fields": [
{ "string": "bar", "nullable": false },
{ "fp32": 3.5, "nullable": false }
{"literal": { "string": "bar", "nullable": false }},
{"literal": { "fp32": 3.5, "nullable": false }}
]
}
]
Expand Down Expand Up @@ -1429,7 +1429,7 @@ func TestEmptyVirtualTable(t *testing.T) {
}
},
"virtualTable": {
"values": [
"expressions": [
{},
{},
{},
Expand Down
8 changes: 4 additions & 4 deletions plan/relations.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,23 +208,23 @@ func (n *NamedTableReadRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, _
type VirtualTableReadRel struct {
baseReadRel

values []expr.StructLiteralValue
values []expr.VirtualTableExpressionValue
}

func (v *VirtualTableReadRel) Values() []expr.StructLiteralValue {
func (v *VirtualTableReadRel) Values() []expr.VirtualTableExpressionValue {

Check warning on line 214 in plan/relations.go

View check run for this annotation

Codecov / codecov/patch

plan/relations.go#L214

Added line #L214 was not covered by tests
return v.values
}

func (v *VirtualTableReadRel) ToProto() *proto.Rel {
readRel := v.toReadRelProto()
values := make([]*proto.Expression_Literal_Struct, len(v.values))
values := make([]*proto.Expression_Nested_Struct, len(v.values))
for i, v := range v.values {
values[i] = v.ToProto()
}

readRel.ReadType = &proto.ReadRel_VirtualTable_{
VirtualTable: &proto.ReadRel_VirtualTable{
Values: values,
Expressions: values,
},
}
return &proto.Rel{
Expand Down
4 changes: 2 additions & 2 deletions plan/relations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func noOpRewrite(e expr.Expression) (expr.Expression, error) {
}

func createVirtualTableReadRel(value int64) *VirtualTableReadRel {
return &VirtualTableReadRel{values: []expr.StructLiteralValue{[]expr.Literal{&expr.PrimitiveLiteral[int64]{Value: value}}}}
return &VirtualTableReadRel{values: []expr.VirtualTableExpressionValue{[]expr.Expression{&expr.PrimitiveLiteral[int64]{Value: value}}}}
}

func createPrimitiveFloat(value float64) expr.Expression {
Expand All @@ -40,7 +40,7 @@ func TestRelations_Copy(t *testing.T) {
projectRel := &ProjectRel{input: createVirtualTableReadRel(1), exprs: []expr.Expression{createPrimitiveFloat(1.0), createPrimitiveFloat(2.0)}}
setRel := &SetRel{inputs: []Rel{createVirtualTableReadRel(1), createVirtualTableReadRel(2), createVirtualTableReadRel(3)}, op: SetOpUnionAll}
sortRel := &SortRel{input: createVirtualTableReadRel(1), sorts: []expr.SortField{{Expr: createPrimitiveFloat(1.0), Kind: types.SortAscNullsFirst}}}
virtualTableReadRel := &VirtualTableReadRel{values: []expr.StructLiteralValue{[]expr.Literal{&expr.PrimitiveLiteral[int64]{Value: 1}}}}
virtualTableReadRel := &VirtualTableReadRel{values: []expr.VirtualTableExpressionValue{[]expr.Expression{&expr.PrimitiveLiteral[int64]{Value: 1}}}}
namedTableWriteRel := &NamedTableWriteRel{input: namedTableReadRel}

type relationTestCase struct {
Expand Down

0 comments on commit 55ff889

Please sign in to comment.