Skip to content

Commit

Permalink
feat(go-bindgen): represent uint8 discriminants as LEB128
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Volosatovs <[email protected]>
  • Loading branch information
rvolosatovs committed May 8, 2024
1 parent 7220e7f commit 9902393
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 85 deletions.
156 changes: 93 additions & 63 deletions crates/wit-bindgen-go/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,82 @@ impl InterfaceGenerator<'_> {
}
}

fn print_read_discriminant(&mut self, repr: Int, reader: &str) {
match repr {
Int::U8 => {
uwrite!(
self.src,
r#"func(r {wrpc}.ByteReader) (uint8, error) {{
var x uint8
var s uint
for i := 0; i < 2; i++ {{
{slog}.Debug("reading `uint8` byte", "i", i)
b, err := r.ReadByte()
if err != nil {{
if i > 0 && err == {io}.EOF {{
err = {io}.ErrUnexpectedEOF
}}
return x, {fmt}.Errorf("failed to read `uint8` byte: %w", err)
}}
if b < 0x80 {{
if i == 2 && b > 1 {{
return x, {errors}.New("varint overflows a 8-bit integer")
}}
return x | uint8(b)<<s, nil
}}
x |= uint8(b&0x7f) << s
s += 7
}}
return x, {errors}.New("varint overflows a 8-bit integer")
}}({reader})"#,
errors = self.deps.errors(),
fmt = self.deps.fmt(),
io = self.deps.io(),
slog = self.deps.slog(),
wrpc = self.deps.wrpc(),
);
}
Int::U16 => {
self.print_read_ty(&Type::U16, reader);
}
Int::U32 => {
self.print_read_ty(&Type::U32, reader);
}
Int::U64 => {
self.print_read_ty(&Type::U64, reader);
}
}
}

fn print_write_discriminant(&mut self, repr: Int, name: &str, writer: &str) {
match repr {
Int::U8 => {
uwrite!(
self.src,
r#"func(v uint8, w {wrpc}.ByteWriter) error {{
b := make([]byte, 2)
i := {binary}.PutUvarint(b, uint64(v))
{slog}.Debug("writing u8")
_, err := w.Write(b[:i])
return err
}}(uint8({name}), {writer})"#,
binary = self.deps.binary(),
slog = self.deps.slog(),
wrpc = self.deps.wrpc(),
);
}
Int::U16 => {
self.print_write_ty(&Type::U16, &format!("uint16({name})"), writer);
}
Int::U32 => {
self.print_write_ty(&Type::U32, &format!("uint32({name})"), writer);
}
Int::U64 => {
self.print_write_ty(&Type::U64, &format!("uint64({name})"), writer);
}
}
}

