Index | Thread | Search

From:
Vitaliy Makkoveev <mvs@openbsd.org>
Subject:
Don't take solock() in sosend() and soreceive() paths for unix(4) sockets
To:
tech@openbsd.org
Date:
Tue, 30 Apr 2024 23:01:04 +0300

Download raw body.

Thread
Use `sb_mtx' mutex(9) and `sb_lock' rwlock(9) to protect `so_snd' and
`so_rcv' of unix(4) sockets.

The transmission of unix(4) sockets already half-unlocked because
connected peer is not locked by solock() during sbappend*() call. Since
the `so_snd' is protected by `sb_mtx' mutex(9) the re-locking is not
required in uipc_rcvd() too.

SB_OWNLOCK became redundant with SB_MTXLOCK, so remove it. SB_MTXLOCK
was kept because checks against SB_MTXLOCK within sb*() routines look
more consistent to me.

Please note, the unlocked peer `so2' of unix(4) can't be disconnected
while solock() is held on `so'. That's why unlocked sorwakeup() and
sowwakeup() are fine, corresponding paths will never be followed.

Index: sys/kern/uipc_socket.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_socket.c,v
diff -u -p -r1.331 uipc_socket.c
--- sys/kern/uipc_socket.c	30 Apr 2024 17:59:15 -0000	1.331
+++ sys/kern/uipc_socket.c	30 Apr 2024 19:57:27 -0000
@@ -159,14 +159,15 @@ soalloc(const struct protosw *prp, int w
 	case AF_INET6:
 		switch (prp->pr_type) {
 		case SOCK_RAW:
-			so->so_snd.sb_flags |= SB_MTXLOCK | SB_OWNLOCK;
+			so->so_snd.sb_flags |= SB_MTXLOCK;
 			/* FALLTHROUGH */
 		case SOCK_DGRAM:
-			so->so_rcv.sb_flags |= SB_MTXLOCK | SB_OWNLOCK;
+			so->so_rcv.sb_flags |= SB_MTXLOCK;
 			break;
 		}
 		break;
 	case AF_UNIX:
+		so->so_snd.sb_flags |= SB_MTXLOCK;
 		so->so_rcv.sb_flags |= SB_MTXLOCK;
 		break;
 	}
@@ -355,17 +356,17 @@ sofree(struct socket *so, int keep_lock)
 
 	/*
 	 * Regardless on '_locked' postfix, must release solock() before
-	 * call sorflush_locked() for SB_OWNLOCK marked socket. Can't
+	 * call sorflush_locked() for SB_MTXLOCK marked socket. Can't
 	 * release solock() and call sorflush() because solock() release
 	 * is unwanted for tcp(4) socket. 
 	 */
 
-	if (so->so_rcv.sb_flags & SB_OWNLOCK)
+	if (so->so_rcv.sb_flags & SB_MTXLOCK)
 		sounlock(so);
 
 	sorflush_locked(so);
 
-	if (!((so->so_rcv.sb_flags & SB_OWNLOCK) || keep_lock))
+	if (!((so->so_rcv.sb_flags & SB_MTXLOCK) || keep_lock))
 		sounlock(so);
 
 #ifdef SOCKET_SPLICE
