Index | Thread | Search

From:
Vitaliy Makkoveev <mvs@openbsd.org>
Subject:
tcp(4): use per-sockbuf mutex to protect `so_snd' socket buffer
To:
Alexander Bluhm <bluhm@openbsd.org>, tech@openbsd.org
Date:
Sat, 21 Dec 2024 08:16:05 +0300

Download raw body.

Thread
This diff only unlocks sosend() path, all the rest is still locked
exclusively.

Even for tcp(4) case, sosend() only checks `so_snd' free space and
sleeps if necessary, actual buffer handling happening within exclusively
locked PCB layer. In the PCB layer we don't need to protect read-only
access to `so_snd' because corresponding read-write access is serialized
by socket lock.

sosend() needs to be serialized with somove(), so sotask() takes
sblock() on `so_snd' of spliced socket. This discards previous
protection of tcp(4) spliced sockets where solock() was used to prevent
concurrent unsplicing. This scheme was used to avoid sleep in sofree().

However, sleep is sofree() is possible. We have two cases:

1. Socket was not yet accept(2)ed. Such sockets can't be accessed from
the userland and can't be spliced. sofree() could be called only from
PCB layer and don't need to sleep.

2. Socket was accepted. While called form PCB layer, sofree() ignores it
because SS_NOFDREF bit is not set. Socket remains spliced, but without
PCB. Sockets without PCB can't be accessed from PCB layer, only from
userland. So, soclose()/sofree() thread needs to wait concurrent
soclose()/sofree() thread for spliced socket as is is already done for
udp(4) case. In such case it is safe to release solock() and sleep
within sofree().

Note, I intentionally left all "if (dosolock)" dances as is.

Index: sys/kern/uipc_socket.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_socket.c,v
diff -u -p -r1.347 uipc_socket.c
--- sys/kern/uipc_socket.c	19 Dec 2024 22:11:35 -0000	1.347
+++ sys/kern/uipc_socket.c	21 Dec 2024 04:12:02 -0000
@@ -153,27 +153,8 @@ soalloc(const struct protosw *prp, int w
 	TAILQ_INIT(&so->so_q0);
 	TAILQ_INIT(&so->so_q);
 
-	switch (dp->dom_family) {
-	case AF_INET:
-	case AF_INET6:
-		switch (prp->pr_type) {
-		case SOCK_RAW:
-		case SOCK_DGRAM:
-			so->so_snd.sb_flags |= SB_MTXLOCK;
-			/* FALLTHROUGH */
-		case SOCK_STREAM:
-			so->so_rcv.sb_flags |= SB_MTXLOCK;
-			break;
-		}
-		break;
-	case AF_KEY:
-	case AF_ROUTE:
-	case AF_UNIX:
-	case AF_FRAME:
-		so->so_snd.sb_flags |= SB_MTXLOCK;
-		so->so_rcv.sb_flags |= SB_MTXLOCK;
-		break;
-	}
+	so->so_snd.sb_flags |= SB_MTXLOCK;
+	so->so_rcv.sb_flags |= SB_MTXLOCK;
 
 	return (so);
 }
@@ -327,17 +308,13 @@ sofree(struct socket *so, int keep_lock)
 			sounlock(head);
 	}
 
