Index | Thread | Search

From:
Vitaliy Makkoveev <mvs@openbsd.org>
Subject:
tcp(4): use per-sockbuf mutex to protect `so_rcv' socket buffer
To:
Alexander Bluhm <bluhm@openbsd.org>, tech@openbsd.org
Date:
Sun, 8 Dec 2024 01:04:42 +0300

Download raw body.

Thread
Let's make tcp(4) reception path a little bit more parallel. This diff
only unlocks soreceive() path, somove() path still locked exclusively.
Also, exclusive socket lock will be taken each time before pri_rcvd()
call in the soreceive() path.

We always hold both socket and sb_mtx locks while modifying
SS_CANTRCVMORE bit, so socket lock is enough to check it the protocol
input path.

To keep this diff small, I left sbmtx*() wrappers as is.

Index: sys/kern/uipc_socket.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_socket.c,v
diff -u -p -r1.345 uipc_socket.c
--- sys/kern/uipc_socket.c	8 Nov 2024 21:47:03 -0000	1.345
+++ sys/kern/uipc_socket.c	7 Dec 2024 21:55:36 -0000
@@ -160,6 +160,8 @@ soalloc(const struct protosw *prp, int w
 		case SOCK_RAW:
 		case SOCK_DGRAM:
 			so->so_snd.sb_flags |= SB_MTXLOCK;
+			/* FALLTHROUGH */
+		case SOCK_STREAM:
 			so->so_rcv.sb_flags |= SB_MTXLOCK;
 			break;
 		}
Index: sys/kern/uipc_socket2.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_socket2.c,v
diff -u -p -r1.159 uipc_socket2.c
--- sys/kern/uipc_socket2.c	6 Nov 2024 14:37:45 -0000	1.159
+++ sys/kern/uipc_socket2.c	7 Dec 2024 21:55:36 -0000
@@ -839,8 +839,7 @@ sbappend(struct socket *so, struct sockb
 void
 sbappendstream(struct socket *so, struct sockbuf *sb, struct mbuf *m)
 {
-	KASSERT(sb == &so->so_rcv || sb == &so->so_snd);
-	soassertlocked(so);
+	sbmtxassertlocked(so, sb);
 	KDASSERT(m->m_nextpkt == NULL);
 	KASSERT(sb->sb_mb == sb->sb_lastrecord);
 
Index: sys/netinet/tcp_input.c
===================================================================
RCS file: /cvs/src/sys/netinet/tcp_input.c,v
diff -u -p -r1.408 tcp_input.c
--- sys/netinet/tcp_input.c	8 Nov 2024 21:40:39 -0000	1.408
+++ sys/netinet/tcp_input.c	7 Dec 2024 21:55:36 -0000
@@ -334,8 +334,11 @@ tcp_flush_queue(struct tcpcb *tp)
 		ND6_HINT(tp);
 		if (so->so_rcv.sb_state & SS_CANTRCVMORE)
 			m_freem(q->tcpqe_m);
-		else
+		else {
+			mtx_enter(&so->so_rcv.sb_mtx);
 			sbappendstream(so, &so->so_rcv, q->tcpqe_m);
+			mtx_leave(&so->so_rcv.sb_mtx);
+		}
 		pool_put(&tcpqe_pool, q);
 		q = nq;
 	} while (q != NULL && q->tcpqe_tcp->th_seq == tp->rcv_nxt);
@@ -1051,7 +1054,9 @@ findpcb:
 				} else
 					tp->rfbuf_cnt += tlen;
 				m_adj(m, iphlen + off);
+				mtx_enter(&so->so_rcv.sb_mtx);
 				sbappendstream(so, &so->so_rcv, m);
+				mtx_leave(&so->so_rcv.sb_mtx);
 			}
 			tp->t_flags |= TF_BLOCKOUTPUT;
 			sorwakeup(so);
@@ -1869,13 +1874,19 @@ step6:
 	 */
 	if ((tiflags & TH_URG) && th->th_urp &&
 	    TCPS_HAVERCVDFIN(tp->t_state) == 0) {
+		u_long usage;
+
 		/*
 		 * This is a kludge, but if we receive and accept
 		 * random urgent pointers, we'll crash in
 		 * soreceive.  It's hard to imagine someone
 		 * actually wanting to send this much urgent data.
 		 */
-		if (th->th_urp + so->so_rcv.sb_cc > sb_max) {
+		mtx_enter(&so->so_rcv.sb_mtx);
+		usage = th->th_urp + so->so_rcv.sb_cc;
+		mtx_leave(&so->so_rcv.sb_mtx);
+
+		if (usage > sb_max) {
 			th->th_urp = 0;			/* XXX */
 			tiflags &= ~TH_URG;		/* XXX */
 			goto dodata;			/* XXX */
@@ -1948,7 +1959,9 @@ dodata:							/* XXX */
 				m_freem(m);
 			else {
 				m_adj(m, hdroptlen);
+			mtx_enter(&so->so_rcv.sb_mtx);
 				sbappendstream(so, &so->so_rcv, m);
+				mtx_leave(&so->so_rcv.sb_mtx);
 			}
 			tp->t_flags |= TF_BLOCKOUTPUT;
 			sorwakeup(so);
@@ -2998,6 +3011,7 @@ tcp_mss_update(struct tcpcb *tp)
 		(void)sbreserve(so, &so->so_snd, bufsize);
 	}
 
+	mtx_enter(&so->so_rcv.sb_mtx);
 	bufsize = so->so_rcv.sb_hiwat;
 	if (bufsize > mss) {
 		bufsize = roundup(bufsize, mss);
@@ -3005,6 +3019,7 @@ tcp_mss_update(struct tcpcb *tp)
 			bufsize = sb_max;
 		(void)sbreserve(so, &so->so_rcv, bufsize);
 	}
+	mtx_leave(&so->so_rcv.sb_mtx);
 
 }
 
Index: sys/netinet/tcp_output.c
===================================================================
RCS file: /cvs/src/sys/netinet/tcp_output.c,v
diff -u -p -r1.145 tcp_output.c
--- sys/netinet/tcp_output.c	14 May 2024 09:39:02 -0000	1.145
+++ sys/netinet/tcp_output.c	7 Dec 2024 21:55:36 -0000
@@ -195,7 +195,7 @@ int
 tcp_output(struct tcpcb *tp)
 {
 	struct socket *so = tp->t_inpcb->inp_socket;
-	long len, win, txmaxseg;
+	long len, win, sb_hiwat, txmaxseg;
 	int off, flags, error;
 	struct mbuf *m;
 	struct tcphdr *th;
@@ -373,7 +373,10 @@ again:
 	if (off + len < so->so_snd.sb_cc)
 		flags &= ~TH_FIN;
 
-	win = sbspace(so, &so->so_rcv);
+	mtx_enter(&so->so_rcv.sb_mtx);
+	win = sbspace_locked(so, &so->so_rcv);
+	sb_hiwat = (long) so->so_rcv.sb_hiwat; 
+	mtx_leave(&so->so_rcv.sb_mtx);
 
 	/*
 	 * Sender silly window avoidance.  If connection is idle
@@ -420,7 +423,7 @@ again:
 
 		if (adv >= (long) (2 * tp->t_maxseg))
 			goto send;
-		if (2 * adv >= (long) so->so_rcv.sb_hiwat)
+		if (2 * adv >= sb_hiwat)
 			goto send;
 	}
 
@@ -854,7 +857,7 @@ send:
 	 * Calculate receive window.  Don't shrink window,
 	 * but avoid silly window syndrome.
 	 */
-	if (win < (long)(so->so_rcv.sb_hiwat / 4) && win < (long)tp->t_maxseg)
+	if (win < (sb_hiwat / 4) && win < (long)tp->t_maxseg)
 		win = 0;
 	if (win > (long)TCP_MAXWIN << tp->rcv_scale)
 		win = (long)TCP_MAXWIN << tp->rcv_scale;
Index: sys/netinet/tcp_usrreq.c
===================================================================
RCS file: /cvs/src/sys/netinet/tcp_usrreq.c,v
diff -u -p -r1.232 tcp_usrreq.c
--- sys/netinet/tcp_usrreq.c	8 Nov 2024 15:46:55 -0000	1.232
+++ sys/netinet/tcp_usrreq.c	7 Dec 2024 21:55:36 -0000
@@ -296,10 +296,12 @@ tcp_fill_info(struct tcpcb *tp, struct s
 	ti->tcpi_rfbuf_cnt = tp->rfbuf_cnt;
 	ti->tcpi_rfbuf_ts = (now - tp->rfbuf_ts) * t;
 
+	mtx_enter(&so->so_rcv.sb_mtx);
 	ti->tcpi_so_rcv_sb_cc = so->so_rcv.sb_cc;
 	ti->tcpi_so_rcv_sb_hiwat = so->so_rcv.sb_hiwat;
 	ti->tcpi_so_rcv_sb_lowat = so->so_rcv.sb_lowat;
 	ti->tcpi_so_rcv_sb_wat = so->so_rcv.sb_wat;
+	mtx_leave(&so->so_rcv.sb_mtx);
 	ti->tcpi_so_snd_sb_cc = so->so_snd.sb_cc;
 	ti->tcpi_so_snd_sb_hiwat = so->so_snd.sb_hiwat;
 	ti->tcpi_so_snd_sb_lowat = so->so_snd.sb_lowat;
@@ -1044,7 +1046,9 @@ tcp_dodisconnect(struct tcpcb *tp)
 		tp = tcp_drop(tp, 0);
 	else {
 		soisdisconnecting(so);
+		mtx_enter(&so->so_rcv.sb_mtx);
 		sbflush(so, &so->so_rcv);
+		mtx_leave(&so->so_rcv.sb_mtx);
 		tp = tcp_usrclosed(tp);
 		if (tp)
 			(void) tcp_output(tp);
@@ -1556,7 +1560,11 @@ void
 tcp_update_rcvspace(struct tcpcb *tp)
 {
 	struct socket *so = tp->t_inpcb->inp_socket;
-	u_long nmax = so->so_rcv.sb_hiwat;
+	u_long nmax;
+
+	mtx_enter(&so->so_rcv.sb_mtx);	
+
+	nmax = so->so_rcv.sb_hiwat;
 
 	if (sbchecklowmem()) {
 		/* low on memory try to get rid of some */
@@ -1577,10 +1585,11 @@ tcp_update_rcvspace(struct tcpcb *tp)
 	    nmax < so->so_snd.sb_lowat)
 		nmax = so->so_snd.sb_lowat;
 
-	if (nmax == so->so_rcv.sb_hiwat)
-		return;
+	if (nmax != so->so_rcv.sb_hiwat) {
+		/* round to MSS boundary */
+		nmax = roundup(nmax, tp->t_maxseg);
+		sbreserve(so, &so->so_rcv, nmax);
+	}
 
-	/* round to MSS boundary */
-	nmax = roundup(nmax, tp->t_maxseg);
-	sbreserve(so, &so->so_rcv, nmax);
+	mtx_leave(&so->so_rcv.sb_mtx);	
 }