diff --git a/client/tailsql/client.go b/client/tailsql/client.go index eb3af6e..8a34fc0 100644 --- a/client/tailsql/client.go +++ b/client/tailsql/client.go @@ -80,6 +80,22 @@ func QueryJSON[T any](ctx context.Context, c Client, dataSrc, sql string) ([]T, } } +// JSONString is a wrapper type that decodes JSON text encoded as a string +// value to be decoded into plain JSON text. +type JSONString []byte + +// UnmarshalText implements the encoding.TextUnmarshaler interface for JSON +// text encoded inside a JSON string value. +func (js *JSONString) UnmarshalText(data []byte) error { + return json.Unmarshal(data, (*json.RawMessage)(js)) +} + +// MarshalText encodes a JSON text into a JSON string value. +func (js JSONString) MarshalText() ([]byte, error) { return []byte(js), nil } + +// Unmarshal unmarshals the JSON encoded in js into v. +func (js JSONString) Unmarshal(v any) error { return json.Unmarshal([]byte(js), v) } + // Rows is the result of a successful Query call. type Rows struct { Columns []string // column names diff --git a/client/tailsql/client_test.go b/client/tailsql/client_test.go index 277df0f..ed89ed7 100644 --- a/client/tailsql/client_test.go +++ b/client/tailsql/client_test.go @@ -1,8 +1,10 @@ package tailsql_test import ( + "bytes" "context" "database/sql" + "encoding/json" "net/http/httptest" "path/filepath" "testing" @@ -143,3 +145,66 @@ func TestClient(t *testing.T) { } }) } + +func TestJSONString(t *testing.T) { + var tdata = struct { + S string `json:"foo"` + Z int `json:"bar"` + B bool `json:"baz"` + }{S: "hello", Z: 1337, B: true} + + tjson, err := json.Marshal(tdata) + if err != nil { + t.Fatalf("Encode test data: %v", err) + } + + t.Run("Encode", func(t *testing.T) { + const want = `"{\"foo\":\"hello\",\"bar\":1337,\"baz\":true}"` + enc, err := json.Marshal(tailsql.JSONString(tjson)) + if err != nil { + t.Fatalf("Encode failed: %v", err) + } + if got := string(enc); got != want { + t.Errorf("Encode: got %#q, want %#q", got, want) + } + }) + + // Verify that we can round-trip through a string. + t.Run("RoundTrip", func(t *testing.T) { + enc, err := json.Marshal(struct { + V tailsql.JSONString + }{V: tjson}) + if err != nil { + t.Fatalf("Encode wrapper: %v", err) + } + + var dec struct { + V tailsql.JSONString + } + if err := json.Unmarshal(enc, &dec); err != nil { + t.Fatalf("Decode wrapper: %v", err) + } + + if !bytes.Equal(dec.V, tjson) { + t.Fatalf("Decoded string: got %#q, want %#q", dec.V, tjson) + } + }) + + t.Run("Unmarshal", func(t *testing.T) { + const input = `" [2 , 3 , 5] "` + + var js tailsql.JSONString + if err := json.Unmarshal([]byte(input), &js); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + var zs []int + if err := js.Unmarshal(&zs); err != nil { + t.Fatalf("Unmarshal message: %v", err) + } + + if diff := cmp.Diff(zs, []int{2, 3, 5}); diff != "" { + t.Fatalf("Result (-got, +want):\n%s", diff) + } + }) +}