Index | Thread | Search

From:
Alexander Bluhm <bluhm@openbsd.org>
Subject:
Re: pcb notify socket lock
To:
tech@openbsd.org
Cc:
Vitaliy Makkoveev <mvs@openbsd.org>
Date:
Wed, 12 Feb 2025 20:44:14 +0100

Download raw body.

Thread
On Fri, Feb 07, 2025 at 11:31:35PM +0100, Alexander Bluhm wrote:
> Hi,
> 
> The notify and ctlinput functions are not MP safe yet.  They need
> socket lock which can be aquired by in_pcbsolock_ref().  Of course
> in_pcbnotifyall() has to be called without holding a socket lock
> then.
> 
> Rename in_rtchange() to in_pcbrtchange().  This is the correct
> namespace and the funtions cares about the inpcb route.
> 
> ok?

anyone?

> bluhm
> 
> Index: netinet/in_pcb.c
> ===================================================================
> RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/in_pcb.c,v
> diff -u -p -r1.311 in_pcb.c
> --- netinet/in_pcb.c	5 Feb 2025 10:15:10 -0000	1.311
> +++ netinet/in_pcb.c	7 Feb 2025 21:05:40 -0000
> @@ -810,6 +810,8 @@ in_pcbnotifyall(struct inpcbtable *table
>  	rdomain = rtable_l2(rtable);
>  	mtx_enter(&table->inpt_mtx);
>  	while ((inp = in_pcb_iterator(table, inp, &iter)) != NULL) {
> +		struct socket *so;
> +
>  		KASSERT(!ISSET(inp->inp_flags, INP_IPV6));
> 
>  		if (inp->inp_faddr.s_addr != dst->sin_addr.s_addr ||
> @@ -817,7 +819,10 @@ in_pcbnotifyall(struct inpcbtable *table
>  			continue;
>  		}
>  		mtx_leave(&table->inpt_mtx);
> -		(*notify)(inp, errno);
> +		so = in_pcbsolock_ref(inp);
> +		if (so != NULL)
> +			(*notify)(inp, errno);
> +		in_pcbsounlock_rele(inp, so);
>  		mtx_enter(&table->inpt_mtx);
>  	}
>  	mtx_leave(&table->inpt_mtx);
> @@ -866,8 +871,10 @@ in_losing(struct inpcb *inp)
>   * and allocate a (hopefully) better one.
>   */
>  void
> -in_rtchange(struct inpcb *inp, int errno)
> +in_pcbrtchange(struct inpcb *inp, int errno)
>  {
> +	soassertlocked(inp->inp_socket);
> +
>  	if (inp->inp_route.ro_rt) {
>  		rtfree(inp->inp_route.ro_rt);
>  		inp->inp_route.ro_rt = NULL;
> Index: netinet/in_pcb.h
> ===================================================================
> RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/in_pcb.h,v
> diff -u -p -r1.164 in_pcb.h
> --- netinet/in_pcb.h	9 Jan 2025 16:47:24 -0000	1.164
> +++ netinet/in_pcb.h	7 Feb 2025 21:05:40 -0000
> @@ -354,7 +354,7 @@ struct inpcb *
>  void	 in_pcbnotifyall(struct inpcbtable *, const struct sockaddr_in *,
>  	    u_int, int, void (*)(struct inpcb *, int));
>  void	 in_pcbrehash(struct inpcb *);
> -void	 in_rtchange(struct inpcb *, int);
> +void	 in_pcbrtchange(struct inpcb *, int);
>  void	 in_setpeeraddr(struct inpcb *, struct mbuf *);
>  void	 in_setsockaddr(struct inpcb *, struct mbuf *);
>  int	 in_sockaddr(struct socket *, struct mbuf *);
> Index: netinet/tcp_subr.c
> ===================================================================
> RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/tcp_subr.c,v
> diff -u -p -r1.207 tcp_subr.c
> --- netinet/tcp_subr.c	30 Jan 2025 14:40:50 -0000	1.207
> +++ netinet/tcp_subr.c	7 Feb 2025 21:05:40 -0000
> @@ -576,6 +576,8 @@ tcp_notify(struct inpcb *inp, int error)
>  	struct tcpcb *tp = intotcpcb(inp);
>  	struct socket *so = inp->inp_socket;
> 
> +	soassertlocked(so);
> +
>  	/*
>  	 * Ignore some errors if we are hooked up.
>  	 * If connection hasn't completed, has retransmitted several times,
> @@ -602,12 +604,10 @@ void
>  tcp6_ctlinput(int cmd, struct sockaddr *sa, u_int rdomain, void *d)
>  {
>  	struct tcphdr th;
> -	struct tcpcb *tp;
>  	void (*notify)(struct inpcb *, int) = tcp_notify;
>  	struct ip6_hdr *ip6;
>  	const struct sockaddr_in6 *sa6_src = NULL;
>  	struct sockaddr_in6 *sa6 = satosin6(sa);
> -	struct inpcb *inp;
>  	struct mbuf *m;
>  	tcp_seq seq;
>  	int off;
> @@ -633,7 +633,7 @@ tcp6_ctlinput(int cmd, struct sockaddr *
>  		/* XXX there's no PRC_QUENCH in IPv6 */
>  		return;
>  	} else if (PRC_IS_REDIRECT(cmd))
> -		notify = in_rtchange, d = NULL;
> +		notify = in_pcbrtchange, d = NULL;
>  	else if (cmd == PRC_MSGSIZE)
>  		; /* special code is present, see below */
>  	else if (cmd == PRC_HOSTDEAD)
> @@ -655,6 +655,10 @@ tcp6_ctlinput(int cmd, struct sockaddr *
>  	}
> 
>  	if (ip6) {
> +		struct inpcb *inp;
> +		struct socket *so = NULL;
> +		struct tcpcb *tp = NULL;
> +
>  		/*
>  		 * XXX: We assume that when ip6 is non NULL,
>  		 * M and OFF are valid.
> @@ -687,18 +691,27 @@ tcp6_ctlinput(int cmd, struct sockaddr *
>  			in_pcbunref(inp);
>  			return;
>  		}
> -		if (inp) {
> +		if (inp != NULL)
> +			so = in_pcbsolock_ref(inp);
> +		if (so != NULL)
> +			tp = intotcpcb(inp);
> +		if (tp != NULL) {
>  			seq = ntohl(th.th_seq);
>  			if ((tp = intotcpcb(inp)) &&
>  			    SEQ_GEQ(seq, tp->snd_una) &&
>  			    SEQ_LT(seq, tp->snd_max))
>  				notify(inp, inet6ctlerrmap[cmd]);
> -		} else if (inet6ctlerrmap[cmd] == EHOSTUNREACH ||
> +		}
> +		in_pcbsounlock_rele(inp, so);
> +		in_pcbunref(inp);
> +
> +		if (tp == NULL &&
> +		    (inet6ctlerrmap[cmd] == EHOSTUNREACH ||
>  		    inet6ctlerrmap[cmd] == ENETUNREACH ||
> -		    inet6ctlerrmap[cmd] == EHOSTDOWN)
> +		    inet6ctlerrmap[cmd] == EHOSTDOWN)) {
>  			syn_cache_unreach(sin6tosa_const(sa6_src), sa, &th,
>  			    rdomain);
> -		in_pcbunref(inp);
> +		}
>  	} else {
>  		in6_pcbnotify(&tcb6table, sa6, 0,
>  		    sa6_src, 0, rdomain, cmd, NULL, notify);
> @@ -711,8 +724,6 @@ tcp_ctlinput(int cmd, struct sockaddr *s
>  {
>  	struct ip *ip = v;
>  	struct tcphdr *th;
> -	struct tcpcb *tp;
> -	struct inpcb *inp;
>  	struct in_addr faddr;
>  	tcp_seq seq;
>  	u_int mtu;
> @@ -735,8 +746,12 @@ tcp_ctlinput(int cmd, struct sockaddr *s
>  		 */
>  		return;
>  	else if (PRC_IS_REDIRECT(cmd))
> -		notify = in_rtchange, ip = 0;
> +		notify = in_pcbrtchange, ip = NULL;
>  	else if (cmd == PRC_MSGSIZE && ip_mtudisc && ip) {
> +		struct inpcb *inp;
> +		struct socket *so = NULL;
> +		struct tcpcb *tp = NULL;
> +
>  		/*
>  		 * Verify that the packet in the icmp payload refers
>  		 * to an existing TCP connection.
> @@ -746,7 +761,11 @@ tcp_ctlinput(int cmd, struct sockaddr *s
>  		inp = in_pcblookup(&tcbtable,
>  		    ip->ip_dst, th->th_dport, ip->ip_src, th->th_sport,
>  		    rdomain);
> -		if (inp && (tp = intotcpcb(inp)) &&
> +		if (inp != NULL)
> +			so = in_pcbsolock_ref(inp);
> +		if (so != NULL)
> +			tp = intotcpcb(inp);
> +		if (tp != NULL &&
>  		    SEQ_GEQ(seq, tp->snd_una) &&
>  		    SEQ_LT(seq, tp->snd_max)) {
>  			struct icmp *icp;
> @@ -760,6 +779,7 @@ tcp_ctlinput(int cmd, struct sockaddr *s
>  			 */
>  			mtu = (u_int)ntohs(icp->icmp_nextmtu);
>  			if (mtu >= tp->t_pmtud_mtu_sent) {
> +				in_pcbsounlock_rele(inp, so);
>  				in_pcbunref(inp);
>  				return;
>  			}
> @@ -780,6 +800,7 @@ tcp_ctlinput(int cmd, struct sockaddr *s
>  				 */
>  				if (tp->t_flags & TF_PMTUD_PEND) {
>  					if (SEQ_LT(tp->t_pmtud_th_seq, seq)) {
> +						in_pcbsounlock_rele(inp, so);
>  						in_pcbunref(inp);
>  						return;
>  					}
> @@ -789,37 +810,52 @@ tcp_ctlinput(int cmd, struct sockaddr *s
>  				tp->t_pmtud_nextmtu = icp->icmp_nextmtu;
>  				tp->t_pmtud_ip_len = icp->icmp_ip.ip_len;
>  				tp->t_pmtud_ip_hl = icp->icmp_ip.ip_hl;
> +				in_pcbsounlock_rele(inp, so);
>  				in_pcbunref(inp);
>  				return;
>  			}
>  		} else {
>  			/* ignore if we don't have a matching connection */
> +			in_pcbsounlock_rele(inp, so);
>  			in_pcbunref(inp);
>  			return;
>  		}
> +		in_pcbsounlock_rele(inp, so);
>  		in_pcbunref(inp);
> -		notify = tcp_mtudisc, ip = 0;
> +		notify = tcp_mtudisc, ip = NULL;
>  	} else if (cmd == PRC_MTUINC)
> -		notify = tcp_mtudisc_increase, ip = 0;
> +		notify = tcp_mtudisc_increase, ip = NULL;
>  	else if (cmd == PRC_HOSTDEAD)
> -		ip = 0;
> +		ip = NULL;
>  	else if (errno == 0)
>  		return;
> 
>  	if (ip) {
> +		struct inpcb *inp;
> +		struct socket *so = NULL;
> +		struct tcpcb *tp = NULL;
> +
>  		th = (struct tcphdr *)((caddr_t)ip + (ip->ip_hl << 2));
>  		inp = in_pcblookup(&tcbtable,
>  		    ip->ip_dst, th->th_dport, ip->ip_src, th->th_sport,
>  		    rdomain);
> -		if (inp) {
> +		if (inp != NULL)
> +			so = in_pcbsolock_ref(inp);
> +		if (so != NULL)
> +			tp = intotcpcb(inp);
> +		if (tp != NULL) {
>  			seq = ntohl(th->th_seq);
> -			if ((tp = intotcpcb(inp)) &&
> -			    SEQ_GEQ(seq, tp->snd_una) &&
> +			if (SEQ_GEQ(seq, tp->snd_una) &&
>  			    SEQ_LT(seq, tp->snd_max))
>  				notify(inp, errno);
> -		} else if (inetctlerrmap[cmd] == EHOSTUNREACH ||
> +		}
> +		in_pcbsounlock_rele(inp, so);
> +		in_pcbunref(inp);
> +
> +		if (tp == NULL &&
> +		    (inetctlerrmap[cmd] == EHOSTUNREACH ||
>  		    inetctlerrmap[cmd] == ENETUNREACH ||
> -		    inetctlerrmap[cmd] == EHOSTDOWN) {
> +		    inetctlerrmap[cmd] == EHOSTDOWN)) {
>  			struct sockaddr_in sin;
> 
>  			bzero(&sin, sizeof(sin));
> @@ -829,7 +865,6 @@ tcp_ctlinput(int cmd, struct sockaddr *s
>  			sin.sin_addr = ip->ip_src;
>  			syn_cache_unreach(sintosa(&sin), sa, th, rdomain);
>  		}
> -		in_pcbunref(inp);
>  	} else
>  		in_pcbnotifyall(&tcbtable, satosin(sa), rdomain, errno, notify);
>  }
> @@ -871,7 +906,7 @@ tcp_mtudisc(struct inpcb *inp, int errno
>  		 * If this was not a host route, remove and realloc.
>  		 */
>  		if ((rt->rt_flags & RTF_HOST) == 0) {
> -			in_rtchange(inp, errno);
> +			in_pcbrtchange(inp, errno);
>  			if ((rt = in_pcbrtentry(inp)) == NULL)
>  				return;
>  		}
> @@ -901,7 +936,7 @@ tcp_mtudisc_increase(struct inpcb *inp,
>  		 * If this was a host route, remove and realloc.
>  		 */
>  		if (rt->rt_flags & RTF_HOST)
> -			in_rtchange(inp, errno);
> +			in_pcbrtchange(inp, errno);
> 
>  		/* also takes care of congestion window */
>  		tcp_mss(tp, -1);
> Index: netinet/tcp_timer.c
> ===================================================================
> RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/tcp_timer.c,v
> diff -u -p -r1.82 tcp_timer.c
> --- netinet/tcp_timer.c	16 Jan 2025 11:59:20 -0000	1.82
> +++ netinet/tcp_timer.c	7 Feb 2025 21:05:40 -0000
> @@ -214,17 +214,19 @@ tcp_timer_rexmt(void *arg)
>  	    SEQ_LT(tp->t_pmtud_th_seq, (int)(tp->snd_una + tp->t_maxseg))) {
>  		struct sockaddr_in sin;
>  		struct icmp icmp;
> +		u_int rtableid;
> 
>  		/* TF_PMTUD_PEND is set in tcp_ctlinput() which is IPv4 only */
>  		KASSERT(!ISSET(inp->inp_flags, INP_IPV6));
>  		tp->t_flags &= ~TF_PMTUD_PEND;
> 
> +		rtableid = inp->inp_rtableid;
> +
>  		/* XXX create fake icmp message with relevant entries */
>  		icmp.icmp_nextmtu = tp->t_pmtud_nextmtu;
>  		icmp.icmp_ip.ip_len = tp->t_pmtud_ip_len;
>  		icmp.icmp_ip.ip_hl = tp->t_pmtud_ip_hl;
>  		icmp.icmp_ip.ip_dst = inp->inp_faddr;
> -		icmp_mtudisc(&icmp, inp->inp_rtableid);
> 
>  		/*
>  		 * Notify all connections to the same peer about
> @@ -234,9 +236,16 @@ tcp_timer_rexmt(void *arg)
>  		sin.sin_len = sizeof(sin);
>  		sin.sin_family = AF_INET;
>  		sin.sin_addr = inp->inp_faddr;
> -		in_pcbnotifyall(&tcbtable, &sin, inp->inp_rtableid, EMSGSIZE,
> +
> +		in_pcbsounlock_rele(inp, so);
> +		in_pcbunref(inp);
> +
> +		icmp_mtudisc(&icmp, rtableid);
> +		in_pcbnotifyall(&tcbtable, &sin, rtableid, EMSGSIZE,
>  		    tcp_mtudisc);
> -		goto out;
> +
> +		NET_UNLOCK_SHARED();
> +		return;
>  	}
> 
>  	tcp_timer_freesack(tp);
> @@ -303,7 +312,7 @@ tcp_timer_rexmt(void *arg)
>  			/* Disable path MTU discovery */
>  			if ((rt->rt_locks & RTV_MTU) == 0) {
>  				rt->rt_locks |= RTV_MTU;
> -				in_rtchange(inp, 0);
> +				in_pcbrtchange(inp, 0);
>  			}
> 
>  			rtfree(rt);
> Index: netinet/udp_usrreq.c
> ===================================================================
> RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/udp_usrreq.c,v
> diff -u -p -r1.331 udp_usrreq.c
> --- netinet/udp_usrreq.c	6 Feb 2025 13:40:58 -0000	1.331
> +++ netinet/udp_usrreq.c	7 Feb 2025 21:05:40 -0000
> @@ -748,7 +748,7 @@ udp6_ctlinput(int cmd, struct sockaddr *
>  	if ((unsigned)cmd >= PRC_NCMDS)
>  		return;
>  	if (PRC_IS_REDIRECT(cmd))
> -		notify = in_rtchange, d = NULL;
> +		notify = in_pcbrtchange, d = NULL;
>  	else if (cmd == PRC_HOSTDEAD)
>  		d = NULL;
>  	else if (cmd == PRC_MSGSIZE)
> @@ -880,7 +880,6 @@ udp_ctlinput(int cmd, struct sockaddr *s
>  	struct ip *ip = v;
>  	struct udphdr *uhp;
>  	struct in_addr faddr;
> -	struct inpcb *inp;
>  	void (*notify)(struct inpcb *, int) = udp_notify;
>  	int errno;
> 
> @@ -897,14 +896,17 @@ udp_ctlinput(int cmd, struct sockaddr *s
>  		return;
>  	errno = inetctlerrmap[cmd];
>  	if (PRC_IS_REDIRECT(cmd))
> -		notify = in_rtchange, ip = 0;
> +		notify = in_pcbrtchange, ip = NULL;
>  	else if (cmd == PRC_HOSTDEAD)
> -		ip = 0;
> +		ip = NULL;
>  	else if (errno == 0)
>  		return;
> +
>  	if (ip) {
> -		uhp = (struct udphdr *)((caddr_t)ip + (ip->ip_hl << 2));
> +		struct inpcb *inp;
> +		struct socket *so = NULL;
> 
> +		uhp = (struct udphdr *)((caddr_t)ip + (ip->ip_hl << 2));
>  #ifdef IPSEC
>  		/* PMTU discovery for udpencap */
>  		if (cmd == PRC_MSGSIZE && ip_mtudisc && udpencap_enable &&
> @@ -917,7 +919,10 @@ udp_ctlinput(int cmd, struct sockaddr *s
>  		    ip->ip_dst, uhp->uh_dport, ip->ip_src, uhp->uh_sport,
>  		    rdomain);
>  		if (inp != NULL)
> +			so = in_pcbsolock_ref(inp);
> +		if (so != NULL)
>  			notify(inp, errno);
> +		in_pcbsounlock_rele(inp, so);
>  		in_pcbunref(inp);
>  	} else
>  		in_pcbnotifyall(&udbtable, satosin(sa), rdomain, errno, notify);
> Index: netinet6/in6_pcb.c
> ===================================================================
> RCS file: /data/mirror/openbsd/cvs/src/sys/netinet6/in6_pcb.c,v
> diff -u -p -r1.146 in6_pcb.c
> --- netinet6/in6_pcb.c	21 Dec 2024 00:10:04 -0000	1.146
> +++ netinet6/in6_pcb.c	7 Feb 2025 21:05:40 -0000
> @@ -456,8 +456,8 @@ in6_pcbnotify(struct inpcbtable *table,
> 
>  	/*
>  	 * Redirects go to all references to the destination,
> -	 * and use in_rtchange to invalidate the route cache.
> -	 * Dead host indications: also use in_rtchange to invalidate
> +	 * and use in_pcbrtchange to invalidate the route cache.
> +	 * Dead host indications: also use in_pcbrtchange to invalidate
>  	 * the cache, and deliver the error to all the sockets.
>  	 * Otherwise, if we have knowledge of the local port and address,
>  	 * deliver only to that socket.
> @@ -468,7 +468,7 @@ in6_pcbnotify(struct inpcbtable *table,
>  		sa6_src.sin6_addr = in6addr_any;
> 
>  		if (cmd != PRC_HOSTDEAD)
> -			notify = in_rtchange;
> +			notify = in_pcbrtchange;
>  	}
>  	errno = inet6ctlerrmap[cmd];
>  	if (notify == NULL)
> @@ -477,6 +477,8 @@ in6_pcbnotify(struct inpcbtable *table,
>  	rdomain = rtable_l2(rtable);
>  	mtx_enter(&table->inpt_mtx);
>  	while ((inp = in_pcb_iterator(table, inp, &iter)) != NULL) {
> +		struct socket *so;
> +
>  		KASSERT(ISSET(inp->inp_flags, INP_IPV6));
> 
>  		/*
> @@ -543,7 +545,10 @@ in6_pcbnotify(struct inpcbtable *table,
>  		}
>  	  do_notify:
>  		mtx_leave(&table->inpt_mtx);
> -		(*notify)(inp, errno);
> +		so = in_pcbsolock_ref(inp);
> +		if (so != NULL)
> +			(*notify)(inp, errno);
> +		in_pcbsounlock_rele(inp, so);
>  		mtx_enter(&table->inpt_mtx);
>  	}
>  	mtx_leave(&table->inpt_mtx);
> Index: netinet6/raw_ip6.c
> ===================================================================
> RCS file: /data/mirror/openbsd/cvs/src/sys/netinet6/raw_ip6.c,v
> diff -u -p -r1.187 raw_ip6.c
> --- netinet6/raw_ip6.c	6 Feb 2025 13:40:58 -0000	1.187
> +++ netinet6/raw_ip6.c	7 Feb 2025 21:05:40 -0000
> @@ -308,7 +308,7 @@ rip6_ctlinput(int cmd, struct sockaddr *
>  	struct sockaddr_in6 *sa6 = satosin6(sa);
>  	const struct sockaddr_in6 *sa6_src = NULL;
>  	void *cmdarg;
> -	void (*notify)(struct inpcb *, int) = in_rtchange;
> +	void (*notify)(struct inpcb *, int) = in_pcbrtchange;
>  	int nxt;
> 
>  	if (sa->sa_family != AF_INET6 ||
> @@ -318,7 +318,7 @@ rip6_ctlinput(int cmd, struct sockaddr *
>  	if ((unsigned)cmd >= PRC_NCMDS)
>  		return;
>  	if (PRC_IS_REDIRECT(cmd))
> -		notify = in_rtchange, d = NULL;
> +		notify = in_pcbrtchange, d = NULL;
>  	else if (cmd == PRC_HOSTDEAD)
>  		d = NULL;
>  	else if (cmd == PRC_MSGSIZE)