fn print_write_tyid(&mut self, id: TypeId, name: &str, writer: &str) {
let ty = &self.resolve.types[id];
if ty.name.is_some() {
Expand Down Expand Up @@ -1078,8 +1154,13 @@ impl InterfaceGenerator<'_> {
let wrpc = self.deps.wrpc();
uwriteln!(
self.src,
r#"stop{i}, err := c.Serve("{instance}", "{name}", func(ctx {context}.Context, buffer []byte, tx {wrpc}.Transmitter, inv {wrpc}.IncomingInvocation) error {{
{slog}.DebugContext(ctx, "subscribing for `{instance}.{name}` parameters")
r#"stop{i}, err := c.Serve("{instance}", "{name}", func(ctx {context}.Context, buffer []byte, tx {wrpc}.Transmitter, inv {wrpc}.IncomingInvocation) error {{"#,
);
if !params.is_empty() {
// TODO: Handle async parameters
uwriteln!(
self.src,
r#"{slog}.DebugContext(ctx, "subscribing for `{instance}.{name}` parameters")
payload := make(chan []byte)
stop, err := inv.Subscribe(func(ctx {context}.Context, buf []byte) {{
Expand All @@ -1092,11 +1173,12 @@ impl InterfaceGenerator<'_> {
if err := stop(); err != nil {{
{slog}.ErrorContext(ctx, "failed to stop parameter subscription", "err", err)
}}
}}()
// TODO: Handle async parameters
{slog}.DebugContext(ctx, "accepting handshake")
}}()"#,
);
}
uwriteln!(
self.src,
r#"{slog}.DebugContext(ctx, "accepting handshake")
if err := inv.Accept(ctx, nil); err != nil {{
return {fmt}.Errorf("failed to complete handshake: %w", err)
}}"#,
Expand Down Expand Up @@ -1930,20 +2012,7 @@ func (v *{name}) WriteTo(w {wrpc}.ByteWriter) error {{"#
r#"func (v *{name}) WriteTo(w {wrpc}.ByteWriter) error {{"#,
);
self.push_str("if err := ");
match variant.tag() {
Int::U8 => {
self.print_write_ty(&Type::U8, "uint8(v.discriminant)", "w");
}
Int::U16 => {
self.print_write_ty(&Type::U16, "uint16(v.discriminant)", "w");
}
Int::U32 => {
self.print_write_ty(&Type::U32, "uint32(v.discriminant)", "w");
}
Int::U64 => {
self.print_write_ty(&Type::U64, "uint64(v.discriminant)", "w");
}
}
self.print_write_discriminant(variant.tag(), "v.discriminant", "w");
self.push_str("; err != nil { return ");
self.push_str(fmt);
self.push_str(".Errorf(\"failed to write discriminant: %w\", err)\n}\n");
Expand Down Expand Up @@ -1983,20 +2052,7 @@ func (v *{name}) WriteTo(w {wrpc}.ByteWriter) error {{"#
r#"func Read{name}(r {wrpc}.ByteReader) (*{name}, error) {{
disc, err := "#,
);
match variant.tag() {
Int::U8 => {
self.print_read_ty(&Type::U8, "r");
}
Int::U16 => {
self.print_read_ty(&Type::U16, "r");
}
Int::U32 => {
self.print_read_ty(&Type::U32, "r");
}
Int::U64 => {
self.print_read_ty(&Type::U64, "r");
}
}
self.print_read_discriminant(variant.tag(), "r");
self.push_str("\n");
self.push_str("if err != nil {\n");
self.push_str("return nil, ");
Expand Down Expand Up @@ -2125,20 +2181,7 @@ func (v *{name}) WriteTo(w {wrpc}.ByteWriter) error {{"#
r#"func (v {name}) WriteTo(w {wrpc}.ByteWriter) error {{"#,
);
self.push_str("if err := ");
match enum_.tag() {
Int::U8 => {
self.print_write_ty(&Type::U8, "uint8(v)", "w");
}
Int::U16 => {
self.print_write_ty(&Type::U16, "uint16(v)", "w");
}
Int::U32 => {
self.print_write_ty(&Type::U32, "uint32(v)", "w");
}
Int::U64 => {
self.print_write_ty(&Type::U64, "uint64(v)", "w");
}
}
self.print_write_discriminant(enum_.tag(), "v", "w");
self.push_str("; err != nil { return ");
self.push_str(fmt);
self.push_str(".Errorf(\"failed to write discriminant: %w\", err)\n}\n");
Expand All @@ -2150,20 +2193,7 @@ func (v *{name}) WriteTo(w {wrpc}.ByteWriter) error {{"#
r#"func Read{name}(r {wrpc}.ByteReader) (v {name}, err error) {{
disc, err := "#,
);
match enum_.tag() {
Int::U8 => {
self.print_read_ty(&Type::U8, "r");
}
Int::U16 => {
self.print_read_ty(&Type::U16, "r");
}
Int::U32 => {
self.print_read_ty(&Type::U32, "r");
}
Int::U64 => {
self.print_read_ty(&Type::U64, "r");
}
}
self.print_read_discriminant(enum_.tag(), "r");
self.push_str("\n");
self.push_str("if err != nil {\n");
self.push_str("return v, ");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type Error struct {

func (v *Error) Discriminant() ErrorDiscriminant { return v.discriminant }

type ErrorDiscriminant uint32
type ErrorDiscriminant uint8

const (
// The host does not recognize the store identifier requested.
Expand All @@ -43,7 +43,7 @@ func (v *Error) String() string {
case ErrorDiscriminant_Other:
return "other"
default:
panic("unreachable")
panic("invalid variant")
}
}

Expand Down Expand Up @@ -93,13 +93,13 @@ func NewError_Other(payload string) *Error {
}
func (v *Error) Error() string { return v.String() }
func (v *Error) WriteTo(w wrpc.ByteWriter) error {
if err := func(v uint32, w wrpc.ByteWriter) error {
b := make([]byte, binary.MaxVarintLen32)
if err := func(v uint8, w wrpc.ByteWriter) error {
b := make([]byte, 2)
i := binary.PutUvarint(b, uint64(v))
slog.Debug("writing u32")
slog.Debug("writing u8")
_, err := w.Write(b[:i])
return err
}(uint32(v.discriminant), w); err != nil {
}(uint8(v.discriminant), w); err != nil {
return fmt.Errorf("failed to write discriminant: %w", err)
}
switch v.discriminant {
Expand Down Expand Up @@ -135,10 +135,84 @@ func (v *Error) WriteTo(w wrpc.ByteWriter) error {
return fmt.Errorf("failed to write payload: %w", err)
}
default:
panic("unreachable")
panic("invalid variant")
}
return nil
}
func ReadError(r wrpc.ByteReader) (*Error, error) {
disc, err := func(r wrpc.ByteReader) (uint8, error) {
var x uint8
var s uint
for i := 0; i < 2; i++ {
slog.Debug("reading `uint8` byte", "i", i)
b, err := r.ReadByte()
if err != nil {
if i > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return x, fmt.Errorf("failed to read `uint8` byte: %w", err)
}
if b < 0x80 {
if i == 2 && b > 1 {
return x, errors.New("varint overflows a 8-bit integer")
}
return x | uint8(b)<<s, nil
}
x |= uint8(b&0x7f) << s
s += 7
}
return x, errors.New("varint overflows a 8-bit integer")
}(r)
if err != nil {
return nil, fmt.Errorf("failed to read discriminant: %w", err)
}
switch ErrorDiscriminant(disc) {
case ErrorDiscriminant_NoSuchStore:
return NewError_NoSuchStore(), nil
case ErrorDiscriminant_AccessDenied:
return NewError_AccessDenied(), nil
case ErrorDiscriminant_Other:
payload, err := func(r wrpc.ByteReader) (string, error) {
var x uint32
var s uint
for i := 0; i < 5; i++ {
slog.Debug("reading string length byte", "i", i)
b, err := r.ReadByte()
if err != nil {
if i > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return "", fmt.Errorf("failed to read string length byte: %w", err)
}
if b < 0x80 {
if i == 4 && b > 1 {
return "", errors.New("string length overflows a 32-bit integer")
}
x = x | uint32(b)<<s
buf := make([]byte, x)
slog.Debug("reading string bytes", "len", x)
_, err = r.Read(buf)
if err != nil {
return "", fmt.Errorf("failed to read string bytes: %w", err)
}
if !utf8.Valid(buf) {
return string(buf), errors.New("string is not valid UTF-8")
}
return string(buf), nil
}
x |= uint32(b&0x7f) << s
s += 7
}
return "", errors.New("string length overflows a 32-bit integer")
}(r)
if err != nil {
return nil, fmt.Errorf("failed to read `other` payload: %w", err)
}
return NewError_Other(payload), nil
default:
return nil, fmt.Errorf("unknown discriminant value %d", disc)
}
}

// A response to a `list-keys` operation.
type KeyResponse struct {
Expand Down Expand Up @@ -440,9 +514,6 @@ func ServeInterface(c wrpc.Client, h Handler) (stop func() error, err error) {
slog.ErrorContext(ctx, "failed to stop parameter subscription", "err", err)
}
}()

// TODO: Handle async parameters

slog.DebugContext(ctx, "accepting handshake")
if err := inv.Accept(ctx, nil); err != nil {
return fmt.Errorf("failed to complete handshake: %w", err)
Expand Down Expand Up @@ -626,9 +697,6 @@ func ServeInterface(c wrpc.Client, h Handler) (stop func() error, err error) {
slog.ErrorContext(ctx, "failed to stop parameter subscription", "err", err)
}
}()

// TODO: Handle async parameters

slog.DebugContext(ctx, "accepting handshake")
if err := inv.Accept(ctx, nil); err != nil {
return fmt.Errorf("failed to complete handshake: %w", err)
Expand Down Expand Up @@ -809,9 +877,6 @@ func ServeInterface(c wrpc.Client, h Handler) (stop func() error, err error) {
slog.ErrorContext(ctx, "failed to stop parameter subscription", "err", err)
}
}()

// TODO: Handle async parameters

slog.DebugContext(ctx, "accepting handshake")
if err := inv.Accept(ctx, nil); err != nil {
return fmt.Errorf("failed to complete handshake: %w", err)
Expand Down Expand Up @@ -949,9 +1014,6 @@ func ServeInterface(c wrpc.Client, h Handler) (stop func() error, err error) {
slog.ErrorContext(ctx, "failed to stop parameter subscription", "err", err)
}
}()

// TODO: Handle async parameters

slog.DebugContext(ctx, "accepting handshake")
if err := inv.Accept(ctx, nil); err != nil {
return fmt.Errorf("failed to complete handshake: %w", err)
Expand Down Expand Up @@ -1100,9 +1162,6 @@ func ServeInterface(c wrpc.Client, h Handler) (stop func() error, err error) {
slog.ErrorContext(ctx, "failed to stop parameter subscription", "err", err)
}
}()

// TODO: Handle async parameters

slog.DebugContext(ctx, "accepting handshake")
if err := inv.Accept(ctx, nil); err != nil {
return fmt.Errorf("failed to complete handshake: %w", err)
Expand Down

0 comments on commit 9902393

Please sign in to comment.