-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
pqarrow/arrowutils: Add EnsureSameSchema for records (#806)
- Loading branch information
1 parent
088e3b2
commit 8760e41
Showing
2 changed files
with
186 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
package arrowutils | ||
|
||
import ( | ||
"fmt" | ||
"sort" | ||
|
||
"github.com/apache/arrow/go/v16/arrow" | ||
"github.com/apache/arrow/go/v16/arrow/array" | ||
) | ||
|
||
// EnsureSameSchema ensures that all the records have the same schema. In cases | ||
// where the schema is not equal, virtual null columns are inserted in the | ||
// records with the missing column. When we have static schemas in the execution | ||
// engine, steps like these should be unnecessary. | ||
func EnsureSameSchema(records []arrow.Record) ([]arrow.Record, error) { | ||
if len(records) < 2 { | ||
return records, nil | ||
} | ||
|
||
lastSchema := records[0].Schema() | ||
needSchemaRecalculation := false | ||
for i := range records { | ||
if !records[i].Schema().Equal(lastSchema) { | ||
needSchemaRecalculation = true | ||
break | ||
} | ||
} | ||
if !needSchemaRecalculation { | ||
return records, nil | ||
} | ||
|
||
columns := make(map[string]arrow.Field) | ||
for _, r := range records { | ||
for j := 0; j < r.Schema().NumFields(); j++ { | ||
field := r.Schema().Field(j) | ||
if _, ok := columns[field.Name]; !ok { | ||
columns[field.Name] = field | ||
} | ||
} | ||
} | ||
|
||
columnNames := make([]string, 0, len(columns)) | ||
for name := range columns { | ||
columnNames = append(columnNames, name) | ||
} | ||
sort.Strings(columnNames) | ||
|
||
mergedFields := make([]arrow.Field, 0, len(columnNames)) | ||
for _, name := range columnNames { | ||
mergedFields = append(mergedFields, columns[name]) | ||
} | ||
mergedSchema := arrow.NewSchema(mergedFields, nil) | ||
|
||
mergedRecords := make([]arrow.Record, len(records)) | ||
var replacedRecords []arrow.Record | ||
|
||
for i := range records { | ||
recordSchema := records[i].Schema() | ||
if mergedSchema.Equal(recordSchema) { | ||
mergedRecords[i] = records[i] | ||
continue | ||
} | ||
|
||
mergedColumns := make([]arrow.Array, 0, len(mergedFields)) | ||
recordNumRows := records[i].NumRows() | ||
for j := 0; j < mergedSchema.NumFields(); j++ { | ||
field := mergedSchema.Field(j) | ||
if otherFields := recordSchema.FieldIndices(field.Name); otherFields != nil { | ||
if len(otherFields) > 1 { | ||
fieldsFound, _ := recordSchema.FieldsByName(field.Name) | ||
return nil, fmt.Errorf( | ||
"found multiple fields %v for name %s", | ||
fieldsFound, | ||
field.Name, | ||
) | ||
} | ||
mergedColumns = append(mergedColumns, records[i].Column(otherFields[0])) | ||
} else { | ||
// Note that this VirtualNullArray will be read from, but the | ||
// merged output will be a physical null array, so there is no | ||
// virtual->physical conversion necessary before we return data. | ||
mergedColumns = append(mergedColumns, MakeVirtualNullArray(field.Type, int(recordNumRows))) | ||
} | ||
} | ||
|
||
replacedRecords = append(replacedRecords, records[i]) | ||
mergedRecords[i] = array.NewRecord(mergedSchema, mergedColumns, recordNumRows) | ||
} | ||
|
||
for _, r := range replacedRecords { | ||
r.Release() | ||
} | ||
|
||
return mergedRecords, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
package arrowutils_test | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/apache/arrow/go/v16/arrow" | ||
"github.com/apache/arrow/go/v16/arrow/memory" | ||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/polarsignals/frostdb/internal/records" | ||
"github.com/polarsignals/frostdb/pqarrow/arrowutils" | ||
) | ||
|
||
func TestEnsureSameSchema(t *testing.T) { | ||
type struct1 struct { | ||
Field1 int64 `frostdb:",asc(0)"` | ||
Field2 int64 `frostdb:",asc(1)"` | ||
} | ||
type struct2 struct { | ||
Field1 int64 `frostdb:",asc(0)"` | ||
Field3 int64 `frostdb:",asc(1)"` | ||
} | ||
type struct3 struct { | ||
Field1 int64 `frostdb:",asc(0)"` | ||
Field2 int64 `frostdb:",asc(1)"` | ||
Field3 int64 `frostdb:",asc(1)"` | ||
} | ||
|
||
mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) | ||
defer mem.AssertSize(t, 0) | ||
|
||
build1 := records.NewBuild[struct1](mem) | ||
defer build1.Release() | ||
err := build1.Append([]struct1{ | ||
{Field1: 1, Field2: 2}, | ||
{Field1: 1, Field2: 3}, | ||
}...) | ||
require.NoError(t, err) | ||
|
||
build2 := records.NewBuild[struct2](mem) | ||
defer build2.Release() | ||
err = build2.Append([]struct2{ | ||
{Field1: 1, Field3: 2}, | ||
{Field1: 1, Field3: 3}, | ||
}...) | ||
require.NoError(t, err) | ||
|
||
build3 := records.NewBuild[struct3](mem) | ||
defer build3.Release() | ||
err = build3.Append([]struct3{ | ||
{Field1: 1, Field2: 1, Field3: 1}, | ||
{Field1: 2, Field2: 2, Field3: 2}, | ||
}...) | ||
require.NoError(t, err) | ||
|
||
record1 := build1.NewRecord() | ||
record2 := build2.NewRecord() | ||
record3 := build3.NewRecord() | ||
|
||
recs := []arrow.Record{record1, record2, record3} | ||
defer func() { | ||
for _, r := range recs { | ||
r.Release() | ||
} | ||
}() | ||
|
||
recs, err = arrowutils.EnsureSameSchema(recs) | ||
require.NoError(t, err) | ||
|
||
expected := []struct3{ | ||
// record1 | ||
{Field1: 1, Field2: 2, Field3: 0}, | ||
{Field1: 1, Field2: 3, Field3: 0}, | ||
// record2 | ||
{Field1: 1, Field2: 0, Field3: 2}, | ||
{Field1: 1, Field2: 0, Field3: 3}, | ||
// record3 | ||
{Field1: 1, Field2: 1, Field3: 1}, | ||
{Field1: 2, Field2: 2, Field3: 2}, | ||
} | ||
|
||
reader := records.NewReader[struct3](recs...) | ||
rows := reader.NumRows() | ||
require.Equal(t, int64(len(expected)), rows) | ||
|
||
actual := make([]struct3, rows) | ||
for i := 0; i < int(rows); i++ { | ||
actual[i] = reader.Value(i) | ||
} | ||
require.Equal(t, expected, actual) | ||
} |