Index | Thread | Search

From:
Alexander Bluhm <bluhm@openbsd.org>
Subject:
pcb notify socket lock
To:
tech@openbsd.org
Cc:
Vitaliy Makkoveev <mvs@openbsd.org>
Date:
Fri, 7 Feb 2025 23:31:35 +0100

Download raw body.

Thread
Hi,

The notify and ctlinput functions are not MP safe yet.  They need
socket lock which can be aquired by in_pcbsolock_ref().  Of course
in_pcbnotifyall() has to be called without holding a socket lock
then.

Rename in_rtchange() to in_pcbrtchange().  This is the correct
namespace and the funtions cares about the inpcb route.

ok?

bluhm

Index: netinet/in_pcb.c
===================================================================
RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/in_pcb.c,v
diff -u -p -r1.311 in_pcb.c
--- netinet/in_pcb.c	5 Feb 2025 10:15:10 -0000	1.311
+++ netinet/in_pcb.c	7 Feb 2025 21:05:40 -0000
@@ -810,6 +810,8 @@ in_pcbnotifyall(struct inpcbtable *table
 	rdomain = rtable_l2(rtable);
 	mtx_enter(&table->inpt_mtx);
 	while ((inp = in_pcb_iterator(table, inp, &iter)) != NULL) {
+		struct socket *so;
+
 		KASSERT(!ISSET(inp->inp_flags, INP_IPV6));
 
 		if (inp->inp_faddr.s_addr != dst->sin_addr.s_addr ||
@@ -817,7 +819,10 @@ in_pcbnotifyall(struct inpcbtable *table
 			continue;
 		}
 		mtx_leave(&table->inpt_mtx);
-		(*notify)(inp, errno);
+		so = in_pcbsolock_ref(inp);
+		if (so != NULL)
+			(*notify)(inp, errno);
+		in_pcbsounlock_rele(inp, so);
 		mtx_enter(&table->inpt_mtx);
 	}
 	mtx_leave(&table->inpt_mtx);
@@ -866,8 +871,10 @@ in_losing(struct inpcb *inp)
  * and allocate a (hopefully) better one.
  */
 void
-in_rtchange(struct inpcb *inp, int errno)
+in_pcbrtchange(struct inpcb *inp, int errno)
 {
+	soassertlocked(inp->inp_socket);
+
 	if (inp->inp_route.ro_rt) {
 		rtfree(inp->inp_route.ro_rt);
 		inp->inp_route.ro_rt = NULL;
Index: netinet/in_pcb.h
===================================================================
RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/in_pcb.h,v
diff -u -p -r1.164 in_pcb.h
--- netinet/in_pcb.h	9 Jan 2025 16:47:24 -0000	1.164
+++ netinet/in_pcb.h	7 Feb 2025 21:05:40 -0000
@@ -354,7 +354,7 @@ struct inpcb *
 void	 in_pcbnotifyall(struct inpcbtable *, const struct sockaddr_in *,
 	    u_int, int, void (*)(struct inpcb *, int));
 void	 in_pcbrehash(struct inpcb *);
-void	 in_rtchange(struct inpcb *, int);
+void	 in_pcbrtchange(struct inpcb *, int);
 void	 in_setpeeraddr(struct inpcb *, struct mbuf *);
 void	 in_setsockaddr(struct inpcb *, struct mbuf *);
 int	 in_sockaddr(struct socket *, struct mbuf *);
Index: netinet/tcp_subr.c
===================================================================
RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/tcp_subr.c,v
diff -u -p -r1.207 tcp_subr.c
--- netinet/tcp_subr.c	30 Jan 2025 14:40:50 -0000	1.207
+++ netinet/tcp_subr.c	7 Feb 2025 21:05:40 -0000
@@ -576,6 +576,8 @@ tcp_notify(struct inpcb *inp, int error)
 	struct tcpcb *tp = intotcpcb(inp);
 	struct socket *so = inp->inp_socket;
 
+	soassertlocked(so);
+
 	/*
 	 * Ignore some errors if we are hooked up.
 	 * If connection hasn't completed, has retransmitted several times,
@@ -602,12 +604,10 @@ void
 tcp6_ctlinput(int cmd, struct sockaddr *sa, u_int rdomain, void *d)
 {
 	struct tcphdr th;
-	struct tcpcb *tp;
 	void (*notify)(struct inpcb *, int) = tcp_notify;
 	struct ip6_hdr *ip6;
 	const struct sockaddr_in6 *sa6_src = NULL;
 	struct sockaddr_in6 *sa6 = satosin6(sa);
-	struct inpcb *inp;
 	struct mbuf *m;
 	tcp_seq seq;
 	int off;
@@ -633,7 +633,7 @@ tcp6_ctlinput(int cmd, struct sockaddr *
 		/* XXX there's no PRC_QUENCH in IPv6 */
 		return;
 	} else if (PRC_IS_REDIRECT(cmd))
-		notify = in_rtchange, d = NULL;
+		notify = in_pcbrtchange, d = NULL;
 	else if (cmd == PRC_MSGSIZE)
 		; /* special code is present, see below */
 	else if (cmd == PRC_HOSTDEAD)
@@ -655,6 +655,10 @@ tcp6_ctlinput(int cmd, struct sockaddr *
 	}
 
 	if (ip6) {
+		struct inpcb *inp;
+		struct socket *so = NULL;
+		struct tcpcb *tp = NULL;
+
 		/*
 		 * XXX: We assume that when ip6 is non NULL,
 		 * M and OFF are valid.
@@ -687,18 +691,27 @@ tcp6_ctlinput(int cmd, struct sockaddr *
 			in_pcbunref(inp);
 			return;
 		}
-		if (inp) {
+		if (inp != NULL)
+			so = in_pcbsolock_ref(inp);
+		if (so != NULL)
+			tp = intotcpcb(inp);
+		if (tp != NULL) {
 			seq = ntohl(th.th_seq);
 			if ((tp = intotcpcb(inp)) &&
 			    SEQ_GEQ(seq, tp->snd_una) &&
 			    SEQ_LT(seq, tp->snd_max))
 				notify(inp, inet6ctlerrmap[cmd]);
-		} else if (inet6ctlerrmap[cmd] == EHOSTUNREACH ||
+		}
+		in_pcbsounlock_rele(inp, so);
+		in_pcbunref(inp);
+
+		if (tp == NULL &&
+		    (inet6ctlerrmap[cmd] == EHOSTUNREACH ||
 		    inet6ctlerrmap[cmd] == ENETUNREACH ||
-		    inet6ctlerrmap[cmd] == EHOSTDOWN)
+		    inet6ctlerrmap[cmd] == EHOSTDOWN)) {
 			syn_cache_unreach(sin6tosa_const(sa6_src), sa, &th,
 			    rdomain);
-		in_pcbunref(inp);
+		}
 	} else {
 		in6_pcbnotify(&tcb6table, sa6, 0,
 		    sa6_src, 0, rdomain, cmd, NULL, notify);
@@ -711,8 +724,6 @@ tcp_ctlinput(int cmd, struct sockaddr *s
 {
 	struct ip *ip = v;
 	struct tcphdr *th;
-	struct tcpcb *tp;
-	struct inpcb *inp;
 	struct in_addr faddr;
 	tcp_seq seq;
 	u_int mtu;
@@ -735,8 +746,12 @@ tcp_ctlinput(int cmd, struct sockaddr *s
 		 */
 		return;
 	else if (PRC_IS_REDIRECT(cmd))
-		notify = in_rtchange, ip = 0;
+		notify = in_pcbrtchange, ip = NULL;
 	else if (cmd == PRC_MSGSIZE && ip_mtudisc && ip) {
+		struct inpcb *inp;
+		struct socket *so = NULL;
+		struct tcpcb *tp = NULL;
+
 		/*
 		 * Verify that the packet in the icmp payload refers
 		 * to an existing TCP connection.
@@ -746,7 +761,11 @@ tcp_ctlinput(int cmd, struct sockaddr *s
 		inp = in_pcblookup(&tcbtable,
 		    ip->ip_dst, th->th_dport, ip->ip_src, th->th_sport,
 		    rdomain);
-		if (inp && (tp = intotcpcb(inp)) &&
+		if (inp != NULL)
+			so = in_pcbsolock_ref(inp);
+		if (so != NULL)
+			tp = intotcpcb(inp);
+		if (tp != NULL &&
 		    SEQ_GEQ(seq, tp->snd_una) &&
 		    SEQ_LT(seq, tp->snd_max)) {
 			struct icmp *icp;
@@ -760,6 +779,7 @@ tcp_ctlinput(int cmd, struct sockaddr *s
 			 */
 			mtu = (u_int)ntohs(icp->icmp_nextmtu);
 			if (mtu >= tp->t_pmtud_mtu_sent) {
+				in_pcbsounlock_rele(inp, so);
 				in_pcbunref(inp);
 				return;
 			}
@@ -780,6 +800,7 @@ tcp_ctlinput(int cmd, struct sockaddr *s
 				 */
 				if (tp->t_flags & TF_PMTUD_PEND) {
 					if (SEQ_LT(tp->t_pmtud_th_seq, seq)) {
+						in_pcbsounlock_rele(inp, so);
 						in_pcbunref(inp);
 						return;
 					}
@@ -789,37 +810,52 @@ tcp_ctlinput(int cmd, struct sockaddr *s
 				tp->t_pmtud_nextmtu = icp->icmp_nextmtu;
 				tp->t_pmtud_ip_len = icp->icmp_ip.ip_len;
 				tp->t_pmtud_ip_hl = icp->icmp_ip.ip_hl;
+				in_pcbsounlock_rele(inp, so);
 				in_pcbunref(inp);
 				return;
 			}
 		} else {
 			/* ignore if we don't have a matching connection */
+			in_pcbsounlock_rele(inp, so);
 			in_pcbunref(inp);
 			return;
 		}
+		in_pcbsounlock_rele(inp, so);
 		in_pcbunref(inp);
-		notify = tcp_mtudisc, ip = 0;
+		notify = tcp_mtudisc, ip = NULL;
 	} else if (cmd == PRC_MTUINC)
-		notify = tcp_mtudisc_increase, ip = 0;
+		notify = tcp_mtudisc_increase, ip = NULL;
 	else if (cmd == PRC_HOSTDEAD)
-		ip = 0;
+		ip = NULL;
 	else if (errno == 0)
 		return;
 
 	if (ip) {
+		struct inpcb *inp;
+		struct socket *so = NULL;
+		struct tcpcb *tp = NULL;
+
 		th = (struct tcphdr *)((caddr_t)ip + (ip->ip_hl << 2));
 		inp = in_pcblookup(&tcbtable,
 		    ip->ip_dst, th->th_dport, ip->ip_src, th->th_sport,
 		    rdomain);
-		if (inp) {
+		if (inp != NULL)
+			so = in_pcbsolock_ref(inp);
+		if (so != NULL)
+			tp = intotcpcb(inp);
+		if (tp != NULL) {
 			seq = ntohl(th->th_seq);
-			if ((tp = intotcpcb(inp)) &&
-			    SEQ_GEQ(seq, tp->snd_una) &&
+			if (SEQ_GEQ(seq, tp->snd_una) &&
 			    SEQ_LT(seq, tp->snd_max))
 				notify(inp, errno);
-		} else if (inetctlerrmap[cmd] == EHOSTUNREACH ||
+		}
+		in_pcbsounlock_rele(inp, so);
+		in_pcbunref(inp);
+
+		if (tp == NULL &&
+		    (inetctlerrmap[cmd] == EHOSTUNREACH ||
 		    inetctlerrmap[cmd] == ENETUNREACH ||
-		    inetctlerrmap[cmd] == EHOSTDOWN) {
+		    inetctlerrmap[cmd] == EHOSTDOWN)) {
 			struct sockaddr_in sin;
 
 			bzero(&sin, sizeof(sin));
@@ -829,7 +865,6 @@ tcp_ctlinput(int cmd, struct sockaddr *s
 			sin.sin_addr = ip->ip_src;
 			syn_cache_unreach(sintosa(&sin), sa, th, rdomain);
 		}
-		in_pcbunref(inp);
 	} else
 		in_pcbnotifyall(&tcbtable, satosin(sa), rdomain, errno, notify);
 }
@@ -871,7 +906,7 @@ tcp_mtudisc(struct inpcb *inp, int errno
 		 * If this was not a host route, remove and realloc.
 		 */
 		if ((rt->rt_flags & RTF_HOST) == 0) {
-			in_rtchange(inp, errno);
+			in_pcbrtchange(inp, errno);
 			if ((rt = in_pcbrtentry(inp)) == NULL)
 				return;
 		}
@@ -901,7 +936,7 @@ tcp_mtudisc_increase(struct inpcb *inp, 
 		 * If this was a host route, remove and realloc.
 		 */
 		if (rt->rt_flags & RTF_HOST)
-			in_rtchange(inp, errno);
+			in_pcbrtchange(inp, errno);
 
 		/* also takes care of congestion window */
 		tcp_mss(tp, -1);
Index: netinet/tcp_timer.c
===================================================================
RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/tcp_timer.c,v
diff -u -p -r1.82 tcp_timer.c
--- netinet/tcp_timer.c	16 Jan 2025 11:59:20 -0000	1.82
+++ netinet/tcp_timer.c	7 Feb 2025 21:05:40 -0000
@@ -214,17 +214,19 @@ tcp_timer_rexmt(void *arg)
 	    SEQ_LT(tp->t_pmtud_th_seq, (int)(tp->snd_una + tp->t_maxseg))) {
 		struct sockaddr_in sin;
 		struct icmp icmp;
+		u_int rtableid;
 
 		/* TF_PMTUD_PEND is set in tcp_ctlinput() which is IPv4 only */
 		KASSERT(!ISSET(inp->inp_flags, INP_IPV6));
 		tp->t_flags &= ~TF_PMTUD_PEND;
 
+		rtableid = inp->inp_rtableid;
+
 		/* XXX create fake icmp message with relevant entries */
 		icmp.icmp_nextmtu = tp->t_pmtud_nextmtu;
 		icmp.icmp_ip.ip_len = tp->t_pmtud_ip_len;
 		icmp.icmp_ip.ip_hl = tp->t_pmtud_ip_hl;
 		icmp.icmp_ip.ip_dst = inp->inp_faddr;
-		icmp_mtudisc(&icmp, inp->inp_rtableid);
 
 		/*
 		 * Notify all connections to the same peer about
@@ -234,9 +236,16 @@ tcp_timer_rexmt(void *arg)
 		sin.sin_len = sizeof(sin);
 		sin.sin_family = AF_INET;
 		sin.sin_addr = inp->inp_faddr;
-		in_pcbnotifyall(&tcbtable, &sin, inp->inp_rtableid, EMSGSIZE,
+
+		in_pcbsounlock_rele(inp, so);
+		in_pcbunref(inp);
+
+		icmp_mtudisc(&icmp, rtableid);
+		in_pcbnotifyall(&tcbtable, &sin, rtableid, EMSGSIZE,
 		    tcp_mtudisc);
-		goto out;
+
+		NET_UNLOCK_SHARED();
+		return;
 	}
 
 	tcp_timer_freesack(tp);
@@ -303,7 +312,7 @@ tcp_timer_rexmt(void *arg)
 			/* Disable path MTU discovery */
 			if ((rt->rt_locks & RTV_MTU) == 0) {
 				rt->rt_locks |= RTV_MTU;
-				in_rtchange(inp, 0);
+				in_pcbrtchange(inp, 0);
 			}
 
 			rtfree(rt);
Index: netinet/udp_usrreq.c
===================================================================
RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/udp_usrreq.c,v
diff -u -p -r1.331 udp_usrreq.c
--- netinet/udp_usrreq.c	6 Feb 2025 13:40:58 -0000	1.331
+++ netinet/udp_usrreq.c	7 Feb 2025 21:05:40 -0000
@@ -748,7 +748,7 @@ udp6_ctlinput(int cmd, struct sockaddr *
 	if ((unsigned)cmd >= PRC_NCMDS)
 		return;
 	if (PRC_IS_REDIRECT(cmd))
-		notify = in_rtchange, d = NULL;
+		notify = in_pcbrtchange, d = NULL;
 	else if (cmd == PRC_HOSTDEAD)
 		d = NULL;
 	else if (cmd == PRC_MSGSIZE)
@@ -880,7 +880,6 @@ udp_ctlinput(int cmd, struct sockaddr *s
 	struct ip *ip = v;
 	struct udphdr *uhp;
 	struct in_addr faddr;
-	struct inpcb *inp;
 	void (*notify)(struct inpcb *, int) = udp_notify;
 	int errno;
 
@@ -897,14 +896,17 @@ udp_ctlinput(int cmd, struct sockaddr *s
 		return;
 	errno = inetctlerrmap[cmd];
 	if (PRC_IS_REDIRECT(cmd))
-		notify = in_rtchange, ip = 0;
+		notify = in_pcbrtchange, ip = NULL;
 	else if (cmd == PRC_HOSTDEAD)
-		ip = 0;
+		ip = NULL;
 	else if (errno == 0)
 		return;
+
 	if (ip) {
-		uhp = (struct udphdr *)((caddr_t)ip + (ip->ip_hl << 2));
+		struct inpcb *inp;
+		struct socket *so = NULL;
 
+		uhp = (struct udphdr *)((caddr_t)ip + (ip->ip_hl << 2));
 #ifdef IPSEC
 		/* PMTU discovery for udpencap */
 		if (cmd == PRC_MSGSIZE && ip_mtudisc && udpencap_enable &&
@@ -917,7 +919,10 @@ udp_ctlinput(int cmd, struct sockaddr *s
 		    ip->ip_dst, uhp->uh_dport, ip->ip_src, uhp->uh_sport,
 		    rdomain);
 		if (inp != NULL)
+			so = in_pcbsolock_ref(inp);
+		if (so != NULL)
 			notify(inp, errno);
+		in_pcbsounlock_rele(inp, so);
 		in_pcbunref(inp);
 	} else
 		in_pcbnotifyall(&udbtable, satosin(sa), rdomain, errno, notify);
Index: netinet6/in6_pcb.c
===================================================================
RCS file: /data/mirror/openbsd/cvs/src/sys/netinet6/in6_pcb.c,v
diff -u -p -r1.146 in6_pcb.c
--- netinet6/in6_pcb.c	21 Dec 2024 00:10:04 -0000	1.146
+++ netinet6/in6_pcb.c	7 Feb 2025 21:05:40 -0000
@@ -456,8 +456,8 @@ in6_pcbnotify(struct inpcbtable *table, 
 
 	/*
 	 * Redirects go to all references to the destination,
-	 * and use in_rtchange to invalidate the route cache.
-	 * Dead host indications: also use in_rtchange to invalidate
+	 * and use in_pcbrtchange to invalidate the route cache.
+	 * Dead host indications: also use in_pcbrtchange to invalidate
 	 * the cache, and deliver the error to all the sockets.
 	 * Otherwise, if we have knowledge of the local port and address,
 	 * deliver only to that socket.
@@ -468,7 +468,7 @@ in6_pcbnotify(struct inpcbtable *table, 
 		sa6_src.sin6_addr = in6addr_any;
 
 		if (cmd != PRC_HOSTDEAD)
-			notify = in_rtchange;
+			notify = in_pcbrtchange;
 	}
 	errno = inet6ctlerrmap[cmd];
 	if (notify == NULL)
@@ -477,6 +477,8 @@ in6_pcbnotify(struct inpcbtable *table, 
 	rdomain = rtable_l2(rtable);
 	mtx_enter(&table->inpt_mtx);
 	while ((inp = in_pcb_iterator(table, inp, &iter)) != NULL) {
+		struct socket *so;
+
 		KASSERT(ISSET(inp->inp_flags, INP_IPV6));
 
 		/*
@@ -543,7 +545,10 @@ in6_pcbnotify(struct inpcbtable *table, 
 		}
 	  do_notify:
 		mtx_leave(&table->inpt_mtx);
-		(*notify)(inp, errno);
+		so = in_pcbsolock_ref(inp);
+		if (so != NULL)
+			(*notify)(inp, errno);
+		in_pcbsounlock_rele(inp, so);
 		mtx_enter(&table->inpt_mtx);
 	}
 	mtx_leave(&table->inpt_mtx);
Index: netinet6/raw_ip6.c
===================================================================
RCS file: /data/mirror/openbsd/cvs/src/sys/netinet6/raw_ip6.c,v
diff -u -p -r1.187 raw_ip6.c
--- netinet6/raw_ip6.c	6 Feb 2025 13:40:58 -0000	1.187
+++ netinet6/raw_ip6.c	7 Feb 2025 21:05:40 -0000
@@ -308,7 +308,7 @@ rip6_ctlinput(int cmd, struct sockaddr *
 	struct sockaddr_in6 *sa6 = satosin6(sa);
 	const struct sockaddr_in6 *sa6_src = NULL;
 	void *cmdarg;
-	void (*notify)(struct inpcb *, int) = in_rtchange;
+	void (*notify)(struct inpcb *, int) = in_pcbrtchange;
 	int nxt;
 
 	if (sa->sa_family != AF_INET6 ||
@@ -318,7 +318,7 @@ rip6_ctlinput(int cmd, struct sockaddr *
 	if ((unsigned)cmd >= PRC_NCMDS)
 		return;
 	if (PRC_IS_REDIRECT(cmd))
-		notify = in_rtchange, d = NULL;
+		notify = in_pcbrtchange, d = NULL;
 	else if (cmd == PRC_HOSTDEAD)
 		d = NULL;
 	else if (cmd == PRC_MSGSIZE)