Index | Thread | Search

From:
Vitaliy Makkoveev <mvs@openbsd.org>
Subject:
Re: Don't take solock() in soreceive() for SOCK_RAW inet sockets
To:
Alexander Bluhm <bluhm@openbsd.org>
Cc:
tech@openbsd.org
Date:
Wed, 10 Apr 2024 16:55:49 +0300

Download raw body.

Thread
Updated to be on top of the tree. Against previous, the shared solock()
used around sorflush_locked(). Only unix(4) sockets set (*dom_dispose)()
handler and it actually doesn't need any protection.

Index: sys/kern/uipc_socket.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_socket.c,v
retrieving revision 1.328
diff -u -p -r1.328 uipc_socket.c
--- sys/kern/uipc_socket.c	10 Apr 2024 12:04:41 -0000	1.328
+++ sys/kern/uipc_socket.c	10 Apr 2024 13:54:34 -0000
@@ -66,6 +66,7 @@ void	soreaper(void *);
 void	soput(void *);
 int	somove(struct socket *, int);
 void	sorflush(struct socket *);
+void	sorflush_locked(struct socket *);
 
 void	filt_sordetach(struct knote *kn);
 int	filt_soread(struct knote *kn, long hint);
@@ -143,6 +144,8 @@ soalloc(const struct protosw *prp, int w
 		return (NULL);
 	rw_init_flags(&so->so_lock, dp->dom_name, RWL_DUPOK);
 	refcnt_init(&so->so_refcnt);
+	rw_init(&so->so_rcv.sb_lock, "sbufrcv");
+	rw_init(&so->so_snd.sb_lock, "sbufsnd");
 	mtx_init(&so->so_rcv.sb_mtx, IPL_MPFLOOR);
 	mtx_init(&so->so_snd.sb_mtx, IPL_MPFLOOR);
 	klist_init_mutex(&so->so_rcv.sb_klist, &so->so_rcv.sb_mtx);
@@ -156,15 +159,15 @@ soalloc(const struct protosw *prp, int w
 	case AF_INET6:
 		switch (prp->pr_type) {
 		case SOCK_DGRAM:
-			so->so_rcv.sb_flags |= SB_OWNLOCK;
-			/* FALLTHROUGH */
-		case SOCK_RAW:
 			so->so_rcv.sb_flags |= SB_MTXLOCK;
 			break;
+		case SOCK_RAW:
+			so->so_rcv.sb_flags |= SB_MTXLOCK | SB_OWNLOCK;
+			break;
 		}
 		break;
 	case AF_UNIX:
-		so->so_rcv.sb_flags |= SB_MTXLOCK | SB_OWNLOCK;
+		so->so_rcv.sb_flags |= SB_MTXLOCK;
 		break;
 	}
 
@@ -346,9 +349,22 @@ sofree(struct socket *so, int keep_lock)
 	}
 #endif /* SOCKET_SPLICE */
 	sbrelease(so, &so->so_snd);
-	sorflush(so);
-	if (!keep_lock)
+
+	/*
+	 * Regardless on '_locked' postfix, must release solock() before
+	 * call sorflush_locked() for SB_OWNLOCK marked socket. Can't
+	 * release solock() and call sorflush() because solock() release
+	 * is unwanted for tcp(4) socket. 
+	 */
+
+	if (so->so_rcv.sb_flags & SB_OWNLOCK)
+		sounlock(so);
+
+	sorflush_locked(so);
+
+	if (!((so->so_rcv.sb_flags & SB_OWNLOCK) || keep_lock))
 		sounlock(so);