-	switch (so->so_proto->pr_domain->dom_family) {
-	case AF_INET:
-	case AF_INET6:
-		if (so->so_proto->pr_type == SOCK_STREAM)
-			break;
-		/* FALLTHROUGH */
-	default:
+	if (!keep_lock) {
+		/*
+		 * sofree() was called from soclose(). Sleep is safe
+		 * even for tcp(4) sockets.
+		 */
 		sounlock(so);
 		refcnt_finalize(&so->so_refcnt, "sofinal");
-		solock(so);
-		break;
 	}
 
 	sigio_free(&so->so_sigio);
@@ -358,9 +335,6 @@ sofree(struct socket *so, int keep_lock)
 		(*so->so_proto->pr_domain->dom_dispose)(so->so_rcv.sb_mb);
 	m_purge(so->so_rcv.sb_mb);
 
-	if (!keep_lock)
-		sounlock(so);
-
 #ifdef SOCKET_SPLICE
 	if (so->so_sp) {
 		/* Reuse splice idle, sounsplice() has been called before. */
@@ -458,31 +432,6 @@ discard:
 	if (so->so_sp) {
 		struct socket *soback;
 
-		if (so->so_proto->pr_flags & PR_WANTRCVD) {
-			/*
-			 * Copy - Paste, but can't relock and sleep in
-			 * sofree() in tcp(4) case. That's why tcp(4)
-			 * still rely on solock() for splicing and
-			 * unsplicing.
-			 */
-
-			if (issplicedback(so)) {
-				int freeing = SOSP_FREEING_WRITE;
-
-				if (so->so_sp->ssp_soback == so)
-					freeing |= SOSP_FREEING_READ;
-				sounsplice(so->so_sp->ssp_soback, so, freeing);
-			}
-			if (isspliced(so)) {
-				int freeing = SOSP_FREEING_READ;
-
-				if (so == so->so_sp->ssp_socket)
-					freeing |= SOSP_FREEING_WRITE;
-				sounsplice(so, so->so_sp->ssp_socket, freeing);
-			}
-			goto free;
-		}
-
 		sounlock(so);
 		mtx_enter(&so->so_snd.sb_mtx);
 		/*
@@ -530,7 +479,6 @@ notsplicedback:
 
 		solock(so);
 	}
-free:
 #endif /* SOCKET_SPLICE */
 	/* sofree() calls sounlock(). */
 	sofree(so, 0);
@@ -1587,17 +1535,22 @@ sotask(void *arg)
 	 */
 
 	sblock(&so->so_rcv, SBL_WAIT | SBL_NOINTR);
-	if (sockstream)
-		solock(so);
-
 	if (so->so_rcv.sb_flags & SB_SPLICE) {
-		if (sockstream)
+		struct socket *sosp = so->so_sp->ssp_socket;
+
+		if (sockstream) {
+			sblock(&sosp->so_snd, SBL_WAIT | SBL_NOINTR);
+			solock(so);
 			doyield = 1;
+		}
+
 		somove(so, M_DONTWAIT);
-	}
 
-	if (sockstream)
-		sounlock(so);
+		if (sockstream) {
+			sounlock(so);
+			sbunlock(&sosp->so_snd);
+		}
+	}
 	sbunlock(&so->so_rcv);
 
 	if (doyield) {
Index: sys/netinet/tcp_input.c
===================================================================
RCS file: /cvs/src/sys/netinet/tcp_input.c,v
diff -u -p -r1.410 tcp_input.c
--- sys/netinet/tcp_input.c	20 Dec 2024 19:20:34 -0000	1.410
+++ sys/netinet/tcp_input.c	21 Dec 2024 04:12:02 -0000
@@ -957,7 +957,10 @@ findpcb:
 				    acked);
 				tp->t_rcvacktime = now;
 				ND6_HINT(tp);
+
+				mtx_enter(&so->so_snd.sb_mtx);
 				sbdrop(so, &so->so_snd, acked);
+				mtx_leave(&so->so_snd.sb_mtx);
 
 				/*
 				 * If we had a pending ICMP message that
@@ -1738,10 +1741,14 @@ trimthenstep6:
 				tp->snd_wnd -= so->so_snd.sb_cc;
 			else
 				tp->snd_wnd = 0;
+			mtx_enter(&so->so_snd.sb_mtx);
 			sbdrop(so, &so->so_snd, (int)so->so_snd.sb_cc);
+			mtx_leave(&so->so_snd.sb_mtx);
 			ourfinisacked = 1;
 		} else {
+			mtx_enter(&so->so_snd.sb_mtx);
 			sbdrop(so, &so->so_snd, acked);
+			mtx_leave(&so->so_snd.sb_mtx);
 			if (tp->snd_wnd > acked)
 				tp->snd_wnd -= acked;
 			else
@@ -2999,6 +3006,7 @@ tcp_mss_update(struct tcpcb *tp)
 	if (rt == NULL)
 		return;
 
+	mtx_enter(&so->so_snd.sb_mtx);
 	bufsize = so->so_snd.sb_hiwat;
 	if (bufsize < mss) {
 		mss = bufsize;
@@ -3010,6 +3018,7 @@ tcp_mss_update(struct tcpcb *tp)
 			bufsize = sb_max;
 		(void)sbreserve(so, &so->so_snd, bufsize);
 	}
+	mtx_leave(&so->so_snd.sb_mtx);
 
 	mtx_enter(&so->so_rcv.sb_mtx);
 	bufsize = so->so_rcv.sb_hiwat;
Index: sys/netinet/tcp_output.c
===================================================================
RCS file: /cvs/src/sys/netinet/tcp_output.c,v
diff -u -p -r1.146 tcp_output.c
--- sys/netinet/tcp_output.c	19 Dec 2024 22:11:35 -0000	1.146
+++ sys/netinet/tcp_output.c	21 Dec 2024 04:12:02 -0000
@@ -202,7 +202,7 @@ tcp_output(struct tcpcb *tp)
 	u_int32_t optbuf[howmany(MAX_TCPOPTLEN, sizeof(u_int32_t))];
 	u_char *opt = (u_char *)optbuf;
 	unsigned int optlen, hdrlen, packetlen;
-	int idle, sendalot = 0;
+	int doing_sosend, idle, sendalot = 0;
 	int i, sack_rxmit = 0;
 	struct sackhole *p;
 	uint64_t now;
@@ -227,6 +227,10 @@ tcp_output(struct tcpcb *tp)
 
 	now = tcp_now();
 
+	mtx_enter(&so->so_snd.sb_mtx);
+	doing_sosend=soissending(so);
+	mtx_leave(&so->so_snd.sb_mtx);
+
 	/*
 	 * Determine length of data that should be transmitted,
 	 * and flags that will be used.
@@ -243,7 +247,7 @@ tcp_output(struct tcpcb *tp)
 		tp->snd_cwnd = 2 * tp->t_maxseg;
 
 	/* remember 'idle' for next invocation of tcp_output */
-	if (idle && soissending(so)) {
+	if (idle && doing_sosend) {
 		tp->t_flags |= TF_LASTIDLE;
 		idle = 0;
 	} else
@@ -392,7 +396,7 @@ again:
 		if (len >= txmaxseg)
 			goto send;
 		if ((idle || (tp->t_flags & TF_NODELAY)) &&
-		    len + off >= so->so_snd.sb_cc && !soissending(so) &&
+		    len + off >= so->so_snd.sb_cc && !doing_sosend &&
 		    (tp->t_flags & TF_NOPUSH) == 0)
 			goto send;
 		if (tp->t_force)
@@ -725,7 +729,7 @@ send:
 		 * give data to the user when a buffer fills or
 		 * a PUSH comes in.)
 		 */
-		if (off + len == so->so_snd.sb_cc && !soissending(so))
+		if (off + len == so->so_snd.sb_cc && !doing_sosend)
 			flags |= TH_PUSH;
 		tp->t_sndtime = now;
 	} else {
Index: sys/netinet/tcp_usrreq.c
===================================================================
RCS file: /cvs/src/sys/netinet/tcp_usrreq.c,v
diff -u -p -r1.233 tcp_usrreq.c
--- sys/netinet/tcp_usrreq.c	19 Dec 2024 22:11:35 -0000	1.233
+++ sys/netinet/tcp_usrreq.c	21 Dec 2024 04:12:02 -0000
@@ -302,10 +302,12 @@ tcp_fill_info(struct tcpcb *tp, struct s
 	ti->tcpi_so_rcv_sb_lowat = so->so_rcv.sb_lowat;
 	ti->tcpi_so_rcv_sb_wat = so->so_rcv.sb_wat;
 	mtx_leave(&so->so_rcv.sb_mtx);
+	mtx_enter(&so->so_snd.sb_mtx);
 	ti->tcpi_so_snd_sb_cc = so->so_snd.sb_cc;
 	ti->tcpi_so_snd_sb_hiwat = so->so_snd.sb_hiwat;
 	ti->tcpi_so_snd_sb_lowat = so->so_snd.sb_lowat;
 	ti->tcpi_so_snd_sb_wat = so->so_snd.sb_wat;
+	mtx_leave(&so->so_snd.sb_mtx);
 
 	return 0;
 }
@@ -842,7 +844,9 @@ tcp_send(struct socket *so, struct mbuf 
 	if (so->so_options & SO_DEBUG)
 		ostate = tp->t_state;
 
+	mtx_enter(&so->so_snd.sb_mtx);
 	sbappendstream(so, &so->so_snd, m);
+	mtx_leave(&so->so_snd.sb_mtx);
 	m = NULL;
 
 	error = tcp_output(tp);
@@ -895,7 +899,9 @@ tcp_sense(struct socket *so, struct stat
 	if ((error = tcp_sogetpcb(so, &inp, &tp)))
 		return (error);
 
+	mtx_enter(&so->so_snd.sb_mtx);
 	ub->st_blksize = so->so_snd.sb_hiwat;
+	mtx_leave(&so->so_snd.sb_mtx);
 
 	if (so->so_options & SO_DEBUG)
 		tcp_trace(TA_USER, tp->t_state, tp, tp, NULL, PRU_SENSE, 0);
@@ -970,7 +976,9 @@ tcp_sendoob(struct socket *so, struct mb
 	 * of data past the urgent section.
 	 * Otherwise, snd_up should be one lower.
 	 */
+	mtx_enter(&so->so_snd.sb_mtx);
 	sbappendstream(so, &so->so_snd, m);
+	mtx_leave(&so->so_snd.sb_mtx);
 	m = NULL;
 	tp->snd_up = tp->snd_una + so->so_snd.sb_cc;
 	tp->t_force = 1;
@@ -1519,7 +1527,11 @@ void
 tcp_update_sndspace(struct tcpcb *tp)
 {
 	struct socket *so = tp->t_inpcb->inp_socket;
-	u_long nmax = so->so_snd.sb_hiwat;
+	u_long nmax;
+
+	mtx_enter(&so->so_snd.sb_mtx);
+
+	nmax = so->so_snd.sb_hiwat;
 
 	if (sbchecklowmem()) {
 		/* low on memory try to get rid of some */
@@ -1535,7 +1547,7 @@ tcp_update_sndspace(struct tcpcb *tp)
 	}
 
 	/* a writable socket must be preserved because of poll(2) semantics */
-	if (sbspace(so, &so->so_snd) >= so->so_snd.sb_lowat) {
+	if (sbspace_locked(so, &so->so_snd) >= so->so_snd.sb_lowat) {
 		if (nmax < so->so_snd.sb_cc + so->so_snd.sb_lowat)
 			nmax = so->so_snd.sb_cc + so->so_snd.sb_lowat;
 		/* keep in sync with sbreserve() calculation */
@@ -1548,6 +1560,8 @@ tcp_update_sndspace(struct tcpcb *tp)
 
 	if (nmax != so->so_snd.sb_hiwat)
 		sbreserve(so, &so->so_snd, nmax);
+
+	mtx_leave(&so->so_snd.sb_mtx);
 }
 
 /*
@@ -1581,9 +1595,11 @@ tcp_update_rcvspace(struct tcpcb *tp)
 	}
 
 	/* a readable socket must be preserved because of poll(2) semantics */
+	mtx_enter(&so->so_snd.sb_mtx);
 	if (so->so_rcv.sb_cc >= so->so_rcv.sb_lowat &&
 	    nmax < so->so_snd.sb_lowat)
 		nmax = so->so_snd.sb_lowat;
+	mtx_leave(&so->so_snd.sb_mtx);
 
 	if (nmax != so->so_rcv.sb_hiwat) {
 		/* round to MSS boundary */