@@ -574,7 +575,7 @@ sosend(struct socket *so, struct mbuf *a
 	size_t resid;
 	int error;
 	int atomic = sosendallatonce(so) || top;
-	int dosolock = ((so->so_snd.sb_flags & SB_OWNLOCK) == 0);
+	int dosolock = ((so->so_snd.sb_flags & SB_MTXLOCK) == 0);
 
 	if (uio)
 		resid = uio->uio_resid;
@@ -846,7 +847,7 @@ soreceive(struct socket *so, struct mbuf
 	const struct protosw *pr = so->so_proto;
 	struct mbuf *nextrecord;
 	size_t resid, orig_resid = uio->uio_resid;
-	int dosolock = ((so->so_rcv.sb_flags & SB_OWNLOCK) == 0);
+	int dosolock = ((so->so_rcv.sb_flags & SB_MTXLOCK) == 0);
 
 	mp = mp0;
 	if (paddr)
@@ -945,7 +946,7 @@ restart:
 		SBLASTRECORDCHK(&so->so_rcv, "soreceive sbwait 1");
 		SBLASTMBUFCHK(&so->so_rcv, "soreceive sbwait 1");
 
-		if (so->so_rcv.sb_flags & (SB_MTXLOCK | SB_OWNLOCK)) {
+		if (so->so_rcv.sb_flags & SB_MTXLOCK) {
 			sbunlock_locked(so, &so->so_rcv);
 			if (dosolock)
 				sounlock_shared(so);
@@ -1247,7 +1248,11 @@ dontblock:
 		SBLASTMBUFCHK(&so->so_rcv, "soreceive 4");
 		if (pr->pr_flags & PR_WANTRCVD) {
 			sb_mtx_unlock(&so->so_rcv);
+			if (!dosolock)
+				solock_shared(so);
 			pru_rcvd(so);
+			if (!dosolock)
+				sounlock_shared(so);
 			sb_mtx_lock(&so->so_rcv);
 		}
 	}
@@ -1306,17 +1311,17 @@ sorflush_locked(struct socket *so)
 	const struct protosw *pr = so->so_proto;
 	int error;
 
-	if ((sb->sb_flags & SB_OWNLOCK) == 0)
+	if ((sb->sb_flags & SB_MTXLOCK) == 0)
 		soassertlocked(so);
 
 	error = sblock(so, sb, SBL_WAIT | SBL_NOINTR);
 	/* with SBL_WAIT and SLB_NOINTR sblock() must not fail */
 	KASSERT(error == 0);
 
-	if (sb->sb_flags & SB_OWNLOCK)
+	if (sb->sb_flags & SB_MTXLOCK)
 		solock(so);
 	socantrcvmore(so);
-	if (sb->sb_flags & SB_OWNLOCK)
+	if (sb->sb_flags & SB_MTXLOCK)
 		sounlock(so);
 
 	mtx_enter(&sb->sb_mtx);
@@ -1334,10 +1339,10 @@ sorflush_locked(struct socket *so)
 void
 sorflush(struct socket *so)
 {
-	if ((so->so_rcv.sb_flags & SB_OWNLOCK) == 0)
+	if ((so->so_rcv.sb_flags & SB_MTXLOCK) == 0)
 		solock_shared(so);
 	sorflush_locked(so);
-	if ((so->so_rcv.sb_flags & SB_OWNLOCK) == 0)
+	if ((so->so_rcv.sb_flags & SB_MTXLOCK) == 0)
 		sounlock_shared(so);
 }
 
@@ -1383,7 +1388,7 @@ sosplice(struct socket *so, int fd, off_
 		membar_consumer();
 	}
 
-	if (so->so_rcv.sb_flags & SB_OWNLOCK) {
+	if (so->so_rcv.sb_flags & SB_MTXLOCK) {
 		if ((error = sblock(so, &so->so_rcv, SBL_WAIT)) != 0)
 			return (error);
 		solock(so);
@@ -1471,7 +1476,7 @@ sosplice(struct socket *so, int fd, off_
  release:
 	sbunlock(sosp, &sosp->so_snd);
  out:
-	if (so->so_rcv.sb_flags & SB_OWNLOCK) {
+	if (so->so_rcv.sb_flags & SB_MTXLOCK) {
 		sounlock(so);
 		sbunlock(so, &so->so_rcv);
 	} else {
@@ -1885,7 +1890,8 @@ sorwakeup(struct socket *so)
 void
 sowwakeup(struct socket *so)
 {
-	soassertlocked_readonly(so);
+	if ((so->so_snd.sb_flags & SB_MTXLOCK) == 0)
+		soassertlocked_readonly(so);
 
 #ifdef SOCKET_SPLICE
 	if (so->so_snd.sb_flags & SB_SPLICE)
@@ -1976,7 +1982,7 @@ sosetopt(struct socket *so, int level, i
 			if ((long)cnt <= 0)
 				cnt = 1;
 
-			if (((sb->sb_flags & SB_OWNLOCK) == 0))
+			if (((sb->sb_flags & SB_MTXLOCK) == 0))
 				solock(so);
 			mtx_enter(&sb->sb_mtx);
 
@@ -2003,7 +2009,7 @@ sosetopt(struct socket *so, int level, i
 			}
 
 			mtx_leave(&sb->sb_mtx);
-			if (((sb->sb_flags & SB_OWNLOCK) == 0))
+			if (((sb->sb_flags & SB_MTXLOCK) == 0))
 				sounlock(so);
 
 			break;
@@ -2380,7 +2386,8 @@ filt_sowrite(struct knote *kn, long hint
 	int rv;
 
 	MUTEX_ASSERT_LOCKED(&so->so_snd.sb_mtx);
-	soassertlocked_readonly(so);
+	if ((so->so_snd.sb_flags & SB_MTXLOCK) == 0)
+		soassertlocked_readonly(so);
 
 	kn->kn_data = sbspace(so, &so->so_snd);
 	if (so->so_snd.sb_state & SS_CANTSENDMORE) {
Index: sys/kern/uipc_socket2.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_socket2.c,v
diff -u -p -r1.151 uipc_socket2.c
--- sys/kern/uipc_socket2.c	30 Apr 2024 17:59:15 -0000	1.151
+++ sys/kern/uipc_socket2.c	30 Apr 2024 19:57:27 -0000
@@ -228,9 +228,10 @@ sonewconn(struct socket *head, int conns
 	 */
 	if (soreserve(so, head->so_snd.sb_hiwat, head->so_rcv.sb_hiwat))
 		goto fail;
+
+	mtx_enter(&head->so_snd.sb_mtx);
 	so->so_snd.sb_wat = head->so_snd.sb_wat;
 	so->so_snd.sb_lowat = head->so_snd.sb_lowat;
-	mtx_enter(&head->so_snd.sb_mtx);
 	so->so_snd.sb_timeo_nsecs = head->so_snd.sb_timeo_nsecs;
 	mtx_leave(&head->so_snd.sb_mtx);
 
@@ -334,7 +335,7 @@ socantsendmore(struct socket *so)
 void
 socantrcvmore(struct socket *so)
 {
-	if ((so->so_rcv.sb_flags & SB_OWNLOCK) == 0)
+	if ((so->so_rcv.sb_flags & SB_MTXLOCK) == 0)
 		soassertlocked(so);
 
 	mtx_enter(&so->so_rcv.sb_mtx);
@@ -543,7 +544,7 @@ sblock(struct socket *so, struct sockbuf
 {
 	int error = 0, prio = PSOCK;
 
-	if (sb->sb_flags & SB_OWNLOCK) {
+	if (sb->sb_flags & SB_MTXLOCK) {
 		int rwflags = RW_WRITE;
 
 		if (!(flags & SBL_NOINTR || sb->sb_flags & SB_NOINTR))
@@ -586,7 +587,7 @@ out:
 void
 sbunlock_locked(struct socket *so, struct sockbuf *sb)
 {
-	if (sb->sb_flags & SB_OWNLOCK) {
+	if (sb->sb_flags & SB_MTXLOCK) {
 		rw_exit(&sb->sb_lock);
 		return;
 	}
@@ -603,7 +604,7 @@ sbunlock_locked(struct socket *so, struc
 void
 sbunlock(struct socket *so, struct sockbuf *sb)
 {
-	if (sb->sb_flags & SB_OWNLOCK) {
+	if (sb->sb_flags & SB_MTXLOCK) {
 		rw_exit(&sb->sb_lock);
 		return;
 	}
Index: sys/kern/uipc_usrreq.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_usrreq.c,v
diff -u -p -r1.204 uipc_usrreq.c
--- sys/kern/uipc_usrreq.c	10 Apr 2024 12:04:41 -0000	1.204
+++ sys/kern/uipc_usrreq.c	30 Apr 2024 19:57:27 -0000
@@ -477,20 +477,24 @@ uipc_dgram_shutdown(struct socket *so)
 void
 uipc_rcvd(struct socket *so)
 {
+	struct unpcb *unp = sotounpcb(so);
 	struct socket *so2;
 
-	if ((so2 = unp_solock_peer(so)) == NULL)
+	if (unp->unp_conn == NULL)
 		return;
+	so2 = unp->unp_conn->unp_socket;
+
 	/*
 	 * Adjust backpressure on sender
 	 * and wakeup any waiting to write.
 	 */
 	mtx_enter(&so->so_rcv.sb_mtx);
+	mtx_enter(&so2->so_snd.sb_mtx);
 	so2->so_snd.sb_mbcnt = so->so_rcv.sb_mbcnt;
 	so2->so_snd.sb_cc = so->so_rcv.sb_cc;
+	mtx_leave(&so2->so_snd.sb_mtx);
 	mtx_leave(&so->so_rcv.sb_mtx);
 	sowwakeup(so2);
-	sounlock(so2);
 }
 
 int
@@ -509,6 +513,10 @@ uipc_send(struct socket *so, struct mbuf
 			goto out;
 	}
 
+	/*
+	 * Non-unlocked check because socantsendmore()
+	 * called with solock() held.
+	 */
 	if (so->so_snd.sb_state & SS_CANTSENDMORE) {
 		error = EPIPE;
 		goto dispose;
@@ -525,11 +533,17 @@ uipc_send(struct socket *so, struct mbuf
 	 * send buffer counts to maintain backpressure.
 	 * Wake up readers.
 	 */
+	/*
+	 * sbappend*() should be serialized together
+	 * with so_snd modification.
+	 */
 	mtx_enter(&so2->so_rcv.sb_mtx);
+	mtx_enter(&so->so_snd.sb_mtx);
 	if (control) {
 		if (sbappendcontrol(so2, &so2->so_rcv, m, control)) {
 			control = NULL;
 		} else {
+			mtx_leave(&so->so_snd.sb_mtx);
 			mtx_leave(&so2->so_rcv.sb_mtx);
 			error = ENOBUFS;
 			goto dispose;
@@ -542,6 +556,7 @@ uipc_send(struct socket *so, struct mbuf
 	so->so_snd.sb_cc = so2->so_rcv.sb_cc;
 	if (so2->so_rcv.sb_cc > 0)
 		dowakeup = 1;
+	mtx_leave(&so->so_snd.sb_mtx);
 	mtx_leave(&so2->so_rcv.sb_mtx);
 
 	if (dowakeup)
Index: sys/sys/socketvar.h
===================================================================
RCS file: /cvs/src/sys/sys/socketvar.h,v
diff -u -p -r1.129 socketvar.h
--- sys/sys/socketvar.h	11 Apr 2024 13:32:51 -0000	1.129
+++ sys/sys/socketvar.h	30 Apr 2024 19:57:27 -0000
@@ -134,8 +134,7 @@ struct socket {
 #define SB_ASYNC	0x0010		/* ASYNC I/O, need signals */
 #define SB_SPLICE	0x0020		/* buffer is splice source or drain */
 #define SB_NOINTR	0x0040		/* operations not interruptible */
-#define SB_MTXLOCK	0x0080		/* use sb_mtx for sockbuf protection */
-#define SB_OWNLOCK	0x0100		/* sblock() doesn't need solock() */
+#define SB_MTXLOCK	0x0080		/* sblock() doesn't need solock() */
 
 	void	(*so_upcall)(struct socket *so, caddr_t arg, int waitf);
 	caddr_t	so_upcallarg;		/* Arg for above */