Index | Thread | Search

From:
Vitaliy Makkoveev <mvs@openbsd.org>
Subject:
Don't take solock() in soreceive() for tcp(4) sockets
To:
tech@openbsd.org
Cc:
Hrvoje Popovski <hrvoje@srce.hr>
Date:
Tue, 16 Apr 2024 13:36:40 +0300

Download raw body.

Thread
Make the reception locking fine grained too. pru_rcvd() requires
solock(), so take it around. I don't know is the unlocked connection
state SS_ bits check safe, so take `sb_mtx' around corresponding
`so_state' modifications. If this serialization is not required, it
could be reworked with another diff later.

So, the reception of all the inet sockets moved out of netlock.

Index: sys/kern/uipc_socket.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_socket.c,v
retrieving revision 1.330
diff -u -p -r1.330 uipc_socket.c
--- sys/kern/uipc_socket.c	15 Apr 2024 21:31:29 -0000	1.330
+++ sys/kern/uipc_socket.c	16 Apr 2024 10:20:11 -0000
@@ -157,12 +157,7 @@ soalloc(const struct protosw *prp, int w
 	switch (dp->dom_family) {
 	case AF_INET:
 	case AF_INET6:
-		switch (prp->pr_type) {
-		case SOCK_DGRAM:
-		case SOCK_RAW:
-			so->so_rcv.sb_flags |= SB_MTXLOCK | SB_OWNLOCK;
-			break;
-		}
+		so->so_rcv.sb_flags |= SB_MTXLOCK | SB_OWNLOCK;
 		break;
 	case AF_UNIX:
 		so->so_rcv.sb_flags |= SB_MTXLOCK;
@@ -348,20 +343,18 @@ sofree(struct socket *so, int keep_lock)
 #endif /* SOCKET_SPLICE */
 	sbrelease(so, &so->so_snd);
 
+	if (!keep_lock)
+		sounlock(so);
+
 	/*
-	 * Regardless on '_locked' postfix, must release solock() before
-	 * call sorflush_locked() for SB_OWNLOCK marked socket. Can't
-	 * release solock() and call sorflush() because solock() release
-	 * is unwanted for tcp(4) socket. 
+	 * The socket was previously unspliced. The unlocked `so_rcv'
+	 * cleanup is safe even if we have concurrent somove() thread.
 	 */
 
-	if (so->so_rcv.sb_flags & SB_OWNLOCK)
-		sounlock(so);
-
-	sorflush_locked(so);
-
-	if (!((so->so_rcv.sb_flags & SB_OWNLOCK) || keep_lock))
-		sounlock(so);
+	if (so->so_proto->pr_flags & PR_RIGHTS &&
+	    so->so_proto->pr_domain->dom_dispose)
+		(*so->so_proto->pr_domain->dom_dispose)(so->so_rcv.sb_mb);
+	m_purge(so->so_rcv.sb_mb);
 
 #ifdef SOCKET_SPLICE
 	if (so->so_sp) {
@@ -1222,7 +1215,11 @@ dontblock:
 		SBLASTMBUFCHK(&so->so_rcv, "soreceive 4");
 		if (pr->pr_flags & PR_WANTRCVD) {
 			sb_mtx_unlock(&so->so_rcv);
+			if (!dosolock)
+				solock(so);
 			pru_rcvd(so);
+			if (!dosolock)
+				sounlock(so);
 			sb_mtx_lock(&so->so_rcv);
 		}
 	}
@@ -1358,17 +1355,9 @@ sosplice(struct socket *so, int fd, off_
 		membar_consumer();
 	}
 
-	if (so->so_rcv.sb_flags & SB_OWNLOCK) {
-		if ((error = sblock(so, &so->so_rcv, SBL_WAIT)) != 0)
-			return (error);
-		solock(so);
-	} else {
-		solock(so);
-		if ((error = sblock(so, &so->so_rcv, SBL_WAIT)) != 0) {
-			sounlock(so);
-			return (error);
-		}
-	}
+	if ((error = sblock(so, &so->so_rcv, SBL_WAIT)) != 0)
+		return (error);
+	solock(so);
 
 	if (so->so_options & SO_ACCEPTCONN) {
 		error = EOPNOTSUPP;
@@ -1446,13 +1435,8 @@ sosplice(struct socket *so, int fd, off_
  release:
 	sbunlock(sosp, &sosp->so_snd);
  out:
-	if (so->so_rcv.sb_flags & SB_OWNLOCK) {
-		sounlock(so);
-		sbunlock(so, &so->so_rcv);
-	} else {
-		sbunlock(so, &so->so_rcv);
-		sounlock(so);
-	}
+	sounlock(so);
+	sbunlock(so, &so->so_rcv);
 
 	if (fp)
 		FRELE(fp, curproc);
Index: sys/kern/uipc_socket2.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_socket2.c,v
retrieving revision 1.149
diff -u -p -r1.149 uipc_socket2.c
--- sys/kern/uipc_socket2.c	11 Apr 2024 13:32:51 -0000	1.149
+++ sys/kern/uipc_socket2.c	16 Apr 2024 10:20:11 -0000
@@ -86,8 +86,10 @@ void
 soisconnecting(struct socket *so)
 {
 	soassertlocked(so);
+	mtx_enter(&so->so_rcv.sb_mtx);
 	so->so_state &= ~(SS_ISCONNECTED|SS_ISDISCONNECTING);
 	so->so_state |= SS_ISCONNECTING;
+	mtx_leave(&so->so_rcv.sb_mtx);
 }
 
 void
@@ -96,8 +98,10 @@ soisconnected(struct socket *so)
 	struct socket *head = so->so_head;
 
 	soassertlocked(so);
+	mtx_enter(&so->so_rcv.sb_mtx);
 	so->so_state &= ~(SS_ISCONNECTING|SS_ISDISCONNECTING);
 	so->so_state |= SS_ISCONNECTED;
+	mtx_leave(&so->so_rcv.sb_mtx);
 
 	if (head != NULL && so->so_onq == &head->so_q0) {
 		int persocket = solock_persocket(so);
@@ -140,9 +144,9 @@ void
 soisdisconnecting(struct socket *so)
 {
 	soassertlocked(so);
+	mtx_enter(&so->so_rcv.sb_mtx);
 	so->so_state &= ~SS_ISCONNECTING;
 	so->so_state |= SS_ISDISCONNECTING;
-	mtx_enter(&so->so_rcv.sb_mtx);
 	so->so_rcv.sb_state |= SS_CANTRCVMORE;
 	mtx_leave(&so->so_rcv.sb_mtx);
 	so->so_snd.sb_state |= SS_CANTSENDMORE;
@@ -155,9 +159,9 @@ void
 soisdisconnected(struct socket *so)
 {
 	soassertlocked(so);
+	mtx_enter(&so->so_rcv.sb_mtx);
 	so->so_state &= ~(SS_ISCONNECTING|SS_ISCONNECTED|SS_ISDISCONNECTING);
 	so->so_state |= SS_ISDISCONNECTED;
-	mtx_enter(&so->so_rcv.sb_mtx);
 	so->so_rcv.sb_state |= SS_CANTRCVMORE;
 	mtx_leave(&so->so_rcv.sb_mtx);
 	so->so_snd.sb_state |= SS_CANTSENDMORE;
@@ -874,8 +878,7 @@ sbappend(struct socket *so, struct sockb
 void
 sbappendstream(struct socket *so, struct sockbuf *sb, struct mbuf *m)
 {
-	KASSERT(sb == &so->so_rcv || sb == &so->so_snd);
-	soassertlocked(so);
+	sbmtxassertlocked(so, sb);
 	KDASSERT(m->m_nextpkt == NULL);
 	KASSERT(sb->sb_mb == sb->sb_lastrecord);
 
Index: sys/netinet/tcp_input.c
===================================================================
RCS file: /cvs/src/sys/netinet/tcp_input.c,v
retrieving revision 1.404
diff -u -p -r1.404 tcp_input.c
--- sys/netinet/tcp_input.c	13 Apr 2024 23:44:11 -0000	1.404
+++ sys/netinet/tcp_input.c	16 Apr 2024 10:20:11 -0000
@@ -337,10 +337,12 @@ tcp_flush_queue(struct tcpcb *tp)
 		nq = TAILQ_NEXT(q, tcpqe_q);
 		TAILQ_REMOVE(&tp->t_segq, q, tcpqe_q);
 		ND6_HINT(tp);
+		mtx_enter(&so->so_rcv.sb_mtx);
 		if (so->so_rcv.sb_state & SS_CANTRCVMORE)
 			m_freem(q->tcpqe_m);
 		else
 			sbappendstream(so, &so->so_rcv, q->tcpqe_m);
+		mtx_leave(&so->so_rcv.sb_mtx);
 		pool_put(&tcpqe_pool, q);
 		q = nq;
 	} while (q != NULL && q->tcpqe_tcp->th_seq == tp->rcv_nxt);
@@ -1042,6 +1044,7 @@ findpcb:
 			 * Drop TCP, IP headers and TCP options then add data
 			 * to socket buffer.
 			 */
+			mtx_enter(&so->so_rcv.sb_mtx);
 			if (so->so_rcv.sb_state & SS_CANTRCVMORE)
 				m_freem(m);
 			else {
@@ -1057,6 +1060,7 @@ findpcb:
 				m_adj(m, iphlen + off);
 				sbappendstream(so, &so->so_rcv, m);
 			}
+			mtx_leave(&so->so_rcv.sb_mtx);
 			tp->t_flags |= TF_BLOCKOUTPUT;
 			sorwakeup(so);
 			tp->t_flags &= ~TF_BLOCKOUTPUT;
@@ -1900,10 +1904,12 @@ step6:
 		 */
 		if (SEQ_GT(th->th_seq+th->th_urp, tp->rcv_up)) {
 			tp->rcv_up = th->th_seq + th->th_urp;
+			mtx_enter(&so->so_rcv.sb_mtx);
 			so->so_oobmark = so->so_rcv.sb_cc +
 			    (tp->rcv_up - tp->rcv_nxt) - 1;
 			if (so->so_oobmark == 0)
 				so->so_rcv.sb_state |= SS_RCVATMARK;
+			mtx_leave(&so->so_rcv.sb_mtx);
 			sohasoutofband(so);
 			tp->t_oobflags &= ~(TCPOOB_HAVEDATA | TCPOOB_HADDATA);
 		}
@@ -1946,12 +1952,14 @@ dodata:							/* XXX */
 			tiflags = th->th_flags & TH_FIN;
 			tcpstat_pkt(tcps_rcvpack, tcps_rcvbyte, tlen);
 			ND6_HINT(tp);
+			mtx_enter(&so->so_rcv.sb_mtx);
 			if (so->so_rcv.sb_state & SS_CANTRCVMORE)
 				m_freem(m);
 			else {
 				m_adj(m, hdroptlen);
 				sbappendstream(so, &so->so_rcv, m);
 			}
+			mtx_leave(&so->so_rcv.sb_mtx);
 			tp->t_flags |= TF_BLOCKOUTPUT;
 			sorwakeup(so);
 			tp->t_flags &= ~TF_BLOCKOUTPUT;
Index: sys/netinet/tcp_usrreq.c
===================================================================
RCS file: /cvs/src/sys/netinet/tcp_usrreq.c,v
retrieving revision 1.231
diff -u -p -r1.231 tcp_usrreq.c
--- sys/netinet/tcp_usrreq.c	12 Apr 2024 16:07:09 -0000	1.231
+++ sys/netinet/tcp_usrreq.c	16 Apr 2024 10:20:11 -0000
@@ -907,13 +907,17 @@ tcp_rcvoob(struct socket *so, struct mbu
 	if ((error = tcp_sogetpcb(so, &inp, &tp)))
 		return (error);
 
+	mtx_enter(&so->so_rcv.sb_mtx);
 	if ((so->so_oobmark == 0 &&
 	    (so->so_rcv.sb_state & SS_RCVATMARK) == 0) ||
 	    so->so_options & SO_OOBINLINE ||
 	    tp->t_oobflags & TCPOOB_HADDATA) {
+		mtx_leave(&so->so_rcv.sb_mtx);
 		error = EINVAL;
 		goto out;
 	}
+	mtx_leave(&so->so_rcv.sb_mtx);
+
 	if ((tp->t_oobflags & TCPOOB_HAVEDATA) == 0) {
 		error = EWOULDBLOCK;
 		goto out;