diff --git a/crates/wit-bindgen-go/src/interface.rs b/crates/wit-bindgen-go/src/interface.rs index f617a7bd..f708eb0e 100644 --- a/crates/wit-bindgen-go/src/interface.rs +++ b/crates/wit-bindgen-go/src/interface.rs @@ -690,6 +690,82 @@ impl InterfaceGenerator<'_> { } } + fn print_read_discriminant(&mut self, repr: Int, reader: &str) { + match repr { + Int::U8 => { + uwrite!( + self.src, + r#"func(r {wrpc}.ByteReader) (uint8, error) {{ + var x uint8 + var s uint + for i := 0; i < 2; i++ {{ + {slog}.Debug("reading `uint8` byte", "i", i) + b, err := r.ReadByte() + if err != nil {{ + if i > 0 && err == {io}.EOF {{ + err = {io}.ErrUnexpectedEOF + }} + return x, {fmt}.Errorf("failed to read `uint8` byte: %w", err) + }} + if b < 0x80 {{ + if i == 2 && b > 1 {{ + return x, {errors}.New("varint overflows a 8-bit integer") + }} + return x | uint8(b)< { + self.print_read_ty(&Type::U16, reader); + } + Int::U32 => { + self.print_read_ty(&Type::U32, reader); + } + Int::U64 => { + self.print_read_ty(&Type::U64, reader); + } + } + } + + fn print_write_discriminant(&mut self, repr: Int, name: &str, writer: &str) { + match repr { + Int::U8 => { + uwrite!( + self.src, + r#"func(v uint8, w {wrpc}.ByteWriter) error {{ + b := make([]byte, 2) + i := {binary}.PutUvarint(b, uint64(v)) + {slog}.Debug("writing u8") + _, err := w.Write(b[:i]) + return err + }}(uint8({name}), {writer})"#, + binary = self.deps.binary(), + slog = self.deps.slog(), + wrpc = self.deps.wrpc(), + ); + } + Int::U16 => { + self.print_write_ty(&Type::U16, &format!("uint16({name})"), writer); + } + Int::U32 => { + self.print_write_ty(&Type::U32, &format!("uint32({name})"), writer); + } + Int::U64 => { + self.print_write_ty(&Type::U64, &format!("uint64({name})"), writer); + } + } + } + fn print_write_tyid(&mut self, id: TypeId, name: &str, writer: &str) { let ty = &self.resolve.types[id]; if ty.name.is_some() { @@ -1078,8 +1154,13 @@ impl InterfaceGenerator<'_> { let wrpc = self.deps.wrpc(); uwriteln!( self.src, - r#"stop{i}, err := c.Serve("{instance}", "{name}", func(ctx {context}.Context, buffer []byte, tx {wrpc}.Transmitter, inv {wrpc}.IncomingInvocation) error {{ - {slog}.DebugContext(ctx, "subscribing for `{instance}.{name}` parameters") + r#"stop{i}, err := c.Serve("{instance}", "{name}", func(ctx {context}.Context, buffer []byte, tx {wrpc}.Transmitter, inv {wrpc}.IncomingInvocation) error {{"#, + ); + if !params.is_empty() { + // TODO: Handle async parameters + uwriteln!( + self.src, + r#"{slog}.DebugContext(ctx, "subscribing for `{instance}.{name}` parameters") payload := make(chan []byte) stop, err := inv.Subscribe(func(ctx {context}.Context, buf []byte) {{ @@ -1092,11 +1173,12 @@ impl InterfaceGenerator<'_> { if err := stop(); err != nil {{ {slog}.ErrorContext(ctx, "failed to stop parameter subscription", "err", err) }} - }}() - - // TODO: Handle async parameters - - {slog}.DebugContext(ctx, "accepting handshake") + }}()"#, + ); + } + uwriteln!( + self.src, + r#"{slog}.DebugContext(ctx, "accepting handshake") if err := inv.Accept(ctx, nil); err != nil {{ return {fmt}.Errorf("failed to complete handshake: %w", err) }}"#, @@ -1930,20 +2012,7 @@ func (v *{name}) WriteTo(w {wrpc}.ByteWriter) error {{"# r#"func (v *{name}) WriteTo(w {wrpc}.ByteWriter) error {{"#, ); self.push_str("if err := "); - match variant.tag() { - Int::U8 => { - self.print_write_ty(&Type::U8, "uint8(v.discriminant)", "w"); - } - Int::U16 => { - self.print_write_ty(&Type::U16, "uint16(v.discriminant)", "w"); - } - Int::U32 => { - self.print_write_ty(&Type::U32, "uint32(v.discriminant)", "w"); - } - Int::U64 => { - self.print_write_ty(&Type::U64, "uint64(v.discriminant)", "w"); - } - } + self.print_write_discriminant(variant.tag(), "v.discriminant", "w"); self.push_str("; err != nil { return "); self.push_str(fmt); self.push_str(".Errorf(\"failed to write discriminant: %w\", err)\n}\n"); @@ -1983,20 +2052,7 @@ func (v *{name}) WriteTo(w {wrpc}.ByteWriter) error {{"# r#"func Read{name}(r {wrpc}.ByteReader) (*{name}, error) {{ disc, err := "#, ); - match variant.tag() { - Int::U8 => { - self.print_read_ty(&Type::U8, "r"); - } - Int::U16 => { - self.print_read_ty(&Type::U16, "r"); - } - Int::U32 => { - self.print_read_ty(&Type::U32, "r"); - } - Int::U64 => { - self.print_read_ty(&Type::U64, "r"); - } - } + self.print_read_discriminant(variant.tag(), "r"); self.push_str("\n"); self.push_str("if err != nil {\n"); self.push_str("return nil, "); @@ -2125,20 +2181,7 @@ func (v *{name}) WriteTo(w {wrpc}.ByteWriter) error {{"# r#"func (v {name}) WriteTo(w {wrpc}.ByteWriter) error {{"#, ); self.push_str("if err := "); - match enum_.tag() { - Int::U8 => { - self.print_write_ty(&Type::U8, "uint8(v)", "w"); - } - Int::U16 => { - self.print_write_ty(&Type::U16, "uint16(v)", "w"); - } - Int::U32 => { - self.print_write_ty(&Type::U32, "uint32(v)", "w"); - } - Int::U64 => { - self.print_write_ty(&Type::U64, "uint64(v)", "w"); - } - } + self.print_write_discriminant(enum_.tag(), "v", "w"); self.push_str("; err != nil { return "); self.push_str(fmt); self.push_str(".Errorf(\"failed to write discriminant: %w\", err)\n}\n"); @@ -2150,20 +2193,7 @@ func (v *{name}) WriteTo(w {wrpc}.ByteWriter) error {{"# r#"func Read{name}(r {wrpc}.ByteReader) (v {name}, err error) {{ disc, err := "#, ); - match enum_.tag() { - Int::U8 => { - self.print_read_ty(&Type::U8, "r"); - } - Int::U16 => { - self.print_read_ty(&Type::U16, "r"); - } - Int::U32 => { - self.print_read_ty(&Type::U32, "r"); - } - Int::U64 => { - self.print_read_ty(&Type::U64, "r"); - } - } + self.print_read_discriminant(enum_.tag(), "r"); self.push_str("\n"); self.push_str("if err != nil {\n"); self.push_str("return v, "); diff --git a/examples/go/keyvalue-server/bindings/exports/wrpc/keyvalue/store/bindings.wrpc.go b/examples/go/keyvalue-server/bindings/exports/wrpc/keyvalue/store/bindings.wrpc.go index bb0503f3..12ac8b29 100644 --- a/examples/go/keyvalue-server/bindings/exports/wrpc/keyvalue/store/bindings.wrpc.go +++ b/examples/go/keyvalue-server/bindings/exports/wrpc/keyvalue/store/bindings.wrpc.go @@ -22,7 +22,7 @@ type Error struct { func (v *Error) Discriminant() ErrorDiscriminant { return v.discriminant } -type ErrorDiscriminant uint32 +type ErrorDiscriminant uint8 const ( // The host does not recognize the store identifier requested. @@ -43,7 +43,7 @@ func (v *Error) String() string { case ErrorDiscriminant_Other: return "other" default: - panic("unreachable") + panic("invalid variant") } } @@ -93,13 +93,13 @@ func NewError_Other(payload string) *Error { } func (v *Error) Error() string { return v.String() } func (v *Error) WriteTo(w wrpc.ByteWriter) error { - if err := func(v uint32, w wrpc.ByteWriter) error { - b := make([]byte, binary.MaxVarintLen32) + if err := func(v uint8, w wrpc.ByteWriter) error { + b := make([]byte, 2) i := binary.PutUvarint(b, uint64(v)) - slog.Debug("writing u32") + slog.Debug("writing u8") _, err := w.Write(b[:i]) return err - }(uint32(v.discriminant), w); err != nil { + }(uint8(v.discriminant), w); err != nil { return fmt.Errorf("failed to write discriminant: %w", err) } switch v.discriminant { @@ -135,10 +135,84 @@ func (v *Error) WriteTo(w wrpc.ByteWriter) error { return fmt.Errorf("failed to write payload: %w", err) } default: - panic("unreachable") + panic("invalid variant") } return nil } +func ReadError(r wrpc.ByteReader) (*Error, error) { + disc, err := func(r wrpc.ByteReader) (uint8, error) { + var x uint8 + var s uint + for i := 0; i < 2; i++ { + slog.Debug("reading `uint8` byte", "i", i) + b, err := r.ReadByte() + if err != nil { + if i > 0 && err == io.EOF { + err = io.ErrUnexpectedEOF + } + return x, fmt.Errorf("failed to read `uint8` byte: %w", err) + } + if b < 0x80 { + if i == 2 && b > 1 { + return x, errors.New("varint overflows a 8-bit integer") + } + return x | uint8(b)< 0 && err == io.EOF { + err = io.ErrUnexpectedEOF + } + return "", fmt.Errorf("failed to read string length byte: %w", err) + } + if b < 0x80 { + if i == 4 && b > 1 { + return "", errors.New("string length overflows a 32-bit integer") + } + x = x | uint32(b)<