diff --git a/projects/Apigen/apigen/Apigen.cs b/projects/Apigen/apigen/Apigen.cs index be534bfd34..63b1bdbe46 100644 --- a/projects/Apigen/apigen/Apigen.cs +++ b/projects/Apigen/apigen/Apigen.cs @@ -952,14 +952,14 @@ public void EmitClassMethodImplementations(AmqpClass c) public void EmitMethodArgumentReader() { - EmitLine(" internal override Client.Impl.MethodBase DecodeMethodFrom(ReadOnlyMemory memory)"); + EmitLine(" internal override Client.Impl.MethodBase DecodeMethodFrom(ReadOnlySpan span)"); EmitLine(" {"); - EmitLine(" ushort classId = Util.NetworkOrderDeserializer.ReadUInt16(memory.Span);"); - EmitLine(" ushort methodId = Util.NetworkOrderDeserializer.ReadUInt16(memory.Slice(2).Span);"); + EmitLine(" ushort classId = Util.NetworkOrderDeserializer.ReadUInt16(span);"); + EmitLine(" ushort methodId = Util.NetworkOrderDeserializer.ReadUInt16(span.Slice(2));"); EmitLine(" Client.Impl.MethodBase result = DecodeMethodFrom(classId, methodId);"); EmitLine(" if(result != null)"); EmitLine(" {"); - EmitLine(" Client.Impl.MethodArgumentReader reader = new Client.Impl.MethodArgumentReader(memory.Slice(4));"); + EmitLine(" Client.Impl.MethodArgumentReader reader = new Client.Impl.MethodArgumentReader(span.Slice(4));"); EmitLine(" result.ReadArgumentsFrom(ref reader);"); EmitLine(" return result;"); EmitLine(" }"); diff --git a/projects/RabbitMQ.Client/client/impl/CommandAssembler.cs b/projects/RabbitMQ.Client/client/impl/CommandAssembler.cs index bfca721d89..cc7c1a11a8 100644 --- a/projects/RabbitMQ.Client/client/impl/CommandAssembler.cs +++ b/projects/RabbitMQ.Client/client/impl/CommandAssembler.cs @@ -81,7 +81,7 @@ public Command HandleFrame(in InboundFrame f) { throw new UnexpectedFrameException(f.Type); } - m_method = m_protocol.DecodeMethodFrom(f.Payload); + m_method = m_protocol.DecodeMethodFrom(f.Payload.Span); m_state = m_method.HasContent ? AssemblyState.ExpectingContentHeader : AssemblyState.Complete; return CompletedCommand(); case AssemblyState.ExpectingContentHeader: @@ -89,8 +89,10 @@ public Command HandleFrame(in InboundFrame f) { throw new UnexpectedFrameException(f.Type); } - m_header = m_protocol.DecodeContentHeaderFrom(NetworkOrderDeserializer.ReadUInt16(f.Payload.Span)); - ulong totalBodyBytes = m_header.ReadFrom(f.Payload.Slice(2)); + + ReadOnlySpan span = f.Payload.Span; + m_header = m_protocol.DecodeContentHeaderFrom(NetworkOrderDeserializer.ReadUInt16(span)); + ulong totalBodyBytes = m_header.ReadFrom(span.Slice(2)); if (totalBodyBytes > MaxArrayOfBytesSize) { throw new UnexpectedFrameException(f.Type); diff --git a/projects/RabbitMQ.Client/client/impl/ContentHeaderBase.cs b/projects/RabbitMQ.Client/client/impl/ContentHeaderBase.cs index 0f34827041..937f2f51e0 100644 --- a/projects/RabbitMQ.Client/client/impl/ContentHeaderBase.cs +++ b/projects/RabbitMQ.Client/client/impl/ContentHeaderBase.cs @@ -67,11 +67,11 @@ public virtual object Clone() /// /// Fill this instance from the given byte buffer stream. /// - internal ulong ReadFrom(ReadOnlyMemory memory) + internal ulong ReadFrom(ReadOnlySpan span) { // Skipping the first two bytes since they arent used (weight - not currently used) - ulong bodySize = NetworkOrderDeserializer.ReadUInt64(memory.Slice(2).Span); - ContentHeaderPropertyReader reader = new ContentHeaderPropertyReader(memory.Slice(10)); + ulong bodySize = NetworkOrderDeserializer.ReadUInt64(span.Slice(2)); + ContentHeaderPropertyReader reader = new ContentHeaderPropertyReader(span.Slice(10)); ReadPropertiesFrom(ref reader); return bodySize; } @@ -81,13 +81,12 @@ internal ulong ReadFrom(ReadOnlyMemory memory) private const ushort ZERO = 0; - internal int WriteTo(Memory memory, ulong bodySize) + internal int WriteTo(Span span, ulong bodySize) { - var span = memory.Span; NetworkOrderSerializer.WriteUInt16(span, ZERO); // Weight - not used NetworkOrderSerializer.WriteUInt64(span.Slice(2), bodySize); - ContentHeaderPropertyWriter writer = new ContentHeaderPropertyWriter(memory.Slice(10)); + ContentHeaderPropertyWriter writer = new ContentHeaderPropertyWriter(span.Slice(10)); WritePropertiesTo(ref writer); return 10 + writer.Offset; } diff --git a/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyReader.cs b/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyReader.cs index e8b4dafea4..d0547f7336 100644 --- a/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyReader.cs +++ b/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyReader.cs @@ -50,19 +50,16 @@ internal ref struct ContentHeaderPropertyReader private const int StartBitMask = 0b1000_0000_0000_0000; private const int EndBitMask = 0b0000_0000_0000_0001; - private readonly ReadOnlyMemory _memory; private readonly ReadOnlySpan _span; private int _offset; private int _bitMask; private int _bits; private ReadOnlySpan Span => _span.Slice(_offset); - private ReadOnlyMemory Memory => _memory.Slice(_offset); - public ContentHeaderPropertyReader(ReadOnlyMemory memory) + public ContentHeaderPropertyReader(ReadOnlySpan span) { - _memory = memory; - _span = memory.Span; + _span = span; _offset = 0; _bitMask = EndBitMask; // force a flag read _bits = 1; // just the continuation bit @@ -141,7 +138,7 @@ public ushort ReadShort() public string ReadShortstr() { - string result = WireFormatting.ReadShortstr(Memory, out int bytesRead); + string result = WireFormatting.ReadShortstr(Span, out int bytesRead); _offset += bytesRead; return result; } @@ -149,7 +146,7 @@ public string ReadShortstr() /// A type of . public Dictionary ReadTable() { - Dictionary result = WireFormatting.ReadTable(Memory, out int bytesRead); + Dictionary result = WireFormatting.ReadTable(Span, out int bytesRead); _offset += bytesRead; return result; } diff --git a/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyWriter.cs b/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyWriter.cs index 9dbdc2b22a..47e8a9ef5f 100644 --- a/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyWriter.cs +++ b/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyWriter.cs @@ -50,7 +50,6 @@ internal ref struct ContentHeaderPropertyWriter private const ushort StartBitMask = 0b1000_0000_0000_0000; private const ushort EndBitMask = 0b0000_0000_0000_0001; - private readonly Memory _memory; private readonly Span _span; private int _offset; private ushort _bitAccumulator; @@ -59,12 +58,10 @@ internal ref struct ContentHeaderPropertyWriter public int Offset => _offset; private Span Span => _span.Slice(_offset); - private Memory Memory => _memory.Slice(_offset); - public ContentHeaderPropertyWriter(Memory memory) + public ContentHeaderPropertyWriter(Span span) { - _memory = memory; - _span = _memory.Span; + _span = span; _offset = 0; _bitAccumulator = 0; _bitMask = StartBitMask; @@ -124,12 +121,12 @@ public void WriteShort(ushort val) public void WriteShortstr(string val) { - _offset += WireFormatting.WriteShortstr(Memory, val); + _offset += WireFormatting.WriteShortstr(Span, val); } public void WriteTable(IDictionary val) { - _offset += WireFormatting.WriteTable(Memory, val); + _offset += WireFormatting.WriteTable(Span, val); } public void WriteTimestamp(AmqpTimestamp val) diff --git a/projects/RabbitMQ.Client/client/impl/Frame.cs b/projects/RabbitMQ.Client/client/impl/Frame.cs index 75c99d24c3..1bddfbaa2b 100644 --- a/projects/RabbitMQ.Client/client/impl/Frame.cs +++ b/projects/RabbitMQ.Client/client/impl/Frame.cs @@ -66,12 +66,12 @@ internal override int GetMinimumPayloadBufferSize() return 2 + _header.GetRequiredBufferSize(); } - internal override int WritePayload(Memory memory) + internal override int WritePayload(Span span) { // write protocol class id (2 bytes) - NetworkOrderSerializer.WriteUInt16(memory.Span, _header.ProtocolClassId); + NetworkOrderSerializer.WriteUInt16(span, _header.ProtocolClassId); // write header (X bytes) - int bytesWritten = _header.WriteTo(memory.Slice(2), (ulong)_bodyLength); + int bytesWritten = _header.WriteTo(span.Slice(2), (ulong)_bodyLength); return bytesWritten + 2; } } @@ -90,9 +90,9 @@ internal override int GetMinimumPayloadBufferSize() return _body.Length; } - internal override int WritePayload(Memory memory) + internal override int WritePayload(Span span) { - _body.CopyTo(memory); + _body.Span.CopyTo(span); return _body.Length; } } @@ -112,12 +112,11 @@ internal override int GetMinimumPayloadBufferSize() return 4 + _method.GetRequiredBufferSize(); } - internal override int WritePayload(Memory memory) + internal override int WritePayload(Span span) { - var span = memory.Span; NetworkOrderSerializer.WriteUInt16(span, _method.ProtocolClassId); NetworkOrderSerializer.WriteUInt16(span.Slice(2), _method.ProtocolMethodId); - var argWriter = new MethodArgumentWriter(memory.Slice(4)); + var argWriter = new MethodArgumentWriter(span.Slice(4)); _method.WriteArgumentsTo(ref argWriter); return 4 + argWriter.Offset; } @@ -134,7 +133,7 @@ internal override int GetMinimumPayloadBufferSize() return 0; } - internal override int WritePayload(Memory memory) + internal override int WritePayload(Span span) { return 0; } @@ -151,17 +150,16 @@ protected OutboundFrame(FrameType type, int channel) Channel = channel; } - internal void WriteTo(Memory memory) + internal void WriteTo(Span span) { - var span = memory.Span; span[0] = (byte)Type; NetworkOrderSerializer.WriteUInt16(span.Slice(1), (ushort)Channel); - int bytesWritten = WritePayload(memory.Slice(7)); + int bytesWritten = WritePayload(span.Slice(7)); NetworkOrderSerializer.WriteUInt32(span.Slice(3), (uint)bytesWritten); span[bytesWritten + 7] = Constants.FrameEnd; } - internal abstract int WritePayload(Memory memory); + internal abstract int WritePayload(Span span); internal abstract int GetMinimumPayloadBufferSize(); internal int GetMinimumBufferSize() { diff --git a/projects/RabbitMQ.Client/client/impl/MainSession.cs b/projects/RabbitMQ.Client/client/impl/MainSession.cs index e440a04d3f..6b863643c5 100644 --- a/projects/RabbitMQ.Client/client/impl/MainSession.cs +++ b/projects/RabbitMQ.Client/client/impl/MainSession.cs @@ -84,7 +84,7 @@ public override void HandleFrame(in InboundFrame frame) if (!_closeServerInitiated && frame.IsMethod()) { - MethodBase method = Connection.Protocol.DecodeMethodFrom(frame.Payload); + MethodBase method = Connection.Protocol.DecodeMethodFrom(frame.Payload.Span); if ((method.ProtocolClassId == _closeClassId) && (method.ProtocolMethodId == _closeMethodId)) { diff --git a/projects/RabbitMQ.Client/client/impl/MethodArgumentReader.cs b/projects/RabbitMQ.Client/client/impl/MethodArgumentReader.cs index df7ce038e7..ed49fbcbc2 100644 --- a/projects/RabbitMQ.Client/client/impl/MethodArgumentReader.cs +++ b/projects/RabbitMQ.Client/client/impl/MethodArgumentReader.cs @@ -47,19 +47,16 @@ namespace RabbitMQ.Client.Impl { internal ref struct MethodArgumentReader { - private readonly ReadOnlyMemory _memory; private readonly ReadOnlySpan _span; private int _offset; private int _bitMask; private int _bits; private ReadOnlySpan Span => _span.Slice(_offset); - private ReadOnlyMemory Memory => _memory.Slice(_offset); - public MethodArgumentReader(ReadOnlyMemory memory) + public MethodArgumentReader(ReadOnlySpan span) { - _memory = memory; - _span = memory.Span; + _span = span; _offset = 0; _bitMask = 0; _bits = 0; @@ -119,14 +116,14 @@ public ushort ReadShort() public string ReadShortstr() { - string result = WireFormatting.ReadShortstr(Memory, out int bytesRead); + string result = WireFormatting.ReadShortstr(Span, out int bytesRead); _offset += bytesRead; return result; } public Dictionary ReadTable() { - Dictionary result = WireFormatting.ReadTable(Memory, out int bytesRead); + Dictionary result = WireFormatting.ReadTable(Span, out int bytesRead); _offset += bytesRead; return result; } diff --git a/projects/RabbitMQ.Client/client/impl/MethodArgumentWriter.cs b/projects/RabbitMQ.Client/client/impl/MethodArgumentWriter.cs index 151522765e..3e15f03558 100644 --- a/projects/RabbitMQ.Client/client/impl/MethodArgumentWriter.cs +++ b/projects/RabbitMQ.Client/client/impl/MethodArgumentWriter.cs @@ -46,7 +46,6 @@ namespace RabbitMQ.Client.Impl { internal ref struct MethodArgumentWriter { - private readonly Memory _memory; private readonly Span _span; private int _offset; private int _bitAccumulator; @@ -55,12 +54,10 @@ internal ref struct MethodArgumentWriter public int Offset => _offset; private Span Span => _span.Slice(_offset); - private Memory Memory => _memory.Slice(_offset); - public MethodArgumentWriter(Memory memory) + public MethodArgumentWriter(Span span) { - _memory = memory; - _span = memory.Span; + _span = span; _offset = 0; _bitAccumulator = 0; _bitMask = 1; @@ -117,17 +114,17 @@ public void WriteShort(ushort val) public void WriteShortstr(string val) { - _offset += WireFormatting.WriteShortstr(Memory, val); + _offset += WireFormatting.WriteShortstr(Span, val); } public void WriteTable(IDictionary val) { - _offset += WireFormatting.WriteTable(Memory, val); + _offset += WireFormatting.WriteTable(Span, val); } public void WriteTable(IDictionary val) { - _offset += WireFormatting.WriteTable(Memory, val); + _offset += WireFormatting.WriteTable(Span, val); } public void WriteTimestamp(AmqpTimestamp val) diff --git a/projects/RabbitMQ.Client/client/impl/ProtocolBase.cs b/projects/RabbitMQ.Client/client/impl/ProtocolBase.cs index f4be4336aa..6ae4ed75c5 100644 --- a/projects/RabbitMQ.Client/client/impl/ProtocolBase.cs +++ b/projects/RabbitMQ.Client/client/impl/ProtocolBase.cs @@ -106,7 +106,7 @@ public void CreateConnectionClose(ushort reasonCode, } internal abstract ContentHeaderBase DecodeContentHeaderFrom(ushort classId); - internal abstract MethodBase DecodeMethodFrom(ReadOnlyMemory reader); + internal abstract MethodBase DecodeMethodFrom(ReadOnlySpan reader); public override bool Equals(object obj) { diff --git a/projects/RabbitMQ.Client/client/impl/QuiescingSession.cs b/projects/RabbitMQ.Client/client/impl/QuiescingSession.cs index 633426b133..ca5a1209f3 100644 --- a/projects/RabbitMQ.Client/client/impl/QuiescingSession.cs +++ b/projects/RabbitMQ.Client/client/impl/QuiescingSession.cs @@ -59,7 +59,7 @@ public override void HandleFrame(in InboundFrame frame) { if (frame.IsMethod()) { - MethodBase method = Connection.Protocol.DecodeMethodFrom(frame.Payload); + MethodBase method = Connection.Protocol.DecodeMethodFrom(frame.Payload.Span); if ((method.ProtocolClassId == ClassConstants.Channel) && (method.ProtocolMethodId == ChannelMethodConstants.CloseOk)) { diff --git a/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs b/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs index ed21562aa7..33c93413ee 100644 --- a/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs +++ b/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs @@ -252,8 +252,7 @@ public async Task WriteFrameImpl() { int bufferSize = frame.GetMinimumBufferSize(); byte[] memoryArray = ArrayPool.Shared.Rent(bufferSize); - Memory slice = new Memory(memoryArray, 0, bufferSize); - frame.WriteTo(slice); + frame.WriteTo(new Span(memoryArray, 0, bufferSize)); _writer.Write(memoryArray, 0, bufferSize); ArrayPool.Shared.Return(memoryArray); } diff --git a/projects/RabbitMQ.Client/client/impl/WireFormatting.cs b/projects/RabbitMQ.Client/client/impl/WireFormatting.cs index bc2fb803b2..4df31883f0 100644 --- a/projects/RabbitMQ.Client/client/impl/WireFormatting.cs +++ b/projects/RabbitMQ.Client/client/impl/WireFormatting.cs @@ -41,7 +41,6 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Runtime.InteropServices; using System.Text; using RabbitMQ.Client.Exceptions; @@ -93,14 +92,14 @@ public static void DecimalToAmqp(decimal value, out byte scale, out int mantissa (((uint)bitRepresentation[0]) & 0x7FFFFFFF)); } - public static IList ReadArray(ReadOnlyMemory memory, out int bytesRead) + public static IList ReadArray(ReadOnlySpan span, out int bytesRead) { List array = new List(); - long arrayLength = NetworkOrderDeserializer.ReadUInt32(memory.Span); + long arrayLength = NetworkOrderDeserializer.ReadUInt32(span); bytesRead = 4; while (bytesRead - 4 < arrayLength) { - object value = ReadFieldValue(memory.Slice(bytesRead), out int fieldValueBytesRead); + object value = ReadFieldValue(span.Slice(bytesRead), out int fieldValueBytesRead); bytesRead += fieldValueBytesRead; array.Add(value); } @@ -115,65 +114,64 @@ public static decimal ReadDecimal(ReadOnlySpan span) return AmqpToDecimal(scale, unsignedMantissa); } - public static object ReadFieldValue(ReadOnlyMemory memory, out int bytesRead) + public static object ReadFieldValue(ReadOnlySpan span, out int bytesRead) { bytesRead = 1; - ReadOnlyMemory slice = memory.Slice(1); - switch ((char)memory.Span[0]) + switch ((char)span[0]) { case 'S': - byte[] result = ReadLongstr(slice.Span); + byte[] result = ReadLongstr(span.Slice(1)); bytesRead += result.Length + 4; return result; case 'I': bytesRead += 4; - return NetworkOrderDeserializer.ReadInt32(slice.Span); + return NetworkOrderDeserializer.ReadInt32(span.Slice(1)); case 'i': bytesRead += 4; - return NetworkOrderDeserializer.ReadUInt32(slice.Span); + return NetworkOrderDeserializer.ReadUInt32(span.Slice(1)); case 'D': bytesRead += 5; - return ReadDecimal(slice.Span); + return ReadDecimal(span.Slice(1)); case 'T': bytesRead += 8; - return ReadTimestamp(slice.Span); + return ReadTimestamp(span.Slice(1)); case 'F': - Dictionary tableResult = ReadTable(slice, out int tableBytesRead); + Dictionary tableResult = ReadTable(span.Slice(1), out int tableBytesRead); bytesRead += tableBytesRead; return tableResult; case 'A': - IList arrayResult = ReadArray(slice, out int arrayBytesRead); + IList arrayResult = ReadArray(span.Slice(1), out int arrayBytesRead); bytesRead += arrayBytesRead; return arrayResult; case 'B': bytesRead += 1; - return slice.Span[0]; + return span[1]; case 'b': bytesRead += 1; - return (sbyte)slice.Span[0]; + return (sbyte)span[1]; case 'd': bytesRead += 8; - return NetworkOrderDeserializer.ReadDouble(slice.Span); + return NetworkOrderDeserializer.ReadDouble(span.Slice(1)); case 'f': bytesRead += 4; - return NetworkOrderDeserializer.ReadSingle(slice.Span); + return NetworkOrderDeserializer.ReadSingle(span.Slice(1)); case 'l': bytesRead += 8; - return NetworkOrderDeserializer.ReadInt64(slice.Span); + return NetworkOrderDeserializer.ReadInt64(span.Slice(1)); case 's': bytesRead += 2; - return NetworkOrderDeserializer.ReadInt16(slice.Span); + return NetworkOrderDeserializer.ReadInt16(span.Slice(1)); case 't': bytesRead += 1; - return slice.Span[0] != 0; + return span[1] != 0; case 'x': - byte[] binaryTableResult = ReadLongstr(slice.Span); + byte[] binaryTableResult = ReadLongstr(span.Slice(1)); bytesRead += binaryTableResult.Length + 4; return new BinaryTableValue(binaryTableResult); case 'V': return null; default: - throw new SyntaxErrorException($"Unrecognised type in table: {(char)memory.Span[0]}"); + throw new SyntaxErrorException($"Unrecognised type in table: {(char)span[0]}"); } } @@ -188,17 +186,19 @@ public static byte[] ReadLongstr(ReadOnlySpan span) return span.Slice(4, (int)byteCount).ToArray(); } - public static string ReadShortstr(ReadOnlyMemory memory, out int bytesRead) + public static unsafe string ReadShortstr(ReadOnlySpan span, out int bytesRead) { - int byteCount = memory.Span[0]; - ReadOnlyMemory stringSlice = memory.Slice(1, byteCount); - if (MemoryMarshal.TryGetArray(stringSlice, out ArraySegment segment)) + int byteCount = span[0]; + if (span.Length >= byteCount + 1) { bytesRead = 1 + byteCount; - return Encoding.UTF8.GetString(segment.Array, segment.Offset, segment.Count); + fixed (byte* bytes = &span.GetPinnableReference()) + { + return Encoding.UTF8.GetString(bytes, byteCount); + } } - throw new InvalidOperationException("Unable to get ArraySegment from memory"); + throw new ArgumentOutOfRangeException(nameof(span), $"Span has not enough space ({span.Length} instead of {byteCount + 1})"); } ///Reads an AMQP "table" definition from the reader. @@ -208,10 +208,10 @@ public static string ReadShortstr(ReadOnlyMemory memory, out int bytesRead /// x and V types and the AMQP 0-9-1 A type. /// /// A . - public static Dictionary ReadTable(ReadOnlyMemory memory, out int bytesRead) + public static Dictionary ReadTable(ReadOnlySpan span, out int bytesRead) { bytesRead = 4; - long tableLength = NetworkOrderDeserializer.ReadUInt32(memory.Span); + long tableLength = NetworkOrderDeserializer.ReadUInt32(span); if (tableLength == 0) { return null; @@ -220,9 +220,9 @@ public static Dictionary ReadTable(ReadOnlyMemory memory, Dictionary table = new Dictionary(); while ((bytesRead - 4) < tableLength) { - string key = ReadShortstr(memory.Slice(bytesRead), out int keyBytesRead); + string key = ReadShortstr(span.Slice(bytesRead), out int keyBytesRead); bytesRead += keyBytesRead; - object value = ReadFieldValue(memory.Slice(bytesRead), out int valueBytesRead); + object value = ReadFieldValue(span.Slice(bytesRead), out int valueBytesRead); bytesRead += valueBytesRead; if (!table.ContainsKey(key)) @@ -242,11 +242,11 @@ public static AmqpTimestamp ReadTimestamp(ReadOnlySpan span) return new AmqpTimestamp((long)stamp); } - public static int WriteArray(Memory memory, IList val) + public static int WriteArray(Span span, IList val) { if (val == null) { - NetworkOrderSerializer.WriteUInt32(memory.Span, 0); + NetworkOrderSerializer.WriteUInt32(span, 0); return 4; } else @@ -254,10 +254,10 @@ public static int WriteArray(Memory memory, IList val) int bytesWritten = 0; for (int index = 0; index < val.Count; index++) { - bytesWritten += WriteFieldValue(memory.Slice(4 + bytesWritten), val[index]); + bytesWritten += WriteFieldValue(span.Slice(4 + bytesWritten), val[index]); } - NetworkOrderSerializer.WriteUInt32(memory.Span, (uint)bytesWritten); + NetworkOrderSerializer.WriteUInt32(span, (uint)bytesWritten); return 4 + bytesWritten; } } @@ -285,81 +285,74 @@ public static int WriteDecimal(Span span, decimal value) return 1 + WriteLong(span.Slice(1), (uint)mantissa); } - public static int WriteFieldValue(Memory memory, object value) + public static int WriteFieldValue(Span span, object value) { if (value == null) { - memory.Span[0] = (byte)'V'; + span[0] = (byte)'V'; return 1; } - Memory slice = memory.Slice(1); + Span slice = span.Slice(1); switch (value) { case string val: - memory.Span[0] = (byte)'S'; - if (MemoryMarshal.TryGetArray(memory, out ArraySegment segment)) - { - int bytesWritten = Encoding.UTF8.GetBytes(val, 0, val.Length, segment.Array, segment.Offset + 5); - NetworkOrderSerializer.WriteUInt32(slice.Span, (uint)bytesWritten); - return 5 + bytesWritten; - } - - throw new WireFormattingException("Unable to get array segment from memory."); + span[0] = (byte)'S'; + return 1 + WriteLongstr(slice, val); case byte[] val: - memory.Span[0] = (byte)'S'; - return 1 + WriteLongstr(slice.Span, val); + span[0] = (byte)'S'; + return 1 + WriteLongstr(slice, val); case int val: - memory.Span[0] = (byte)'I'; - NetworkOrderSerializer.WriteInt32(slice.Span, val); + span[0] = (byte)'I'; + NetworkOrderSerializer.WriteInt32(slice, val); return 5; case uint val: - memory.Span[0] = (byte)'i'; - NetworkOrderSerializer.WriteUInt32(slice.Span, val); + span[0] = (byte)'i'; + NetworkOrderSerializer.WriteUInt32(slice, val); return 5; case decimal val: - memory.Span[0] = (byte)'D'; - return 1 + WriteDecimal(slice.Span, val); + span[0] = (byte)'D'; + return 1 + WriteDecimal(slice, val); case AmqpTimestamp val: - memory.Span[0] = (byte)'T'; - return 1 + WriteTimestamp(slice.Span, val); + span[0] = (byte)'T'; + return 1 + WriteTimestamp(slice, val); case IDictionary val: - memory.Span[0] = (byte)'F'; + span[0] = (byte)'F'; return 1 + WriteTable(slice, val); case IList val: - memory.Span[0] = (byte)'A'; + span[0] = (byte)'A'; return 1 + WriteArray(slice, val); case byte val: - memory.Span[0] = (byte)'B'; - memory.Span[1] = val; + span[0] = (byte)'B'; + span[1] = val; return 2; case sbyte val: - memory.Span[0] = (byte)'b'; - memory.Span[1] = (byte)val; + span[0] = (byte)'b'; + span[1] = (byte)val; return 2; case double val: - memory.Span[0] = (byte)'d'; - NetworkOrderSerializer.WriteDouble(slice.Span, val); + span[0] = (byte)'d'; + NetworkOrderSerializer.WriteDouble(slice, val); return 9; case float val: - memory.Span[0] = (byte)'f'; - NetworkOrderSerializer.WriteSingle(slice.Span, val); + span[0] = (byte)'f'; + NetworkOrderSerializer.WriteSingle(slice, val); return 5; case long val: - memory.Span[0] = (byte)'l'; - NetworkOrderSerializer.WriteInt64(slice.Span, val); + span[0] = (byte)'l'; + NetworkOrderSerializer.WriteInt64(slice, val); return 9; case short val: - memory.Span[0] = (byte)'s'; - NetworkOrderSerializer.WriteInt16(slice.Span, val); + span[0] = (byte)'s'; + NetworkOrderSerializer.WriteInt16(slice, val); return 3; case bool val: - memory.Span[0] = (byte)'t'; - memory.Span[1] = (byte)(val ? 1 : 0); + span[0] = (byte)'t'; + span[1] = (byte)(val ? 1 : 0); return 2; case BinaryTableValue val: - memory.Span[0] = (byte)'x'; - return 1 + WriteLongstr(slice.Span, val.Bytes); + span[0] = (byte)'x'; + return 1 + WriteLongstr(slice, val.Bytes); default: throw new WireFormattingException($"Value of type '{value.GetType().Name}' cannot appear as table value", value); } @@ -427,34 +420,44 @@ public static int WriteShort(Span span, ushort val) return 2; } - public static int WriteShortstr(Memory memory, string val) + public static unsafe int WriteShortstr(Span span, string val) { - if (MemoryMarshal.TryGetArray(memory, out ArraySegment segment)) + int maxLength = span.Length - 1; + if (maxLength > byte.MaxValue) { - int bytesWritten = Encoding.UTF8.GetBytes(val, 0, val.Length, segment.Array, segment.Offset + 1); - if (bytesWritten <= byte.MaxValue) - { - segment.Array[segment.Offset] = (byte)bytesWritten; - return bytesWritten + 1; - } - - throw new ArgumentOutOfRangeException(nameof(val), val, "Value exceeds the maximum allowed length of 255 bytes."); + maxLength = byte.MaxValue; } + fixed (char* chars = val) + fixed (byte* bytes = &span.Slice(1).GetPinnableReference()) + { + int bytesWritten = Encoding.UTF8.GetBytes(chars, val.Length, bytes, maxLength); + span[0] = (byte)bytesWritten; + return bytesWritten + 1; + } + } - throw new WireFormattingException("Unable to get array segment from memory."); + public static unsafe int WriteLongstr(Span span, string val) + { + fixed (char* chars = val) + fixed (byte* bytes = &span.Slice(4).GetPinnableReference()) + { + int bytesWritten = Encoding.UTF8.GetBytes(chars, val.Length, bytes, span.Length); + NetworkOrderSerializer.WriteUInt32(span, (uint)bytesWritten); + return bytesWritten + 4; + } } - public static int WriteTable(Memory memory, IDictionary val) + public static int WriteTable(Span span, IDictionary val) { if (val == null || val.Count == 0) { - NetworkOrderSerializer.WriteUInt32(memory.Span, 0); + NetworkOrderSerializer.WriteUInt32(span, 0); return 4; } else { // Let's only write after the length header. - Memory slice = memory.Slice(4); + Span slice = span.Slice(4); int bytesWritten = 0; foreach (DictionaryEntry entry in val) { @@ -462,22 +465,22 @@ public static int WriteTable(Memory memory, IDictionary val) bytesWritten += WriteFieldValue(slice.Slice(bytesWritten), entry.Value); } - NetworkOrderSerializer.WriteUInt32(memory.Span, (uint)bytesWritten); + NetworkOrderSerializer.WriteUInt32(span, (uint)bytesWritten); return 4 + bytesWritten; } } - public static int WriteTable(Memory memory, IDictionary val) + public static int WriteTable(Span span, IDictionary val) { if (val == null || val.Count == 0) { - NetworkOrderSerializer.WriteUInt32(memory.Span, 0); + NetworkOrderSerializer.WriteUInt32(span, 0); return 4; } else { // Let's only write after the length header. - Memory slice = memory.Slice(4); + Span slice = span.Slice(4); int bytesWritten = 0; if (val is Dictionary dict) { @@ -496,7 +499,7 @@ public static int WriteTable(Memory memory, IDictionary va } } - NetworkOrderSerializer.WriteUInt32(memory.Span, (uint)bytesWritten); + NetworkOrderSerializer.WriteUInt32(span, (uint)bytesWritten); return 4 + bytesWritten; } } diff --git a/projects/Unit/TestBasicProperties.cs b/projects/Unit/TestBasicProperties.cs index 1e05af18da..619c2d29a8 100644 --- a/projects/Unit/TestBasicProperties.cs +++ b/projects/Unit/TestBasicProperties.cs @@ -105,13 +105,13 @@ public void TestNullableProperties_CanWrite( bool isMessageIdPresent = messageId != null; Assert.AreEqual(isMessageIdPresent, subject.IsMessageIdPresent()); - Memory memory = new byte[1024]; - var writer = new Impl.ContentHeaderPropertyWriter(memory); + Span span = new byte[1024]; + var writer = new Impl.ContentHeaderPropertyWriter(span); subject.WritePropertiesTo(ref writer); // Read from Stream var propertiesFromStream = new Framing.BasicProperties(); - var reader = new Impl.ContentHeaderPropertyReader(memory.Slice(0, writer.Offset)); + var reader = new Impl.ContentHeaderPropertyReader(span.Slice(0, writer.Offset)); propertiesFromStream.ReadPropertiesFrom(ref reader); Assert.AreEqual(clusterId, propertiesFromStream.ClusterId); @@ -139,13 +139,13 @@ public void TestProperties_ReplyTo([Values(null, "foo_1", "fanout://name/key")] string replyToAddress = result?.ToString(); Assert.AreEqual(isReplyToPresent, subject.IsReplyToPresent()); - Memory memory = new byte[1024]; - var writer = new Impl.ContentHeaderPropertyWriter(memory); + Span span = new byte[1024]; + var writer = new Impl.ContentHeaderPropertyWriter(span); subject.WritePropertiesTo(ref writer); // Read from Stream var propertiesFromStream = new Framing.BasicProperties(); - var reader = new Impl.ContentHeaderPropertyReader(memory.Slice(0, writer.Offset)); + var reader = new Impl.ContentHeaderPropertyReader(span.Slice(0, writer.Offset)); propertiesFromStream.ReadPropertiesFrom(ref reader); Assert.AreEqual(replyTo, propertiesFromStream.ReplyTo); diff --git a/projects/Unit/TestFieldTableFormatting.cs b/projects/Unit/TestFieldTableFormatting.cs index 1025adb7a4..5da8a035b3 100644 --- a/projects/Unit/TestFieldTableFormatting.cs +++ b/projects/Unit/TestFieldTableFormatting.cs @@ -139,7 +139,7 @@ [new string('A', TooLarge)] = null int bytesNeeded = WireFormatting.GetTableByteCount(t); byte[] bytes = new byte[bytesNeeded]; - Assert.Throws(() => WireFormatting.WriteTable(bytes, t)); + Assert.Throws(() => WireFormatting.WriteTable(bytes, t)); } [Test]