From 713b6cda2d6d97c9c1a72f9abd001c155ebc0582 Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Fri, 27 Oct 2023 15:27:34 -0700 Subject: [PATCH] client/tailsql: add a convenience type for string-wrapped JSON text --- client/tailsql/client.go | 13 ++++++++++ client/tailsql/client_test.go | 47 +++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/client/tailsql/client.go b/client/tailsql/client.go index eb3af6e..dbd4208 100644 --- a/client/tailsql/client.go +++ b/client/tailsql/client.go @@ -80,6 +80,19 @@ func QueryJSON[T any](ctx context.Context, c Client, dataSrc, sql string) ([]T, } } +// JSONString is a wrapper type that decodes JSON text encoded as a string +// value to be decoded into plain JSON text. +type JSONString []byte + +// UnmarshalText implements the encoding.TextUnmarshaler interface for JSON +// text encoded inside a JSON string value. +func (js *JSONString) UnmarshalText(data []byte) error { + return json.Unmarshal(data, (*json.RawMessage)(js)) +} + +// MarshalText encodes a JSON text into a JSON string value. +func (js JSONString) MarshalText() ([]byte, error) { return []byte(js), nil } + // Rows is the result of a successful Query call. type Rows struct { Columns []string // column names diff --git a/client/tailsql/client_test.go b/client/tailsql/client_test.go index 277df0f..a272f96 100644 --- a/client/tailsql/client_test.go +++ b/client/tailsql/client_test.go @@ -1,8 +1,10 @@ package tailsql_test import ( + "bytes" "context" "database/sql" + "encoding/json" "net/http/httptest" "path/filepath" "testing" @@ -143,3 +145,48 @@ func TestClient(t *testing.T) { } }) } + +func TestJSONString(t *testing.T) { + var tdata = struct { + S string `json:"foo"` + Z int `json:"bar"` + B bool `json:"baz"` + }{S: "hello", Z: 1337, B: true} + + tjson, err := json.Marshal(tdata) + if err != nil { + t.Fatalf("Encode test data: %v", err) + } + + t.Run("Encode", func(t *testing.T) { + const want = `"{\"foo\":\"hello\",\"bar\":1337,\"baz\":true}"` + enc, err := json.Marshal(tailsql.JSONString(tjson)) + if err != nil { + t.Fatalf("Encode failed: %v", err) + } + if got := string(enc); got != want { + t.Errorf("Encode: got %#q, want %#q", got, want) + } + }) + + // Verify that we can round-trip through a string. + t.Run("RoundTrip", func(t *testing.T) { + enc, err := json.Marshal(struct { + V tailsql.JSONString + }{V: tjson}) + if err != nil { + t.Fatalf("Encode wrapper: %v", err) + } + + var dec struct { + V tailsql.JSONString + } + if err := json.Unmarshal(enc, &dec); err != nil { + t.Fatalf("Decode wrapper: %v", err) + } + + if !bytes.Equal(dec.V, tjson) { + t.Fatalf("Decoded string: got %#q, want %#q", dec.V, tjson) + } + }) +}