Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protect against async race conditions when parsing incoming byte chunks. #373

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions lib/src/message_window.dart
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class MessageFramer {

int? _type;
int _expectedLength = 0;
bool _isProcessing = false;
Queue<Uint8List>? _pendingChunks;

bool get _hasReadHeader => _type != null;
bool get _canReadHeader => _reader.remainingLength >= _headerByteSize;
Expand All @@ -52,6 +54,31 @@ class MessageFramer {
_expectedLength == 0 || _expectedLength <= _reader.remainingLength;

Future<void> addBytes(Uint8List bytes) async {
// Since the message framing became async, the message processing may overlap
// with the chunk arrival, and we may need to defend against race conditions
// that would defeat the expected vs remaining length checks.
// This solution queues up the pending bytes, while the alternative would require
// us to pre-frame the bytes into fixed message chunks.
if (_isProcessing) {
_pendingChunks ??= Queue<Uint8List>();
_pendingChunks!.add(bytes);
return;
}
_isProcessing = true;
try {
await _addBytes(bytes);
if (_pendingChunks != null) {
while (_pendingChunks!.isNotEmpty) {
await _addBytes(_pendingChunks!.removeFirst());
}
_pendingChunks = null;
}
} finally {
_isProcessing = false;
}
}

Future<void> _addBytes(Uint8List bytes) async {
_reader.add(bytes);

while (true) {
Expand Down
14 changes: 6 additions & 8 deletions lib/src/v3/connection.dart
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class PgConnectionImplementation extends _PgSessionBase implements Connection {

/// Whether [_channel] is backed by a TLS connection.
final bool _channelIsSecure;
late final StreamSubscription<Message> _serverMessages;
late final StreamSubscription _serverMessages;
bool _isClosing = false;

_PendingOperation? _pending;
Expand Down Expand Up @@ -403,7 +403,8 @@ class PgConnectionImplementation extends _PgSessionBase implements Connection {
required this.info,
}) : _relationTracker = relationTracker {
_serverMessages = _channel.stream
.listen(_handleMessage, onDone: _socketClosed, onError: (e, s) {
.asyncMap(_handleMessage)
.listen((_) {}, onDone: _socketClosed, onError: (e, s) {
_close(
true,
PgException('Socket error: $e'),
Expand Down Expand Up @@ -439,10 +440,7 @@ class PgConnectionImplementation extends _PgSessionBase implements Connection {
}

Future<void> _handleMessage(Message message) async {
_serverMessages.pause();
try {
message as ServerMessage;

if (message is ServerMessage) {
if (message is XLogDataLogicalMessage) {
final embedded = message.message;
if (embedded is RelationMessage) {
Expand Down Expand Up @@ -472,8 +470,8 @@ class PgConnectionImplementation extends _PgSessionBase implements Connection {
} else if (_pending != null) {
await _pending!.handleMessage(message);
}
} finally {
_serverMessages.resume();
} else {
throw UnimplementedError('Unknown message: $message');
}
}

Expand Down
36 changes: 12 additions & 24 deletions lib/src/v3/protocol.dart
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,12 @@ StreamTransformer<Uint8List, ServerMessage> _readMessages(
return Stream.multi((listener) {
final framer = MessageFramer(codecContext);

var paused = false;

void emitFinishedMessages() {
while (framer.hasMessage) {
if (listener.isClosed) {
break;
}
listener.addSync(framer.popMessage());

if (paused) break;
}
}

Expand All @@ -77,27 +76,16 @@ StreamTransformer<Uint8List, ServerMessage> _readMessages(
// Don't cancel this subscription on error! If the listener wants that,
// they'll unsubscribe in time after we forward it synchronously.
final rawSubscription =
rawStream.listen(handleChunk, cancelOnError: false)
rawStream.asyncMap(handleChunk).listen((_) {}, cancelOnError: false)
..onError(listener.addErrorSync)
..onDone(listener.closeSync);

listener.onPause = () {
paused = true;
rawSubscription.pause();
};

listener.onResume = () {
paused = false;
emitFinishedMessages();

if (!paused) {
rawSubscription.resume();
}
};

listener.onCancel = () {
paused = true;
rawSubscription.cancel();
..onDone(() async {
await framer.addBytes(Uint8List(0));
emitFinishedMessages();
await listener.close();
});

listener.onCancel = () async {
await rawSubscription.cancel();
};
});
});
Expand Down
Loading