Skip to content

Commit

Permalink
Fix type conversion handling in convertToType function (#11)
Browse files Browse the repository at this point in the history
* Fix type conversion handling in convertToType function

* remove comment

* update tests, cover 100%
  • Loading branch information
LukaGiorgadze authored Dec 28, 2023
1 parent 882c001 commit 151a4a4
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 44 deletions.
32 changes: 18 additions & 14 deletions gonull.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,25 @@ func zeroValue[T any]() T {
// convertToType is a helper function that attempts to convert the given value to type T.
// This function is used by Scan to properly handle value conversion, ensuring that Nullable values are always of the correct type.
func convertToType[T any](value interface{}) (T, error) {
switch v := value.(type) {
case T:
return v, nil
case int64:
// This case handles the situation when the input value is of type int64.
// It attempts to convert the int64 value to the target numeric type T if possible.
// If the conversion is successful, it returns the converted value of type T and a nil error.
// If the conversion is not possible, the function will continue to the next case (return an error).
switch t := reflect.Zero(reflect.TypeOf((*T)(nil)).Elem()).Interface().(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
if reflect.TypeOf(t).ConvertibleTo(reflect.TypeOf((*T)(nil)).Elem()) {
return reflect.ValueOf(value).Convert(reflect.TypeOf((*T)(nil)).Elem()).Interface().(T), nil
}
var zero T
if value == nil {
return zero, nil
}

if reflect.TypeOf(value) == reflect.TypeOf(zero) {
return value.(T), nil
}

// Check if the value is a numeric type and if T is also a numeric type.
valueType := reflect.TypeOf(value)
targetType := reflect.TypeOf(zero)
if valueType.Kind() >= reflect.Int && valueType.Kind() <= reflect.Float64 &&
targetType.Kind() >= reflect.Int && targetType.Kind() <= reflect.Float64 {
if valueType.ConvertibleTo(targetType) {
convertedValue := reflect.ValueOf(value).Convert(targetType)
return convertedValue.Interface().(T), nil
}
}
var zero T

return zero, ErrUnsupportedConversion
}
161 changes: 131 additions & 30 deletions gonull_test.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
package gonull_test
package gonull

import (
"database/sql/driver"
"encoding/json"
"testing"

"github.com/LukaGiorgadze/gonull"
"github.com/stretchr/testify/assert"
)

func TestNewNullable(t *testing.T) {
value := "test"
n := gonull.NewNullable(value)
n := NewNullable(value)

assert.True(t, n.Valid)
assert.Equal(t, value, n.Val)
Expand Down Expand Up @@ -52,7 +51,7 @@ func TestNullableScan(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var n gonull.Nullable[string]
var n Nullable[string]
err := n.Scan(tt.value)

if tt.wantErr {
Expand All @@ -72,19 +71,19 @@ func TestNullableScan(t *testing.T) {
func TestNullableValue(t *testing.T) {
tests := []struct {
name string
nullable gonull.Nullable[string]
nullable Nullable[string]
wantValue driver.Value
wantErr error
}{
{
name: "valid value",
nullable: gonull.NewNullable("test"),
nullable: NewNullable("test"),
wantValue: "test",
wantErr: nil,
},
{
name: "unset value",
nullable: gonull.Nullable[string]{Valid: false},
nullable: Nullable[string]{Valid: false},
wantValue: nil,
wantErr: nil,
},
Expand Down Expand Up @@ -128,7 +127,7 @@ func TestNullableUnmarshalJSON(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var nullable gonull.Nullable[int]
var nullable Nullable[int]

err := nullable.UnmarshalJSON(tc.jsonData)

Expand All @@ -143,7 +142,7 @@ func TestNullableUnmarshalJSON(t *testing.T) {
func TestNullableUnmarshalJSON_Error(t *testing.T) {
jsonData := []byte(`"invalid_number"`)

var nullable gonull.Nullable[int]
var nullable Nullable[int]
err := nullable.UnmarshalJSON(jsonData)

assert.Error(t, err)
Expand All @@ -153,19 +152,19 @@ func TestNullableUnmarshalJSON_Error(t *testing.T) {
func TestNullableMarshalJSON(t *testing.T) {
type testCase struct {
name string
nullable gonull.Nullable[int]
nullable Nullable[int]
expectedJSON []byte
}

testCases := []testCase{
{
name: "ValuePresent",
nullable: gonull.NewNullable[int](123),
nullable: NewNullable[int](123),
expectedJSON: []byte(`123`),
},
{
name: "ValueNull",
nullable: gonull.Nullable[int]{Val: 0, Valid: false},
nullable: Nullable[int]{Val: 0, Valid: false},
expectedJSON: []byte(`null`),
},
}
Expand All @@ -182,7 +181,7 @@ func TestNullableMarshalJSON(t *testing.T) {
func TestNullableScan_UnconvertibleFromInt64(t *testing.T) {
value := int64(123456789012345)

var n gonull.Nullable[string]
var n Nullable[string]
err := n.Scan(value)

assert.Error(t, err)
Expand All @@ -205,39 +204,39 @@ func TestConvertToTypeFromInt64(t *testing.T) {
{name: "Convert int64 to uint16", targetType: "uint16", value: int64(7), expectedError: nil},
{name: "Convert int64 to uint32", targetType: "uint32", value: int64(8), expectedError: nil},
// Add more tests as necessary
{name: "Convert int64 to string (expected to fail)", targetType: "string", value: int64(9), expectedError: gonull.ErrUnsupportedConversion},
{name: "Convert int64 to string (expected to fail)", targetType: "string", value: int64(9), expectedError: ErrUnsupportedConversion},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
switch tt.targetType {
case "int":
n := gonull.Nullable[int]{}
n := Nullable[int]{}
err = n.Scan(tt.value)
case "int8":
n := gonull.Nullable[int8]{}
n := Nullable[int8]{}
err = n.Scan(tt.value)
case "int16":
n := gonull.Nullable[int16]{}
n := Nullable[int16]{}
err = n.Scan(tt.value)
case "int32":
n := gonull.Nullable[int32]{}
n := Nullable[int32]{}
err = n.Scan(tt.value)
case "uint":
n := gonull.Nullable[uint]{}
n := Nullable[uint]{}
err = n.Scan(tt.value)
case "uint8":
n := gonull.Nullable[uint8]{}
n := Nullable[uint8]{}
err = n.Scan(tt.value)
case "uint16":
n := gonull.Nullable[uint16]{}
n := Nullable[uint16]{}
err = n.Scan(tt.value)
case "uint32":
n := gonull.Nullable[uint32]{}
n := Nullable[uint32]{}
err = n.Scan(tt.value)
case "string":
n := gonull.Nullable[string]{}
n := Nullable[string]{}
err = n.Scan(tt.value)
default:
t.Fatalf("Unsupported type: %s", tt.targetType)
Expand All @@ -263,7 +262,7 @@ func TestNullableScanWithCustomEnum(t *testing.T) {

type TestModel struct {
ID int
Field gonull.Nullable[TestEnum]
Field Nullable[TestEnum]
}

// Simulate the scenario where the SQL driver returns an int64
Expand All @@ -273,19 +272,121 @@ func TestNullableScanWithCustomEnum(t *testing.T) {
// The converted value 0 (as float32) matches TestEnumA, which is also 0 when converted to float32.
sqlReturnedValue := int64(0)

model := TestModel{ID: 1, Field: gonull.NewNullable(TestEnumA)}
model := TestModel{ID: 1, Field: NewNullable(TestEnumA)}

err := model.Field.Scan(sqlReturnedValue)
if err != nil {
assert.Error(t, err, "Scan failed with unsupported type conversion")
} else {
assert.Equal(t, TestEnumA, model.Field.Val, "Scanned value does not match expected enum value")
assert.NoError(t, err, "Scan failed with unsupported type conversion")
assert.Equal(t, TestEnumA, model.Field.Val, "Scanned value does not match expected enum value")

}

func TestConvertToTypeWithNilValue(t *testing.T) {
tests := []struct {
name string
expected interface{}
}{
{
name: "Nil to int",
expected: int(0),
},
{
name: "Nil to int8",
expected: int8(0),
},
{
name: "Nil to int16",
expected: int16(0),
},
{
name: "Nil to int32",
expected: int32(0),
},
{
name: "Nil to int64",
expected: int64(0),
},
{
name: "Nil to uint",
expected: uint(0),
},
{
name: "Nil to uint8 (byte)",
expected: uint8(0),
},
{
name: "Nil to uint16",
expected: uint16(0),
},
{
name: "Nil to uint32",
expected: uint32(0),
},
{
name: "Nil to uint64",
expected: uint64(0),
},
{
name: "Nil to float32",
expected: float32(0),
},
{
name: "Nil to float64",
expected: float64(0),
},
{
name: "Nil to bool",
expected: bool(false),
},
{
name: "Nil to string",
expected: "",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var result interface{}
var err error

switch tc.expected.(type) {
case int:
result, err = convertToType[int](nil)
case int8:
result, err = convertToType[int8](nil)
case int16:
result, err = convertToType[int16](nil)
case int32:
result, err = convertToType[int32](nil)
case int64:
result, err = convertToType[int64](nil)
case uint:
result, err = convertToType[uint](nil)
case uint8:
result, err = convertToType[uint8](nil)
case uint16:
result, err = convertToType[uint16](nil)
case uint32:
result, err = convertToType[uint32](nil)
case uint64:
result, err = convertToType[uint64](nil)
case float32:
result, err = convertToType[float32](nil)
case float64:
result, err = convertToType[float64](nil)
case bool:
result, err = convertToType[bool](nil)
case string:
result, err = convertToType[string](nil)
}

assert.NoError(t, err)
assert.Equal(t, tc.expected, result)
})
}
}

type testStruct struct {
Foo gonull.Nullable[*string] `json:"foo"`
Foo Nullable[*string] `json:"foo"`
}

func TestPresent(t *testing.T) {
Expand Down

0 comments on commit 151a4a4

Please sign in to comment.