Index | Thread | Search

From:
Vitaliy Makkoveev <mvs@openbsd.org>
Subject:
Re: Don't take solock() in soreceive() for tcp(4) sockets
To:
tech@openbsd.org
Cc:
Hrvoje Popovski <hrvoje@srce.hr>
Date:
Tue, 16 Apr 2024 14:33:34 +0300

Download raw body.

Thread
On Tue, Apr 16, 2024 at 01:36:40PM +0300, Vitaliy Makkoveev wrote:
> Make the reception locking fine grained too. pru_rcvd() requires
> solock(), so take it around. I don't know is the unlocked connection
> state SS_ bits check safe, so take `sb_mtx' around corresponding
> `so_state' modifications. If this serialization is not required, it
> could be reworked with another diff later.
> 
> So, the reception of all the inet sockets moved out of netlock.
> 

Sorry, this is the right diff.

Index: sys/kern/uipc_socket.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_socket.c,v
retrieving revision 1.330
diff -u -p -r1.330 uipc_socket.c
--- sys/kern/uipc_socket.c	15 Apr 2024 21:31:29 -0000	1.330
+++ sys/kern/uipc_socket.c	16 Apr 2024 11:32:17 -0000
@@ -157,12 +157,7 @@ soalloc(const struct protosw *prp, int w
 	switch (dp->dom_family) {
 	case AF_INET:
 	case AF_INET6:
-		switch (prp->pr_type) {
-		case SOCK_DGRAM:
-		case SOCK_RAW:
-			so->so_rcv.sb_flags |= SB_MTXLOCK | SB_OWNLOCK;
-			break;
-		}
+		so->so_rcv.sb_flags |= SB_MTXLOCK | SB_OWNLOCK;
 		break;
 	case AF_UNIX:
 		so->so_rcv.sb_flags |= SB_MTXLOCK;
@@ -348,20 +343,18 @@ sofree(struct socket *so, int keep_lock)
 #endif /* SOCKET_SPLICE */
 	sbrelease(so, &so->so_snd);
 
+	if (!keep_lock)
+		sounlock(so);
+
 	/*
-	 * Regardless on '_locked' postfix, must release solock() before
-	 * call sorflush_locked() for SB_OWNLOCK marked socket. Can't
-	 * release solock() and call sorflush() because solock() release
-	 * is unwanted for tcp(4) socket. 
+	 * The socket was previously unspliced. The unlocked `so_rcv'
+	 * cleanup is safe even if we have concurrent somove() thread.
 	 */
 
-	if (so->so_rcv.sb_flags & SB_OWNLOCK)
-		sounlock(so);
-
-	sorflush_locked(so);
-
-	if (!((so->so_rcv.sb_flags & SB_OWNLOCK) || keep_lock))
-		sounlock(so);
+	if (so->so_proto->pr_flags & PR_RIGHTS &&
+	    so->so_proto->pr_domain->dom_dispose)
+		(*so->so_proto->pr_domain->dom_dispose)(so->so_rcv.sb_mb);
+	m_purge(so->so_rcv.sb_mb);
 
 #ifdef SOCKET_SPLICE
 	if (so->so_sp) {
@@ -1222,7 +1215,11 @@ dontblock:
 		SBLASTMBUFCHK(&so->so_rcv, "soreceive 4");
 		if (pr->pr_flags & PR_WANTRCVD) {
 			sb_mtx_unlock(&so->so_rcv);
+			if (!dosolock)
+				solock(so);
 			pru_rcvd(so);
+			if (!dosolock)
+				sounlock(so);
 			sb_mtx_lock(&so->so_rcv);
 		}
 	}
@@ -1358,17 +1355,9 @@ sosplice(struct socket *so, int fd, off_
 		membar_consumer();
 	}
 
-	if (so->so_rcv.sb_flags & SB_OWNLOCK) {
-		if ((error = sblock(so, &so->so_rcv, SBL_WAIT)) != 0)
-			return (error);
-		solock(so);
-	} else {
-		solock(so);
-		if ((error = sblock(so, &so->so_rcv, SBL_WAIT)) != 0) {
-			sounlock(so);
-			return (error);
-		}
-	}
+	if ((error = sblock(so, &so->so_rcv, SBL_WAIT)) != 0)
+		return (error);
+	solock(so);
 
 	if (so->so_options & SO_ACCEPTCONN) {
 		error = EOPNOTSUPP;
@@ -1446,13 +1435,8 @@ sosplice(struct socket *so, int fd, off_
  release:
 	sbunlock(sosp, &sosp->so_snd);
  out:
-	if (so->so_rcv.sb_flags & SB_OWNLOCK) {
-		sounlock(so);
-		sbunlock(so, &so->so_rcv);
-	} else {
-		sbunlock(so, &so->so_rcv);
-		sounlock(so);
-	}
+	sounlock(so);
+	sbunlock(so, &so->so_rcv);
 
 	if (fp)
 		FRELE(fp, curproc);
Index: sys/kern/uipc_socket2.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_socket2.c,v
retrieving revision 1.149
diff -u -p -r1.149 uipc_socket2.c
--- sys/kern/uipc_socket2.c	11 Apr 2024 13:32:51 -0000	1.149
+++ sys/kern/uipc_socket2.c	16 Apr 2024 11:32:17 -0000
@@ -86,8 +86,10 @@ void
 soisconnecting(struct socket *so)
 {
 	soassertlocked(so);
+	mtx_enter(&so->so_rcv.sb_mtx);
 	so->so_state &= ~(SS_ISCONNECTED|SS_ISDISCONNECTING);
 	so->so_state |= SS_ISCONNECTING;
+	mtx_leave(&so->so_rcv.sb_mtx);
 }
 
 void
@@ -96,8 +98,10 @@ soisconnected(struct socket *so)
 	struct socket *head = so->so_head;
 
 	soassertlocked(so);
+	mtx_enter(&so->so_rcv.sb_mtx);
 	so->so_state &= ~(SS_ISCONNECTING|SS_ISDISCONNECTING);
 	so->so_state |= SS_ISCONNECTED;
+	mtx_leave(&so->so_rcv.sb_mtx);
 
 	if (head != NULL && so->so_onq == &head->so_q0) {
 		int persocket = solock_persocket(so);
@@ -140,9 +144,9 @@ void
 soisdisconnecting(struct socket *so)
 {
 	soassertlocked(so);
+	mtx_enter(&so->so_rcv.sb_mtx);
 	so->so_state &= ~SS_ISCONNECTING;
 	so->so_state |= SS_ISDISCONNECTING;
-	mtx_enter(&so->so_rcv.sb_mtx);
 	so->so_rcv.sb_state |= SS_CANTRCVMORE;
 	mtx_leave(&so->so_rcv.sb_mtx);
 	so->so_snd.sb_state |= SS_CANTSENDMORE;
@@ -155,9 +159,9 @@ void
 soisdisconnected(struct socket *so)
 {
 	soassertlocked(so);
+	mtx_enter(&so->so_rcv.sb_mtx);
 	so->so_state &= ~(SS_ISCONNECTING|SS_ISCONNECTED|SS_ISDISCONNECTING);
 	so->so_state |= SS_ISDISCONNECTED;
-	mtx_enter(&so->so_rcv.sb_mtx);
 	so->so_rcv.sb_state |= SS_CANTRCVMORE;
 	mtx_leave(&so->so_rcv.sb_mtx);
 	so->so_snd.sb_state |= SS_CANTSENDMORE;
@@ -874,8 +878,7 @@ sbappend(struct socket *so, struct sockb
 void
 sbappendstream(struct socket *so, struct sockbuf *sb, struct mbuf *m)
 {
-	KASSERT(sb == &so->so_rcv || sb == &so->so_snd);
-	soassertlocked(so);
+	sbmtxassertlocked(so, sb);
 	KDASSERT(m->m_nextpkt == NULL);
 	KASSERT(sb->sb_mb == sb->sb_lastrecord);
 
Index: sys/netinet/tcp_input.c
===================================================================
RCS file: /cvs/src/sys/netinet/tcp_input.c,v
retrieving revision 1.404
diff -u -p -r1.404 tcp_input.c
--- sys/netinet/tcp_input.c	13 Apr 2024 23:44:11 -0000	1.404
+++ sys/netinet/tcp_input.c	16 Apr 2024 11:32:17 -0000
@@ -337,10 +337,12 @@ tcp_flush_queue(struct tcpcb *tp)
 		nq = TAILQ_NEXT(q, tcpqe_q);
 		TAILQ_REMOVE(&tp->t_segq, q, tcpqe_q);
 		ND6_HINT(tp);
+		mtx_enter(&so->so_rcv.sb_mtx);
 		if (so->so_rcv.sb_state & SS_CANTRCVMORE)
 			m_freem(q->tcpqe_m);
 		else
 			sbappendstream(so, &so->so_rcv, q->tcpqe_m);
+		mtx_leave(&so->so_rcv.sb_mtx);
 		pool_put(&tcpqe_pool, q);
 		q = nq;
 	} while (q != NULL && q->tcpqe_tcp->th_seq == tp->rcv_nxt);
@@ -380,6 +382,7 @@ tcp_input(struct mbuf **mp, int *offp, i
 	tcp_seq iss, *reuse = NULL;
 	uint64_t now;
 	u_long tiwin;
+	long win;
 	struct tcp_opt_info opti;
 	struct tcphdr *th;
 #ifdef INET6
@@ -939,6 +942,10 @@ findpcb:
 			tp->ts_recent = opti.ts_val;
 		}
 
+		mtx_enter(&so->so_rcv.sb_mtx);
+		win = sbspace(so, &so->so_rcv);
+		mtx_leave(&so->so_rcv.sb_mtx);
+
 		if (tlen == 0) {
 			if (SEQ_GT(th->th_ack, tp->snd_una) &&
 			    SEQ_LEQ(th->th_ack, tp->snd_max) &&
@@ -1018,8 +1025,7 @@ findpcb:
 				return IPPROTO_DONE;
 			}
 		} else if (th->th_ack == tp->snd_una &&
-		    TAILQ_EMPTY(&tp->t_segq) &&
-		    tlen <= sbspace(so, &so->so_rcv)) {
+		    TAILQ_EMPTY(&tp->t_segq) && tlen <= win) {
 			/*
 			 * This is a pure, in-sequence data packet
 			 * with nothing on the reassembly queue and
@@ -1042,6 +1048,7 @@ findpcb:
 			 * Drop TCP, IP headers and TCP options then add data
 			 * to socket buffer.
 			 */
+			mtx_enter(&so->so_rcv.sb_mtx);
 			if (so->so_rcv.sb_state & SS_CANTRCVMORE)
 				m_freem(m);
 			else {
@@ -1057,6 +1064,7 @@ findpcb:
 				m_adj(m, iphlen + off);
 				sbappendstream(so, &so->so_rcv, m);
 			}
+			mtx_leave(&so->so_rcv.sb_mtx);
 			tp->t_flags |= TF_BLOCKOUTPUT;
 			sorwakeup(so);
 			tp->t_flags &= ~TF_BLOCKOUTPUT;
@@ -1081,7 +1089,10 @@ findpcb:
 	{
 		int win;
 
+		mtx_enter(&so->so_rcv.sb_mtx);
 		win = sbspace(so, &so->so_rcv);
+		mtx_leave(&so->so_rcv.sb_mtx);
+
 		if (win < 0)
 			win = 0;
 		tp->rcv_wnd = imax(win, (int)(tp->rcv_adv - tp->rcv_nxt));
@@ -1900,10 +1911,12 @@ step6:
 		 */
 		if (SEQ_GT(th->th_seq+th->th_urp, tp->rcv_up)) {
 			tp->rcv_up = th->th_seq + th->th_urp;
+			mtx_enter(&so->so_rcv.sb_mtx);
 			so->so_oobmark = so->so_rcv.sb_cc +
 			    (tp->rcv_up - tp->rcv_nxt) - 1;
 			if (so->so_oobmark == 0)
 				so->so_rcv.sb_state |= SS_RCVATMARK;
+			mtx_leave(&so->so_rcv.sb_mtx);
 			sohasoutofband(so);
 			tp->t_oobflags &= ~(TCPOOB_HAVEDATA | TCPOOB_HADDATA);
 		}
@@ -1946,12 +1959,14 @@ dodata:							/* XXX */
 			tiflags = th->th_flags & TH_FIN;
 			tcpstat_pkt(tcps_rcvpack, tcps_rcvbyte, tlen);
 			ND6_HINT(tp);
+			mtx_enter(&so->so_rcv.sb_mtx);
 			if (so->so_rcv.sb_state & SS_CANTRCVMORE)
 				m_freem(m);
 			else {
 				m_adj(m, hdroptlen);
 				sbappendstream(so, &so->so_rcv, m);
 			}
+			mtx_leave(&so->so_rcv.sb_mtx);
 			tp->t_flags |= TF_BLOCKOUTPUT;
 			sorwakeup(so);
 			tp->t_flags &= ~TF_BLOCKOUTPUT;
@@ -3000,6 +3015,7 @@ tcp_mss_update(struct tcpcb *tp)
 		(void)sbreserve(so, &so->so_snd, bufsize);
 	}
 
+	mtx_enter(&so->so_rcv.sb_mtx);
 	bufsize = so->so_rcv.sb_hiwat;
 	if (bufsize > mss) {
 		bufsize = roundup(bufsize, mss);
@@ -3007,6 +3023,7 @@ tcp_mss_update(struct tcpcb *tp)
 			bufsize = sb_max;
 		(void)sbreserve(so, &so->so_rcv, bufsize);
 	}
+	mtx_leave(&so->so_rcv.sb_mtx);
 
 }
 
@@ -3784,7 +3801,10 @@ syn_cache_add(struct sockaddr *src, stru
 	/*
 	 * Initialize some local state.
 	 */
+	mtx_enter(&so->so_rcv.sb_mtx);
 	win = sbspace(so, &so->so_rcv);
+	mtx_leave(&so->so_rcv.sb_mtx);
+
 	if (win > TCP_MAXWIN)
 		win = TCP_MAXWIN;
 
Index: sys/netinet/tcp_output.c
===================================================================
RCS file: /cvs/src/sys/netinet/tcp_output.c,v
retrieving revision 1.143
diff -u -p -r1.143 tcp_output.c
--- sys/netinet/tcp_output.c	13 Feb 2024 12:22:09 -0000	1.143
+++ sys/netinet/tcp_output.c	16 Apr 2024 11:32:17 -0000
@@ -373,7 +373,9 @@ again:
 	if (off + len < so->so_snd.sb_cc)
 		flags &= ~TH_FIN;
 
+	mtx_enter(&so->so_rcv.sb_mtx);
 	win = sbspace(so, &so->so_rcv);
+	mtx_leave(&so->so_rcv.sb_mtx);
 
 	/*
 	 * Sender silly window avoidance.  If connection is idle
Index: sys/netinet/tcp_subr.c
===================================================================
RCS file: /cvs/src/sys/netinet/tcp_subr.c,v
retrieving revision 1.200
diff -u -p -r1.200 tcp_subr.c
--- sys/netinet/tcp_subr.c	12 Apr 2024 16:07:09 -0000	1.200
+++ sys/netinet/tcp_subr.c	16 Apr 2024 11:32:17 -0000
@@ -311,7 +311,9 @@ tcp_respond(struct tcpcb *tp, caddr_t te
 
 	if (tp) {
 		struct socket *so = tp->t_inpcb->inp_socket;
+		mtx_enter(&so->so_rcv.sb_mtx);
 		win = sbspace(so, &so->so_rcv);
+		mtx_leave(&so->so_rcv.sb_mtx);
 		/*
 		 * If this is called with an unconnected
 		 * socket/tp/pcb (tp->pf is 0), we lose.
Index: sys/netinet/tcp_usrreq.c
===================================================================
RCS file: /cvs/src/sys/netinet/tcp_usrreq.c,v
retrieving revision 1.231
diff -u -p -r1.231 tcp_usrreq.c
--- sys/netinet/tcp_usrreq.c	12 Apr 2024 16:07:09 -0000	1.231
+++ sys/netinet/tcp_usrreq.c	16 Apr 2024 11:32:17 -0000
@@ -907,13 +907,17 @@ tcp_rcvoob(struct socket *so, struct mbu
 	if ((error = tcp_sogetpcb(so, &inp, &tp)))
 		return (error);
 
+	mtx_enter(&so->so_rcv.sb_mtx);
 	if ((so->so_oobmark == 0 &&
 	    (so->so_rcv.sb_state & SS_RCVATMARK) == 0) ||
 	    so->so_options & SO_OOBINLINE ||
 	    tp->t_oobflags & TCPOOB_HADDATA) {
+		mtx_leave(&so->so_rcv.sb_mtx);
 		error = EINVAL;
 		goto out;
 	}
+	mtx_leave(&so->so_rcv.sb_mtx);
+
 	if ((tp->t_oobflags & TCPOOB_HAVEDATA) == 0) {
 		error = EWOULDBLOCK;
 		goto out;