+
 #ifdef SOCKET_SPLICE
 	if (so->so_sp) {
 		/* Reuse splice idle, sounsplice() has been called before. */
@@ -807,6 +823,7 @@ soreceive(struct socket *so, struct mbuf
 	const struct protosw *pr = so->so_proto;
 	struct mbuf *nextrecord;
 	size_t resid, orig_resid = uio->uio_resid;
+	int dosolock = ((so->so_rcv.sb_flags & SB_OWNLOCK) == 0);
 
 	mp = mp0;
 	if (paddr)
@@ -836,12 +853,11 @@ bad:
 	if (mp)
 		*mp = NULL;
 
-	solock_shared(so);
+	if (dosolock)
+		solock_shared(so);
 restart:
-	if ((error = sblock(so, &so->so_rcv, SBLOCKWAIT(flags))) != 0) {
-		sounlock_shared(so);
-		return (error);
-	}
+	if ((error = sblock(so, &so->so_rcv, SBLOCKWAIT(flags))) != 0)
+		goto out;
 	sb_mtx_lock(&so->so_rcv);
 
 	m = so->so_rcv.sb_mb;
@@ -906,14 +922,16 @@ restart:
 		SBLASTRECORDCHK(&so->so_rcv, "soreceive sbwait 1");
 		SBLASTMBUFCHK(&so->so_rcv, "soreceive sbwait 1");
 
-		if (so->so_rcv.sb_flags & SB_OWNLOCK) {
+		if (so->so_rcv.sb_flags & (SB_MTXLOCK | SB_OWNLOCK)) {
 			sbunlock_locked(so, &so->so_rcv);
-			sounlock_shared(so);
+			if (dosolock)
+				sounlock_shared(so);
 			error = sbwait_locked(so, &so->so_rcv);
 			sb_mtx_unlock(&so->so_rcv);
 			if (error)
 				return (error);
-			solock_shared(so);
+			if (dosolock)
+				solock_shared(so);
 		} else {
 			sb_mtx_unlock(&so->so_rcv);
 			sbunlock(so, &so->so_rcv);
@@ -990,11 +1008,13 @@ dontblock:
 			if (controlp) {
 				if (pr->pr_domain->dom_externalize) {
 					sb_mtx_unlock(&so->so_rcv);
-					sounlock_shared(so);
+					if (dosolock)
+						sounlock_shared(so);
 					error =
 					    (*pr->pr_domain->dom_externalize)
 					    (cm, controllen, flags);
-					solock_shared(so);
+					if (dosolock)
+						solock_shared(so);
 					sb_mtx_lock(&so->so_rcv);
 				}
 				*controlp = cm;
@@ -1073,9 +1093,11 @@ dontblock:
 			SBLASTMBUFCHK(&so->so_rcv, "soreceive uiomove");
 			resid = uio->uio_resid;
 			sb_mtx_unlock(&so->so_rcv);
-			sounlock_shared(so);
+			if (dosolock)
+				sounlock_shared(so);
 			uio_error = uiomove(mtod(m, caddr_t) + moff, len, uio);
-			solock_shared(so);
+			if (dosolock)
+				solock_shared(so);
 			sb_mtx_lock(&so->so_rcv);
 			if (uio_error)
 				uio->uio_resid = resid - len;
@@ -1158,14 +1180,21 @@ dontblock:
 				break;
 			SBLASTRECORDCHK(&so->so_rcv, "soreceive sbwait 2");
 			SBLASTMBUFCHK(&so->so_rcv, "soreceive sbwait 2");
-			sb_mtx_unlock(&so->so_rcv);
-			error = sbwait(so, &so->so_rcv);
-			if (error) {
-				sbunlock(so, &so->so_rcv);
-				sounlock_shared(so);
-				return (0);
+			if (dosolock) {
+				sb_mtx_unlock(&so->so_rcv);
+				error = sbwait(so, &so->so_rcv);
+				if (error) {
+					sbunlock(so, &so->so_rcv);
+					sounlock_shared(so);
+					return (0);
+				}
+				sb_mtx_lock(&so->so_rcv);
+			} else {
+				if (sbwait_locked(so, &so->so_rcv)) {
+					error = 0;
+					goto release;
+				}
 			}
-			sb_mtx_lock(&so->so_rcv);
 			if ((m = so->so_rcv.sb_mb) != NULL)
 				nextrecord = m->m_nextpkt;
 		}
@@ -1214,7 +1243,9 @@ dontblock:
 release:
 	sb_mtx_unlock(&so->so_rcv);
 	sbunlock(so, &so->so_rcv);
-	sounlock_shared(so);
+out:
+	if (dosolock)
+		sounlock_shared(so);
 	return (error);
 }
 
@@ -1223,7 +1254,6 @@ soshutdown(struct socket *so, int how)
 {
 	int error = 0;
 
-	solock(so);
 	switch (how) {
 	case SHUT_RD:
 		sorflush(so);
@@ -1232,25 +1262,29 @@ soshutdown(struct socket *so, int how)
 		sorflush(so);
 		/* FALLTHROUGH */
 	case SHUT_WR:
+		solock(so);
 		error = pru_shutdown(so);
+		sounlock(so);
 		break;
 	default:
 		error = EINVAL;
 		break;
 	}
-	sounlock(so);
 
 	return (error);
 }
 
 void
-sorflush(struct socket *so)
+sorflush_locked(struct socket *so)
 {
 	struct sockbuf *sb = &so->so_rcv;
 	struct mbuf *m;
 	const struct protosw *pr = so->so_proto;
 	int error;
 
+	if ((sb->sb_flags & SB_OWNLOCK) == 0)
+		soassertlocked(so);
+
 	error = sblock(so, sb, SBL_WAIT | SBL_NOINTR);
 	/* with SBL_WAIT and SLB_NOINTR sblock() must not fail */
 	KASSERT(error == 0);
@@ -1267,6 +1301,16 @@ sorflush(struct socket *so)
 	m_purge(m);
 }
 
+void
+sorflush(struct socket *so)
+{
+	if ((so->so_rcv.sb_flags & SB_OWNLOCK) == 0)
+		solock_shared(so);
+	sorflush_locked(so);
+	if ((so->so_rcv.sb_flags & SB_OWNLOCK) == 0)
+		sounlock_shared(so);
+}
+
 #ifdef SOCKET_SPLICE
 
 #define so_splicelen	so_sp->ssp_len
@@ -1905,7 +1949,8 @@ sosetopt(struct socket *so, int level, i
 			if ((long)cnt <= 0)
 				cnt = 1;
 
-			solock(so);
+			if (((sb->sb_flags & SB_OWNLOCK) == 0))
+				solock(so);
 			mtx_enter(&sb->sb_mtx);
 
 			switch (optname) {
@@ -1931,7 +1976,8 @@ sosetopt(struct socket *so, int level, i
 			}
 
 			mtx_leave(&sb->sb_mtx);
-			sounlock(so);
+			if (((sb->sb_flags & SB_OWNLOCK) == 0))
+				sounlock(so);
 
 			break;
 		    }
Index: sys/kern/uipc_socket2.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_socket2.c,v
retrieving revision 1.148
diff -u -p -r1.148 uipc_socket2.c
--- sys/kern/uipc_socket2.c	10 Apr 2024 12:04:41 -0000	1.148
+++ sys/kern/uipc_socket2.c	10 Apr 2024 13:54:34 -0000
@@ -322,7 +322,9 @@ socantsendmore(struct socket *so)
 void
 socantrcvmore(struct socket *so)
 {
-	soassertlocked(so);
+	if ((so->so_rcv.sb_flags & SB_OWNLOCK) == 0)
+		soassertlocked(so);
+
 	mtx_enter(&so->so_rcv.sb_mtx);
 	so->so_rcv.sb_state |= SS_CANTRCVMORE;
 	mtx_leave(&so->so_rcv.sb_mtx);
@@ -529,6 +531,17 @@ sblock(struct socket *so, struct sockbuf
 {
 	int error = 0, prio = PSOCK;
 
+	if (sb->sb_flags & SB_OWNLOCK) {
+		int rwflags = RW_WRITE;
+
+		if (!(flags & SBL_NOINTR || sb->sb_flags & SB_NOINTR))
+			rwflags |= RW_INTR;
+		if (!(flags & SBL_WAIT))
+			rwflags |= RW_NOSLEEP;
+
+		return rw_enter(&sb->sb_lock, rwflags);
+	}
+
 	soassertlocked(so);
 
 	mtx_enter(&sb->sb_mtx);
@@ -561,6 +574,11 @@ out:
 void
 sbunlock_locked(struct socket *so, struct sockbuf *sb)
 {
+	if (sb->sb_flags & SB_OWNLOCK) {
+		rw_exit(&sb->sb_lock);
+		return;
+	}
+
 	MUTEX_ASSERT_LOCKED(&sb->sb_mtx);
 
 	sb->sb_flags &= ~SB_LOCK;
@@ -573,6 +591,11 @@ sbunlock_locked(struct socket *so, struc
 void
 sbunlock(struct socket *so, struct sockbuf *sb)
 {
+	if (sb->sb_flags & SB_OWNLOCK) {
+		rw_exit(&sb->sb_lock);
+		return;
+	}
+
 	mtx_enter(&sb->sb_mtx);
 	sbunlock_locked(so, sb);
 	mtx_leave(&sb->sb_mtx);
Index: sys/sys/socketvar.h
===================================================================
RCS file: /cvs/src/sys/sys/socketvar.h,v
retrieving revision 1.128
diff -u -p -r1.128 socketvar.h
--- sys/sys/socketvar.h	10 Apr 2024 12:04:41 -0000	1.128
+++ sys/sys/socketvar.h	10 Apr 2024 13:54:34 -0000
@@ -105,7 +105,8 @@ struct socket {
  * Variables for socket buffering.
  */
 	struct	sockbuf {
-		struct mutex sb_mtx;
+		struct rwlock sb_lock; 
+		struct mutex  sb_mtx;
 /* The following fields are all zeroed on flush. */
 #define	sb_startzero	sb_cc
 		u_long	sb_cc;		/* actual chars in buffer */
@@ -134,7 +135,7 @@ struct socket {
 #define SB_SPLICE	0x0020		/* buffer is splice source or drain */
 #define SB_NOINTR	0x0040		/* operations not interruptible */
 #define SB_MTXLOCK	0x0080		/* use sb_mtx for sockbuf protection */
-#define SB_OWNLOCK	0x0100		/* sb_mtx used standalone */
+#define SB_OWNLOCK	0x0100		/* sblock() doesn't need solock() */
 
 	void	(*so_upcall)(struct socket *so, caddr_t arg, int waitf);
 	caddr_t	so_upcallarg;		/* Arg for above */