Skip to content

Commit

Permalink
pqarrow/arrowutils: Add EnsureSameSchema for records (#806)
Browse files Browse the repository at this point in the history
  • Loading branch information
metalmatze authored Apr 25, 2024
1 parent 088e3b2 commit 8760e41
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 0 deletions.
95 changes: 95 additions & 0 deletions pqarrow/arrowutils/schema.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package arrowutils

import (
"fmt"
"sort"

"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/array"
)

// EnsureSameSchema ensures that all the records have the same schema. In cases
// where the schema is not equal, virtual null columns are inserted in the
// records with the missing column. When we have static schemas in the execution
// engine, steps like these should be unnecessary.
func EnsureSameSchema(records []arrow.Record) ([]arrow.Record, error) {
if len(records) < 2 {
return records, nil
}

lastSchema := records[0].Schema()
needSchemaRecalculation := false
for i := range records {
if !records[i].Schema().Equal(lastSchema) {
needSchemaRecalculation = true
break
}
}
if !needSchemaRecalculation {
return records, nil
}

columns := make(map[string]arrow.Field)
for _, r := range records {
for j := 0; j < r.Schema().NumFields(); j++ {
field := r.Schema().Field(j)
if _, ok := columns[field.Name]; !ok {
columns[field.Name] = field
}
}
}

columnNames := make([]string, 0, len(columns))
for name := range columns {
columnNames = append(columnNames, name)
}
sort.Strings(columnNames)

mergedFields := make([]arrow.Field, 0, len(columnNames))
for _, name := range columnNames {
mergedFields = append(mergedFields, columns[name])
}
mergedSchema := arrow.NewSchema(mergedFields, nil)

mergedRecords := make([]arrow.Record, len(records))
var replacedRecords []arrow.Record

for i := range records {
recordSchema := records[i].Schema()
if mergedSchema.Equal(recordSchema) {
mergedRecords[i] = records[i]
continue
}

mergedColumns := make([]arrow.Array, 0, len(mergedFields))
recordNumRows := records[i].NumRows()
for j := 0; j < mergedSchema.NumFields(); j++ {
field := mergedSchema.Field(j)
if otherFields := recordSchema.FieldIndices(field.Name); otherFields != nil {
if len(otherFields) > 1 {
fieldsFound, _ := recordSchema.FieldsByName(field.Name)
return nil, fmt.Errorf(
"found multiple fields %v for name %s",
fieldsFound,
field.Name,
)
}
mergedColumns = append(mergedColumns, records[i].Column(otherFields[0]))
} else {
// Note that this VirtualNullArray will be read from, but the
// merged output will be a physical null array, so there is no
// virtual->physical conversion necessary before we return data.
mergedColumns = append(mergedColumns, MakeVirtualNullArray(field.Type, int(recordNumRows)))
}
}

replacedRecords = append(replacedRecords, records[i])
mergedRecords[i] = array.NewRecord(mergedSchema, mergedColumns, recordNumRows)
}

for _, r := range replacedRecords {
r.Release()
}

return mergedRecords, nil
}
91 changes: 91 additions & 0 deletions pqarrow/arrowutils/schema_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package arrowutils_test

import (
"testing"

"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/memory"
"github.com/stretchr/testify/require"

"github.com/polarsignals/frostdb/internal/records"
"github.com/polarsignals/frostdb/pqarrow/arrowutils"
)

func TestEnsureSameSchema(t *testing.T) {
type struct1 struct {
Field1 int64 `frostdb:",asc(0)"`
Field2 int64 `frostdb:",asc(1)"`
}
type struct2 struct {
Field1 int64 `frostdb:",asc(0)"`
Field3 int64 `frostdb:",asc(1)"`
}
type struct3 struct {
Field1 int64 `frostdb:",asc(0)"`
Field2 int64 `frostdb:",asc(1)"`
Field3 int64 `frostdb:",asc(1)"`
}

mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer mem.AssertSize(t, 0)

build1 := records.NewBuild[struct1](mem)
defer build1.Release()
err := build1.Append([]struct1{
{Field1: 1, Field2: 2},
{Field1: 1, Field2: 3},
}...)
require.NoError(t, err)

build2 := records.NewBuild[struct2](mem)
defer build2.Release()
err = build2.Append([]struct2{
{Field1: 1, Field3: 2},
{Field1: 1, Field3: 3},
}...)
require.NoError(t, err)

build3 := records.NewBuild[struct3](mem)
defer build3.Release()
err = build3.Append([]struct3{
{Field1: 1, Field2: 1, Field3: 1},
{Field1: 2, Field2: 2, Field3: 2},
}...)
require.NoError(t, err)

record1 := build1.NewRecord()
record2 := build2.NewRecord()
record3 := build3.NewRecord()

recs := []arrow.Record{record1, record2, record3}
defer func() {
for _, r := range recs {
r.Release()
}
}()

recs, err = arrowutils.EnsureSameSchema(recs)
require.NoError(t, err)

expected := []struct3{
// record1
{Field1: 1, Field2: 2, Field3: 0},
{Field1: 1, Field2: 3, Field3: 0},
// record2
{Field1: 1, Field2: 0, Field3: 2},
{Field1: 1, Field2: 0, Field3: 3},
// record3
{Field1: 1, Field2: 1, Field3: 1},
{Field1: 2, Field2: 2, Field3: 2},
}

reader := records.NewReader[struct3](recs...)
rows := reader.NumRows()
require.Equal(t, int64(len(expected)), rows)

actual := make([]struct3, rows)
for i := 0; i < int(rows); i++ {
actual[i] = reader.Value(i)
}
require.Equal(t, expected, actual)
}

0 comments on commit 8760e41

Please sign in to comment.