diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index 0ad507ac6bc725..9a9f8acd979e20 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -55,28 +55,15 @@ u64 mptcp_wnd_end(const struct mptcp_sock *msk) return READ_ONCE(msk->wnd_end); } -static bool mptcp_is_tcpsk(struct sock *sk) +static const struct proto_ops *mptcp_get_fallback_tcp_ops(const struct sock *sk) { - struct socket *sock = sk->sk_socket; - - if (unlikely(sk->sk_prot == &tcp_prot)) { - /* we are being invoked after mptcp_accept() has - * accepted a non-mp-capable flow: sk is a tcp_sk, - * not an mptcp one. - * - * Hand the socket over to tcp so all further socket ops - * bypass mptcp. - */ - WRITE_ONCE(sock->ops, &inet_stream_ops); - return true; + if (unlikely(sk->sk_prot == &tcp_prot)) + return &inet_stream_ops; #if IS_ENABLED(CONFIG_MPTCP_IPV6) - } else if (unlikely(sk->sk_prot == &tcpv6_prot)) { - WRITE_ONCE(sock->ops, &inet6_stream_ops); - return true; + else if (unlikely(sk->sk_prot == &tcpv6_prot)) + return &inet6_stream_ops; #endif - } - - return false; + return NULL; } static int __mptcp_socket_create(struct mptcp_sock *msk) @@ -3832,6 +3819,7 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, int flags, bool kern) { struct mptcp_sock *msk = mptcp_sk(sock->sk); + const struct proto_ops *fallback_ops; struct sock *ssk, *newsk; int err; @@ -3851,7 +3839,8 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, lock_sock(newsk); __inet_accept(sock, newsock, newsk); - if (!mptcp_is_tcpsk(newsock->sk)) { + fallback_ops = mptcp_get_fallback_tcp_ops(newsock->sk); + if (!fallback_ops) { struct mptcp_sock *msk = mptcp_sk(newsk); struct mptcp_subflow_context *subflow; @@ -3877,6 +3866,15 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, if (unlikely(list_is_singular(&msk->conn_list))) inet_sk_state_store(newsk, TCP_CLOSE); } + } else { + /* we are being invoked after mptcp_accept() has + * accepted a non-mp-capable flow: sk is a tcp_sk, + * not an mptcp one. + * + * Hand the socket over to tcp so all further socket ops + * bypass mptcp. + */ + WRITE_ONCE(newsock->sk->sk_socket->ops, fallback_ops); } release_sock(newsk);