Index | Thread | Search

From:
Vitaliy Makkoveev <mvs@openbsd.org>
Subject:
Re: pcb notify socket lock
To:
Alexander Bluhm <alexander.bluhm@gmx.net>
Cc:
tech@openbsd.org
Date:
Wed, 12 Feb 2025 23:47:17 +0300

Download raw body.

Thread
> On 12 Feb 2025, at 22:44, Alexander Bluhm <alexander.bluhm@gmx.net> wrote:
> 
> On Fri, Feb 07, 2025 at 11:31:35PM +0100, Alexander Bluhm wrote:
>> Hi,
>> 
>> The notify and ctlinput functions are not MP safe yet.  They need
>> socket lock which can be aquired by in_pcbsolock_ref().  Of course
>> in_pcbnotifyall() has to be called without holding a socket lock
>> then.
>> 
>> Rename in_rtchange() to in_pcbrtchange().  This is the correct
>> namespace and the funtions cares about the inpcb route.
>> 
>> ok?
> 
> anyone?
> 

ok mvs

>> bluhm
>> 
>> Index: netinet/in_pcb.c
>> ===================================================================
>> RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/in_pcb.c,v
>> diff -u -p -r1.311 in_pcb.c
>> --- netinet/in_pcb.c	5 Feb 2025 10:15:10 -0000	1.311
>> +++ netinet/in_pcb.c	7 Feb 2025 21:05:40 -0000
>> @@ -810,6 +810,8 @@ in_pcbnotifyall(struct inpcbtable *table
>> 	rdomain = rtable_l2(rtable);
>> 	mtx_enter(&table->inpt_mtx);
>> 	while ((inp = in_pcb_iterator(table, inp, &iter)) != NULL) {
>> +		struct socket *so;
>> +
>> 		KASSERT(!ISSET(inp->inp_flags, INP_IPV6));
>> 
>> 		if (inp->inp_faddr.s_addr != dst->sin_addr.s_addr ||
>> @@ -817,7 +819,10 @@ in_pcbnotifyall(struct inpcbtable *table
>> 			continue;
>> 		}
>> 		mtx_leave(&table->inpt_mtx);
>> -		(*notify)(inp, errno);
>> +		so = in_pcbsolock_ref(inp);
>> +		if (so != NULL)
>> +			(*notify)(inp, errno);
>> +		in_pcbsounlock_rele(inp, so);
>> 		mtx_enter(&table->inpt_mtx);
>> 	}
>> 	mtx_leave(&table->inpt_mtx);
>> @@ -866,8 +871,10 @@ in_losing(struct inpcb *inp)
>>  * and allocate a (hopefully) better one.
>>  */
>> void
>> -in_rtchange(struct inpcb *inp, int errno)
>> +in_pcbrtchange(struct inpcb *inp, int errno)
>> {
>> +	soassertlocked(inp->inp_socket);
>> +
>> 	if (inp->inp_route.ro_rt) {
>> 		rtfree(inp->inp_route.ro_rt);
>> 		inp->inp_route.ro_rt = NULL;
>> Index: netinet/in_pcb.h
>> ===================================================================
>> RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/in_pcb.h,v
>> diff -u -p -r1.164 in_pcb.h
>> --- netinet/in_pcb.h	9 Jan 2025 16:47:24 -0000	1.164
>> +++ netinet/in_pcb.h	7 Feb 2025 21:05:40 -0000
>> @@ -354,7 +354,7 @@ struct inpcb *
>> void	 in_pcbnotifyall(struct inpcbtable *, const struct sockaddr_in *,
>> 	    u_int, int, void (*)(struct inpcb *, int));
>> void	 in_pcbrehash(struct inpcb *);
>> -void	 in_rtchange(struct inpcb *, int);
>> +void	 in_pcbrtchange(struct inpcb *, int);
>> void	 in_setpeeraddr(struct inpcb *, struct mbuf *);
>> void	 in_setsockaddr(struct inpcb *, struct mbuf *);
>> int	 in_sockaddr(struct socket *, struct mbuf *);
>> Index: netinet/tcp_subr.c
>> ===================================================================
>> RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/tcp_subr.c,v
>> diff -u -p -r1.207 tcp_subr.c
>> --- netinet/tcp_subr.c	30 Jan 2025 14:40:50 -0000	1.207
>> +++ netinet/tcp_subr.c	7 Feb 2025 21:05:40 -0000
>> @@ -576,6 +576,8 @@ tcp_notify(struct inpcb *inp, int error)
>> 	struct tcpcb *tp = intotcpcb(inp);
>> 	struct socket *so = inp->inp_socket;
>> 
>> +	soassertlocked(so);
>> +
>> 	/*
>> 	 * Ignore some errors if we are hooked up.
>> 	 * If connection hasn't completed, has retransmitted several times,
>> @@ -602,12 +604,10 @@ void
>> tcp6_ctlinput(int cmd, struct sockaddr *sa, u_int rdomain, void *d)
>> {
>> 	struct tcphdr th;
>> -	struct tcpcb *tp;
>> 	void (*notify)(struct inpcb *, int) = tcp_notify;
>> 	struct ip6_hdr *ip6;
>> 	const struct sockaddr_in6 *sa6_src = NULL;
>> 	struct sockaddr_in6 *sa6 = satosin6(sa);
>> -	struct inpcb *inp;
>> 	struct mbuf *m;
>> 	tcp_seq seq;
>> 	int off;
>> @@ -633,7 +633,7 @@ tcp6_ctlinput(int cmd, struct sockaddr *
>> 		/* XXX there's no PRC_QUENCH in IPv6 */
>> 		return;
>> 	} else if (PRC_IS_REDIRECT(cmd))
>> -		notify = in_rtchange, d = NULL;
>> +		notify = in_pcbrtchange, d = NULL;
>> 	else if (cmd == PRC_MSGSIZE)
>> 		; /* special code is present, see below */
>> 	else if (cmd == PRC_HOSTDEAD)
>> @@ -655,6 +655,10 @@ tcp6_ctlinput(int cmd, struct sockaddr *
>> 	}
>> 
>> 	if (ip6) {
>> +		struct inpcb *inp;
>> +		struct socket *so = NULL;
>> +		struct tcpcb *tp = NULL;
>> +
>> 		/*
>> 		 * XXX: We assume that when ip6 is non NULL,
>> 		 * M and OFF are valid.
>> @@ -687,18 +691,27 @@ tcp6_ctlinput(int cmd, struct sockaddr *
>> 			in_pcbunref(inp);
>> 			return;
>> 		}
>> -		if (inp) {
>> +		if (inp != NULL)
>> +			so = in_pcbsolock_ref(inp);
>> +		if (so != NULL)
>> +			tp = intotcpcb(inp);
>> +		if (tp != NULL) {
>> 			seq = ntohl(th.th_seq);
>> 			if ((tp = intotcpcb(inp)) &&
>> 			    SEQ_GEQ(seq, tp->snd_una) &&
>> 			    SEQ_LT(seq, tp->snd_max))
>> 				notify(inp, inet6ctlerrmap[cmd]);
>> -		} else if (inet6ctlerrmap[cmd] == EHOSTUNREACH ||
>> +		}
>> +		in_pcbsounlock_rele(inp, so);
>> +		in_pcbunref(inp);
>> +
>> +		if (tp == NULL &&
>> +		    (inet6ctlerrmap[cmd] == EHOSTUNREACH ||
>> 		    inet6ctlerrmap[cmd] == ENETUNREACH ||
>> -		    inet6ctlerrmap[cmd] == EHOSTDOWN)
>> +		    inet6ctlerrmap[cmd] == EHOSTDOWN)) {
>> 			syn_cache_unreach(sin6tosa_const(sa6_src), sa, &th,
>> 			    rdomain);
>> -		in_pcbunref(inp);
>> +		}
>> 	} else {
>> 		in6_pcbnotify(&tcb6table, sa6, 0,
>> 		    sa6_src, 0, rdomain, cmd, NULL, notify);
>> @@ -711,8 +724,6 @@ tcp_ctlinput(int cmd, struct sockaddr *s
>> {
>> 	struct ip *ip = v;
>> 	struct tcphdr *th;
>> -	struct tcpcb *tp;
>> -	struct inpcb *inp;
>> 	struct in_addr faddr;
>> 	tcp_seq seq;
>> 	u_int mtu;
>> @@ -735,8 +746,12 @@ tcp_ctlinput(int cmd, struct sockaddr *s
>> 		 */
>> 		return;
>> 	else if (PRC_IS_REDIRECT(cmd))
>> -		notify = in_rtchange, ip = 0;
>> +		notify = in_pcbrtchange, ip = NULL;
>> 	else if (cmd == PRC_MSGSIZE && ip_mtudisc && ip) {
>> +		struct inpcb *inp;
>> +		struct socket *so = NULL;
>> +		struct tcpcb *tp = NULL;
>> +
>> 		/*
>> 		 * Verify that the packet in the icmp payload refers
>> 		 * to an existing TCP connection.
>> @@ -746,7 +761,11 @@ tcp_ctlinput(int cmd, struct sockaddr *s
>> 		inp = in_pcblookup(&tcbtable,
>> 		    ip->ip_dst, th->th_dport, ip->ip_src, th->th_sport,
>> 		    rdomain);
>> -		if (inp && (tp = intotcpcb(inp)) &&
>> +		if (inp != NULL)
>> +			so = in_pcbsolock_ref(inp);
>> +		if (so != NULL)
>> +			tp = intotcpcb(inp);
>> +		if (tp != NULL &&
>> 		    SEQ_GEQ(seq, tp->snd_una) &&
>> 		    SEQ_LT(seq, tp->snd_max)) {
>> 			struct icmp *icp;
>> @@ -760,6 +779,7 @@ tcp_ctlinput(int cmd, struct sockaddr *s
>> 			 */
>> 			mtu = (u_int)ntohs(icp->icmp_nextmtu);
>> 			if (mtu >= tp->t_pmtud_mtu_sent) {
>> +				in_pcbsounlock_rele(inp, so);
>> 				in_pcbunref(inp);
>> 				return;
>> 			}
>> @@ -780,6 +800,7 @@ tcp_ctlinput(int cmd, struct sockaddr *s
>> 				 */
>> 				if (tp->t_flags & TF_PMTUD_PEND) {
>> 					if (SEQ_LT(tp->t_pmtud_th_seq, seq)) {
>> +						in_pcbsounlock_rele(inp, so);
>> 						in_pcbunref(inp);
>> 						return;
>> 					}
>> @@ -789,37 +810,52 @@ tcp_ctlinput(int cmd, struct sockaddr *s
>> 				tp->t_pmtud_nextmtu = icp->icmp_nextmtu;
>> 				tp->t_pmtud_ip_len = icp->icmp_ip.ip_len;
>> 				tp->t_pmtud_ip_hl = icp->icmp_ip.ip_hl;
>> +				in_pcbsounlock_rele(inp, so);
>> 				in_pcbunref(inp);
>> 				return;
>> 			}
>> 		} else {
>> 			/* ignore if we don't have a matching connection */
>> +			in_pcbsounlock_rele(inp, so);
>> 			in_pcbunref(inp);
>> 			return;
>> 		}
>> +		in_pcbsounlock_rele(inp, so);
>> 		in_pcbunref(inp);
>> -		notify = tcp_mtudisc, ip = 0;
>> +		notify = tcp_mtudisc, ip = NULL;
>> 	} else if (cmd == PRC_MTUINC)
>> -		notify = tcp_mtudisc_increase, ip = 0;
>> +		notify = tcp_mtudisc_increase, ip = NULL;
>> 	else if (cmd == PRC_HOSTDEAD)
>> -		ip = 0;
>> +		ip = NULL;
>> 	else if (errno == 0)
>> 		return;
>> 
>> 	if (ip) {
>> +		struct inpcb *inp;
>> +		struct socket *so = NULL;
>> +		struct tcpcb *tp = NULL;
>> +
>> 		th = (struct tcphdr *)((caddr_t)ip + (ip->ip_hl << 2));
>> 		inp = in_pcblookup(&tcbtable,
>> 		    ip->ip_dst, th->th_dport, ip->ip_src, th->th_sport,
>> 		    rdomain);
>> -		if (inp) {
>> +		if (inp != NULL)
>> +			so = in_pcbsolock_ref(inp);
>> +		if (so != NULL)
>> +			tp = intotcpcb(inp);
>> +		if (tp != NULL) {
>> 			seq = ntohl(th->th_seq);
>> -			if ((tp = intotcpcb(inp)) &&
>> -			    SEQ_GEQ(seq, tp->snd_una) &&
>> +			if (SEQ_GEQ(seq, tp->snd_una) &&
>> 			    SEQ_LT(seq, tp->snd_max))
>> 				notify(inp, errno);
>> -		} else if (inetctlerrmap[cmd] == EHOSTUNREACH ||
>> +		}
>> +		in_pcbsounlock_rele(inp, so);
>> +		in_pcbunref(inp);
>> +
>> +		if (tp == NULL &&
>> +		    (inetctlerrmap[cmd] == EHOSTUNREACH ||
>> 		    inetctlerrmap[cmd] == ENETUNREACH ||
>> -		    inetctlerrmap[cmd] == EHOSTDOWN) {
>> +		    inetctlerrmap[cmd] == EHOSTDOWN)) {
>> 			struct sockaddr_in sin;
>> 
>> 			bzero(&sin, sizeof(sin));
>> @@ -829,7 +865,6 @@ tcp_ctlinput(int cmd, struct sockaddr *s
>> 			sin.sin_addr = ip->ip_src;
>> 			syn_cache_unreach(sintosa(&sin), sa, th, rdomain);
>> 		}
>> -		in_pcbunref(inp);
>> 	} else
>> 		in_pcbnotifyall(&tcbtable, satosin(sa), rdomain, errno, notify);
>> }
>> @@ -871,7 +906,7 @@ tcp_mtudisc(struct inpcb *inp, int errno
>> 		 * If this was not a host route, remove and realloc.
>> 		 */
>> 		if ((rt->rt_flags & RTF_HOST) == 0) {
>> -			in_rtchange(inp, errno);
>> +			in_pcbrtchange(inp, errno);
>> 			if ((rt = in_pcbrtentry(inp)) == NULL)
>> 				return;
>> 		}
>> @@ -901,7 +936,7 @@ tcp_mtudisc_increase(struct inpcb *inp,
>> 		 * If this was a host route, remove and realloc.
>> 		 */
>> 		if (rt->rt_flags & RTF_HOST)
>> -			in_rtchange(inp, errno);
>> +			in_pcbrtchange(inp, errno);
>> 
>> 		/* also takes care of congestion window */
>> 		tcp_mss(tp, -1);
>> Index: netinet/tcp_timer.c
>> ===================================================================
>> RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/tcp_timer.c,v
>> diff -u -p -r1.82 tcp_timer.c
>> --- netinet/tcp_timer.c	16 Jan 2025 11:59:20 -0000	1.82
>> +++ netinet/tcp_timer.c	7 Feb 2025 21:05:40 -0000
>> @@ -214,17 +214,19 @@ tcp_timer_rexmt(void *arg)
>> 	    SEQ_LT(tp->t_pmtud_th_seq, (int)(tp->snd_una + tp->t_maxseg))) {
>> 		struct sockaddr_in sin;
>> 		struct icmp icmp;
>> +		u_int rtableid;
>> 
>> 		/* TF_PMTUD_PEND is set in tcp_ctlinput() which is IPv4 only */
>> 		KASSERT(!ISSET(inp->inp_flags, INP_IPV6));
>> 		tp->t_flags &= ~TF_PMTUD_PEND;
>> 
>> +		rtableid = inp->inp_rtableid;
>> +
>> 		/* XXX create fake icmp message with relevant entries */
>> 		icmp.icmp_nextmtu = tp->t_pmtud_nextmtu;
>> 		icmp.icmp_ip.ip_len = tp->t_pmtud_ip_len;
>> 		icmp.icmp_ip.ip_hl = tp->t_pmtud_ip_hl;
>> 		icmp.icmp_ip.ip_dst = inp->inp_faddr;
>> -		icmp_mtudisc(&icmp, inp->inp_rtableid);
>> 
>> 		/*
>> 		 * Notify all connections to the same peer about
>> @@ -234,9 +236,16 @@ tcp_timer_rexmt(void *arg)
>> 		sin.sin_len = sizeof(sin);
>> 		sin.sin_family = AF_INET;
>> 		sin.sin_addr = inp->inp_faddr;
>> -		in_pcbnotifyall(&tcbtable, &sin, inp->inp_rtableid, EMSGSIZE,
>> +
>> +		in_pcbsounlock_rele(inp, so);
>> +		in_pcbunref(inp);
>> +
>> +		icmp_mtudisc(&icmp, rtableid);
>> +		in_pcbnotifyall(&tcbtable, &sin, rtableid, EMSGSIZE,
>> 		    tcp_mtudisc);
>> -		goto out;
>> +
>> +		NET_UNLOCK_SHARED();
>> +		return;
>> 	}
>> 
>> 	tcp_timer_freesack(tp);
>> @@ -303,7 +312,7 @@ tcp_timer_rexmt(void *arg)
>> 			/* Disable path MTU discovery */
>> 			if ((rt->rt_locks & RTV_MTU) == 0) {
>> 				rt->rt_locks |= RTV_MTU;
>> -				in_rtchange(inp, 0);
>> +				in_pcbrtchange(inp, 0);
>> 			}
>> 
>> 			rtfree(rt);
>> Index: netinet/udp_usrreq.c
>> ===================================================================
>> RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/udp_usrreq.c,v
>> diff -u -p -r1.331 udp_usrreq.c
>> --- netinet/udp_usrreq.c	6 Feb 2025 13:40:58 -0000	1.331
>> +++ netinet/udp_usrreq.c	7 Feb 2025 21:05:40 -0000
>> @@ -748,7 +748,7 @@ udp6_ctlinput(int cmd, struct sockaddr *
>> 	if ((unsigned)cmd >= PRC_NCMDS)
>> 		return;
>> 	if (PRC_IS_REDIRECT(cmd))
>> -		notify = in_rtchange, d = NULL;
>> +		notify = in_pcbrtchange, d = NULL;
>> 	else if (cmd == PRC_HOSTDEAD)
>> 		d = NULL;
>> 	else if (cmd == PRC_MSGSIZE)
>> @@ -880,7 +880,6 @@ udp_ctlinput(int cmd, struct sockaddr *s
>> 	struct ip *ip = v;
>> 	struct udphdr *uhp;
>> 	struct in_addr faddr;
>> -	struct inpcb *inp;
>> 	void (*notify)(struct inpcb *, int) = udp_notify;
>> 	int errno;
>> 
>> @@ -897,14 +896,17 @@ udp_ctlinput(int cmd, struct sockaddr *s
>> 		return;
>> 	errno = inetctlerrmap[cmd];
>> 	if (PRC_IS_REDIRECT(cmd))
>> -		notify = in_rtchange, ip = 0;
>> +		notify = in_pcbrtchange, ip = NULL;
>> 	else if (cmd == PRC_HOSTDEAD)
>> -		ip = 0;
>> +		ip = NULL;
>> 	else if (errno == 0)
>> 		return;
>> +
>> 	if (ip) {
>> -		uhp = (struct udphdr *)((caddr_t)ip + (ip->ip_hl << 2));
>> +		struct inpcb *inp;
>> +		struct socket *so = NULL;
>> 
>> +		uhp = (struct udphdr *)((caddr_t)ip + (ip->ip_hl << 2));
>> #ifdef IPSEC
>> 		/* PMTU discovery for udpencap */
>> 		if (cmd == PRC_MSGSIZE && ip_mtudisc && udpencap_enable &&
>> @@ -917,7 +919,10 @@ udp_ctlinput(int cmd, struct sockaddr *s
>> 		    ip->ip_dst, uhp->uh_dport, ip->ip_src, uhp->uh_sport,
>> 		    rdomain);
>> 		if (inp != NULL)
>> +			so = in_pcbsolock_ref(inp);
>> +		if (so != NULL)
>> 			notify(inp, errno);
>> +		in_pcbsounlock_rele(inp, so);
>> 		in_pcbunref(inp);
>> 	} else
>> 		in_pcbnotifyall(&udbtable, satosin(sa), rdomain, errno, notify);
>> Index: netinet6/in6_pcb.c
>> ===================================================================
>> RCS file: /data/mirror/openbsd/cvs/src/sys/netinet6/in6_pcb.c,v
>> diff -u -p -r1.146 in6_pcb.c
>> --- netinet6/in6_pcb.c	21 Dec 2024 00:10:04 -0000	1.146
>> +++ netinet6/in6_pcb.c	7 Feb 2025 21:05:40 -0000
>> @@ -456,8 +456,8 @@ in6_pcbnotify(struct inpcbtable *table,
>> 
>> 	/*
>> 	 * Redirects go to all references to the destination,
>> -	 * and use in_rtchange to invalidate the route cache.
>> -	 * Dead host indications: also use in_rtchange to invalidate
>> +	 * and use in_pcbrtchange to invalidate the route cache.
>> +	 * Dead host indications: also use in_pcbrtchange to invalidate
>> 	 * the cache, and deliver the error to all the sockets.
>> 	 * Otherwise, if we have knowledge of the local port and address,
>> 	 * deliver only to that socket.
>> @@ -468,7 +468,7 @@ in6_pcbnotify(struct inpcbtable *table,
>> 		sa6_src.sin6_addr = in6addr_any;
>> 
>> 		if (cmd != PRC_HOSTDEAD)
>> -			notify = in_rtchange;
>> +			notify = in_pcbrtchange;
>> 	}
>> 	errno = inet6ctlerrmap[cmd];
>> 	if (notify == NULL)
>> @@ -477,6 +477,8 @@ in6_pcbnotify(struct inpcbtable *table,
>> 	rdomain = rtable_l2(rtable);
>> 	mtx_enter(&table->inpt_mtx);
>> 	while ((inp = in_pcb_iterator(table, inp, &iter)) != NULL) {
>> +		struct socket *so;
>> +
>> 		KASSERT(ISSET(inp->inp_flags, INP_IPV6));
>> 
>> 		/*
>> @@ -543,7 +545,10 @@ in6_pcbnotify(struct inpcbtable *table,
>> 		}
>> 	  do_notify:
>> 		mtx_leave(&table->inpt_mtx);
>> -		(*notify)(inp, errno);
>> +		so = in_pcbsolock_ref(inp);
>> +		if (so != NULL)
>> +			(*notify)(inp, errno);
>> +		in_pcbsounlock_rele(inp, so);
>> 		mtx_enter(&table->inpt_mtx);
>> 	}
>> 	mtx_leave(&table->inpt_mtx);
>> Index: netinet6/raw_ip6.c
>> ===================================================================
>> RCS file: /data/mirror/openbsd/cvs/src/sys/netinet6/raw_ip6.c,v
>> diff -u -p -r1.187 raw_ip6.c
>> --- netinet6/raw_ip6.c	6 Feb 2025 13:40:58 -0000	1.187
>> +++ netinet6/raw_ip6.c	7 Feb 2025 21:05:40 -0000
>> @@ -308,7 +308,7 @@ rip6_ctlinput(int cmd, struct sockaddr *
>> 	struct sockaddr_in6 *sa6 = satosin6(sa);
>> 	const struct sockaddr_in6 *sa6_src = NULL;
>> 	void *cmdarg;
>> -	void (*notify)(struct inpcb *, int) = in_rtchange;
>> +	void (*notify)(struct inpcb *, int) = in_pcbrtchange;
>> 	int nxt;
>> 
>> 	if (sa->sa_family != AF_INET6 ||
>> @@ -318,7 +318,7 @@ rip6_ctlinput(int cmd, struct sockaddr *
>> 	if ((unsigned)cmd >= PRC_NCMDS)
>> 		return;
>> 	if (PRC_IS_REDIRECT(cmd))
>> -		notify = in_rtchange, d = NULL;
>> +		notify = in_pcbrtchange, d = NULL;
>> 	else if (cmd == PRC_HOSTDEAD)
>> 		d = NULL;
>> 	else if (cmd == PRC_MSGSIZE)