From: Alexander Bluhm Subject: Re: pcb notify socket lock To: tech@openbsd.org Cc: Vitaliy Makkoveev Date: Wed, 12 Feb 2025 20:44:14 +0100 On Fri, Feb 07, 2025 at 11:31:35PM +0100, Alexander Bluhm wrote: > Hi, > > The notify and ctlinput functions are not MP safe yet. They need > socket lock which can be aquired by in_pcbsolock_ref(). Of course > in_pcbnotifyall() has to be called without holding a socket lock > then. > > Rename in_rtchange() to in_pcbrtchange(). This is the correct > namespace and the funtions cares about the inpcb route. > > ok? anyone? > bluhm > > Index: netinet/in_pcb.c > =================================================================== > RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/in_pcb.c,v > diff -u -p -r1.311 in_pcb.c > --- netinet/in_pcb.c 5 Feb 2025 10:15:10 -0000 1.311 > +++ netinet/in_pcb.c 7 Feb 2025 21:05:40 -0000 > @@ -810,6 +810,8 @@ in_pcbnotifyall(struct inpcbtable *table > rdomain = rtable_l2(rtable); > mtx_enter(&table->inpt_mtx); > while ((inp = in_pcb_iterator(table, inp, &iter)) != NULL) { > + struct socket *so; > + > KASSERT(!ISSET(inp->inp_flags, INP_IPV6)); > > if (inp->inp_faddr.s_addr != dst->sin_addr.s_addr || > @@ -817,7 +819,10 @@ in_pcbnotifyall(struct inpcbtable *table > continue; > } > mtx_leave(&table->inpt_mtx); > - (*notify)(inp, errno); > + so = in_pcbsolock_ref(inp); > + if (so != NULL) > + (*notify)(inp, errno); > + in_pcbsounlock_rele(inp, so); > mtx_enter(&table->inpt_mtx); > } > mtx_leave(&table->inpt_mtx); > @@ -866,8 +871,10 @@ in_losing(struct inpcb *inp) > * and allocate a (hopefully) better one. > */ > void > -in_rtchange(struct inpcb *inp, int errno) > +in_pcbrtchange(struct inpcb *inp, int errno) > { > + soassertlocked(inp->inp_socket); > + > if (inp->inp_route.ro_rt) { > rtfree(inp->inp_route.ro_rt); > inp->inp_route.ro_rt = NULL; > Index: netinet/in_pcb.h > =================================================================== > RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/in_pcb.h,v > diff -u -p -r1.164 in_pcb.h > --- netinet/in_pcb.h 9 Jan 2025 16:47:24 -0000 1.164 > +++ netinet/in_pcb.h 7 Feb 2025 21:05:40 -0000 > @@ -354,7 +354,7 @@ struct inpcb * > void in_pcbnotifyall(struct inpcbtable *, const struct sockaddr_in *, > u_int, int, void (*)(struct inpcb *, int)); > void in_pcbrehash(struct inpcb *); > -void in_rtchange(struct inpcb *, int); > +void in_pcbrtchange(struct inpcb *, int); > void in_setpeeraddr(struct inpcb *, struct mbuf *); > void in_setsockaddr(struct inpcb *, struct mbuf *); > int in_sockaddr(struct socket *, struct mbuf *); > Index: netinet/tcp_subr.c > =================================================================== > RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/tcp_subr.c,v > diff -u -p -r1.207 tcp_subr.c > --- netinet/tcp_subr.c 30 Jan 2025 14:40:50 -0000 1.207 > +++ netinet/tcp_subr.c 7 Feb 2025 21:05:40 -0000 > @@ -576,6 +576,8 @@ tcp_notify(struct inpcb *inp, int error) > struct tcpcb *tp = intotcpcb(inp); > struct socket *so = inp->inp_socket; > > + soassertlocked(so); > + > /* > * Ignore some errors if we are hooked up. > * If connection hasn't completed, has retransmitted several times, > @@ -602,12 +604,10 @@ void > tcp6_ctlinput(int cmd, struct sockaddr *sa, u_int rdomain, void *d) > { > struct tcphdr th; > - struct tcpcb *tp; > void (*notify)(struct inpcb *, int) = tcp_notify; > struct ip6_hdr *ip6; > const struct sockaddr_in6 *sa6_src = NULL; > struct sockaddr_in6 *sa6 = satosin6(sa); > - struct inpcb *inp; > struct mbuf *m; > tcp_seq seq; > int off; > @@ -633,7 +633,7 @@ tcp6_ctlinput(int cmd, struct sockaddr * > /* XXX there's no PRC_QUENCH in IPv6 */ > return; > } else if (PRC_IS_REDIRECT(cmd)) > - notify = in_rtchange, d = NULL; > + notify = in_pcbrtchange, d = NULL; > else if (cmd == PRC_MSGSIZE) > ; /* special code is present, see below */ > else if (cmd == PRC_HOSTDEAD) > @@ -655,6 +655,10 @@ tcp6_ctlinput(int cmd, struct sockaddr * > } > > if (ip6) { > + struct inpcb *inp; > + struct socket *so = NULL; > + struct tcpcb *tp = NULL; > + > /* > * XXX: We assume that when ip6 is non NULL, > * M and OFF are valid. > @@ -687,18 +691,27 @@ tcp6_ctlinput(int cmd, struct sockaddr * > in_pcbunref(inp); > return; > } > - if (inp) { > + if (inp != NULL) > + so = in_pcbsolock_ref(inp); > + if (so != NULL) > + tp = intotcpcb(inp); > + if (tp != NULL) { > seq = ntohl(th.th_seq); > if ((tp = intotcpcb(inp)) && > SEQ_GEQ(seq, tp->snd_una) && > SEQ_LT(seq, tp->snd_max)) > notify(inp, inet6ctlerrmap[cmd]); > - } else if (inet6ctlerrmap[cmd] == EHOSTUNREACH || > + } > + in_pcbsounlock_rele(inp, so); > + in_pcbunref(inp); > + > + if (tp == NULL && > + (inet6ctlerrmap[cmd] == EHOSTUNREACH || > inet6ctlerrmap[cmd] == ENETUNREACH || > - inet6ctlerrmap[cmd] == EHOSTDOWN) > + inet6ctlerrmap[cmd] == EHOSTDOWN)) { > syn_cache_unreach(sin6tosa_const(sa6_src), sa, &th, > rdomain); > - in_pcbunref(inp); > + } > } else { > in6_pcbnotify(&tcb6table, sa6, 0, > sa6_src, 0, rdomain, cmd, NULL, notify); > @@ -711,8 +724,6 @@ tcp_ctlinput(int cmd, struct sockaddr *s > { > struct ip *ip = v; > struct tcphdr *th; > - struct tcpcb *tp; > - struct inpcb *inp; > struct in_addr faddr; > tcp_seq seq; > u_int mtu; > @@ -735,8 +746,12 @@ tcp_ctlinput(int cmd, struct sockaddr *s > */ > return; > else if (PRC_IS_REDIRECT(cmd)) > - notify = in_rtchange, ip = 0; > + notify = in_pcbrtchange, ip = NULL; > else if (cmd == PRC_MSGSIZE && ip_mtudisc && ip) { > + struct inpcb *inp; > + struct socket *so = NULL; > + struct tcpcb *tp = NULL; > + > /* > * Verify that the packet in the icmp payload refers > * to an existing TCP connection. > @@ -746,7 +761,11 @@ tcp_ctlinput(int cmd, struct sockaddr *s > inp = in_pcblookup(&tcbtable, > ip->ip_dst, th->th_dport, ip->ip_src, th->th_sport, > rdomain); > - if (inp && (tp = intotcpcb(inp)) && > + if (inp != NULL) > + so = in_pcbsolock_ref(inp); > + if (so != NULL) > + tp = intotcpcb(inp); > + if (tp != NULL && > SEQ_GEQ(seq, tp->snd_una) && > SEQ_LT(seq, tp->snd_max)) { > struct icmp *icp; > @@ -760,6 +779,7 @@ tcp_ctlinput(int cmd, struct sockaddr *s > */ > mtu = (u_int)ntohs(icp->icmp_nextmtu); > if (mtu >= tp->t_pmtud_mtu_sent) { > + in_pcbsounlock_rele(inp, so); > in_pcbunref(inp); > return; > } > @@ -780,6 +800,7 @@ tcp_ctlinput(int cmd, struct sockaddr *s > */ > if (tp->t_flags & TF_PMTUD_PEND) { > if (SEQ_LT(tp->t_pmtud_th_seq, seq)) { > + in_pcbsounlock_rele(inp, so); > in_pcbunref(inp); > return; > } > @@ -789,37 +810,52 @@ tcp_ctlinput(int cmd, struct sockaddr *s > tp->t_pmtud_nextmtu = icp->icmp_nextmtu; > tp->t_pmtud_ip_len = icp->icmp_ip.ip_len; > tp->t_pmtud_ip_hl = icp->icmp_ip.ip_hl; > + in_pcbsounlock_rele(inp, so); > in_pcbunref(inp); > return; > } > } else { > /* ignore if we don't have a matching connection */ > + in_pcbsounlock_rele(inp, so); > in_pcbunref(inp); > return; > } > + in_pcbsounlock_rele(inp, so); > in_pcbunref(inp); > - notify = tcp_mtudisc, ip = 0; > + notify = tcp_mtudisc, ip = NULL; > } else if (cmd == PRC_MTUINC) > - notify = tcp_mtudisc_increase, ip = 0; > + notify = tcp_mtudisc_increase, ip = NULL; > else if (cmd == PRC_HOSTDEAD) > - ip = 0; > + ip = NULL; > else if (errno == 0) > return; > > if (ip) { > + struct inpcb *inp; > + struct socket *so = NULL; > + struct tcpcb *tp = NULL; > + > th = (struct tcphdr *)((caddr_t)ip + (ip->ip_hl << 2)); > inp = in_pcblookup(&tcbtable, > ip->ip_dst, th->th_dport, ip->ip_src, th->th_sport, > rdomain); > - if (inp) { > + if (inp != NULL) > + so = in_pcbsolock_ref(inp); > + if (so != NULL) > + tp = intotcpcb(inp); > + if (tp != NULL) { > seq = ntohl(th->th_seq); > - if ((tp = intotcpcb(inp)) && > - SEQ_GEQ(seq, tp->snd_una) && > + if (SEQ_GEQ(seq, tp->snd_una) && > SEQ_LT(seq, tp->snd_max)) > notify(inp, errno); > - } else if (inetctlerrmap[cmd] == EHOSTUNREACH || > + } > + in_pcbsounlock_rele(inp, so); > + in_pcbunref(inp); > + > + if (tp == NULL && > + (inetctlerrmap[cmd] == EHOSTUNREACH || > inetctlerrmap[cmd] == ENETUNREACH || > - inetctlerrmap[cmd] == EHOSTDOWN) { > + inetctlerrmap[cmd] == EHOSTDOWN)) { > struct sockaddr_in sin; > > bzero(&sin, sizeof(sin)); > @@ -829,7 +865,6 @@ tcp_ctlinput(int cmd, struct sockaddr *s > sin.sin_addr = ip->ip_src; > syn_cache_unreach(sintosa(&sin), sa, th, rdomain); > } > - in_pcbunref(inp); > } else > in_pcbnotifyall(&tcbtable, satosin(sa), rdomain, errno, notify); > } > @@ -871,7 +906,7 @@ tcp_mtudisc(struct inpcb *inp, int errno > * If this was not a host route, remove and realloc. > */ > if ((rt->rt_flags & RTF_HOST) == 0) { > - in_rtchange(inp, errno); > + in_pcbrtchange(inp, errno); > if ((rt = in_pcbrtentry(inp)) == NULL) > return; > } > @@ -901,7 +936,7 @@ tcp_mtudisc_increase(struct inpcb *inp, > * If this was a host route, remove and realloc. > */ > if (rt->rt_flags & RTF_HOST) > - in_rtchange(inp, errno); > + in_pcbrtchange(inp, errno); > > /* also takes care of congestion window */ > tcp_mss(tp, -1); > Index: netinet/tcp_timer.c > =================================================================== > RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/tcp_timer.c,v > diff -u -p -r1.82 tcp_timer.c > --- netinet/tcp_timer.c 16 Jan 2025 11:59:20 -0000 1.82 > +++ netinet/tcp_timer.c 7 Feb 2025 21:05:40 -0000 > @@ -214,17 +214,19 @@ tcp_timer_rexmt(void *arg) > SEQ_LT(tp->t_pmtud_th_seq, (int)(tp->snd_una + tp->t_maxseg))) { > struct sockaddr_in sin; > struct icmp icmp; > + u_int rtableid; > > /* TF_PMTUD_PEND is set in tcp_ctlinput() which is IPv4 only */ > KASSERT(!ISSET(inp->inp_flags, INP_IPV6)); > tp->t_flags &= ~TF_PMTUD_PEND; > > + rtableid = inp->inp_rtableid; > + > /* XXX create fake icmp message with relevant entries */ > icmp.icmp_nextmtu = tp->t_pmtud_nextmtu; > icmp.icmp_ip.ip_len = tp->t_pmtud_ip_len; > icmp.icmp_ip.ip_hl = tp->t_pmtud_ip_hl; > icmp.icmp_ip.ip_dst = inp->inp_faddr; > - icmp_mtudisc(&icmp, inp->inp_rtableid); > > /* > * Notify all connections to the same peer about > @@ -234,9 +236,16 @@ tcp_timer_rexmt(void *arg) > sin.sin_len = sizeof(sin); > sin.sin_family = AF_INET; > sin.sin_addr = inp->inp_faddr; > - in_pcbnotifyall(&tcbtable, &sin, inp->inp_rtableid, EMSGSIZE, > + > + in_pcbsounlock_rele(inp, so); > + in_pcbunref(inp); > + > + icmp_mtudisc(&icmp, rtableid); > + in_pcbnotifyall(&tcbtable, &sin, rtableid, EMSGSIZE, > tcp_mtudisc); > - goto out; > + > + NET_UNLOCK_SHARED(); > + return; > } > > tcp_timer_freesack(tp); > @@ -303,7 +312,7 @@ tcp_timer_rexmt(void *arg) > /* Disable path MTU discovery */ > if ((rt->rt_locks & RTV_MTU) == 0) { > rt->rt_locks |= RTV_MTU; > - in_rtchange(inp, 0); > + in_pcbrtchange(inp, 0); > } > > rtfree(rt); > Index: netinet/udp_usrreq.c > =================================================================== > RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/udp_usrreq.c,v > diff -u -p -r1.331 udp_usrreq.c > --- netinet/udp_usrreq.c 6 Feb 2025 13:40:58 -0000 1.331 > +++ netinet/udp_usrreq.c 7 Feb 2025 21:05:40 -0000 > @@ -748,7 +748,7 @@ udp6_ctlinput(int cmd, struct sockaddr * > if ((unsigned)cmd >= PRC_NCMDS) > return; > if (PRC_IS_REDIRECT(cmd)) > - notify = in_rtchange, d = NULL; > + notify = in_pcbrtchange, d = NULL; > else if (cmd == PRC_HOSTDEAD) > d = NULL; > else if (cmd == PRC_MSGSIZE) > @@ -880,7 +880,6 @@ udp_ctlinput(int cmd, struct sockaddr *s > struct ip *ip = v; > struct udphdr *uhp; > struct in_addr faddr; > - struct inpcb *inp; > void (*notify)(struct inpcb *, int) = udp_notify; > int errno; > > @@ -897,14 +896,17 @@ udp_ctlinput(int cmd, struct sockaddr *s > return; > errno = inetctlerrmap[cmd]; > if (PRC_IS_REDIRECT(cmd)) > - notify = in_rtchange, ip = 0; > + notify = in_pcbrtchange, ip = NULL; > else if (cmd == PRC_HOSTDEAD) > - ip = 0; > + ip = NULL; > else if (errno == 0) > return; > + > if (ip) { > - uhp = (struct udphdr *)((caddr_t)ip + (ip->ip_hl << 2)); > + struct inpcb *inp; > + struct socket *so = NULL; > > + uhp = (struct udphdr *)((caddr_t)ip + (ip->ip_hl << 2)); > #ifdef IPSEC > /* PMTU discovery for udpencap */ > if (cmd == PRC_MSGSIZE && ip_mtudisc && udpencap_enable && > @@ -917,7 +919,10 @@ udp_ctlinput(int cmd, struct sockaddr *s > ip->ip_dst, uhp->uh_dport, ip->ip_src, uhp->uh_sport, > rdomain); > if (inp != NULL) > + so = in_pcbsolock_ref(inp); > + if (so != NULL) > notify(inp, errno); > + in_pcbsounlock_rele(inp, so); > in_pcbunref(inp); > } else > in_pcbnotifyall(&udbtable, satosin(sa), rdomain, errno, notify); > Index: netinet6/in6_pcb.c > =================================================================== > RCS file: /data/mirror/openbsd/cvs/src/sys/netinet6/in6_pcb.c,v > diff -u -p -r1.146 in6_pcb.c > --- netinet6/in6_pcb.c 21 Dec 2024 00:10:04 -0000 1.146 > +++ netinet6/in6_pcb.c 7 Feb 2025 21:05:40 -0000 > @@ -456,8 +456,8 @@ in6_pcbnotify(struct inpcbtable *table, > > /* > * Redirects go to all references to the destination, > - * and use in_rtchange to invalidate the route cache. > - * Dead host indications: also use in_rtchange to invalidate > + * and use in_pcbrtchange to invalidate the route cache. > + * Dead host indications: also use in_pcbrtchange to invalidate > * the cache, and deliver the error to all the sockets. > * Otherwise, if we have knowledge of the local port and address, > * deliver only to that socket. > @@ -468,7 +468,7 @@ in6_pcbnotify(struct inpcbtable *table, > sa6_src.sin6_addr = in6addr_any; > > if (cmd != PRC_HOSTDEAD) > - notify = in_rtchange; > + notify = in_pcbrtchange; > } > errno = inet6ctlerrmap[cmd]; > if (notify == NULL) > @@ -477,6 +477,8 @@ in6_pcbnotify(struct inpcbtable *table, > rdomain = rtable_l2(rtable); > mtx_enter(&table->inpt_mtx); > while ((inp = in_pcb_iterator(table, inp, &iter)) != NULL) { > + struct socket *so; > + > KASSERT(ISSET(inp->inp_flags, INP_IPV6)); > > /* > @@ -543,7 +545,10 @@ in6_pcbnotify(struct inpcbtable *table, > } > do_notify: > mtx_leave(&table->inpt_mtx); > - (*notify)(inp, errno); > + so = in_pcbsolock_ref(inp); > + if (so != NULL) > + (*notify)(inp, errno); > + in_pcbsounlock_rele(inp, so); > mtx_enter(&table->inpt_mtx); > } > mtx_leave(&table->inpt_mtx); > Index: netinet6/raw_ip6.c > =================================================================== > RCS file: /data/mirror/openbsd/cvs/src/sys/netinet6/raw_ip6.c,v > diff -u -p -r1.187 raw_ip6.c > --- netinet6/raw_ip6.c 6 Feb 2025 13:40:58 -0000 1.187 > +++ netinet6/raw_ip6.c 7 Feb 2025 21:05:40 -0000 > @@ -308,7 +308,7 @@ rip6_ctlinput(int cmd, struct sockaddr * > struct sockaddr_in6 *sa6 = satosin6(sa); > const struct sockaddr_in6 *sa6_src = NULL; > void *cmdarg; > - void (*notify)(struct inpcb *, int) = in_rtchange; > + void (*notify)(struct inpcb *, int) = in_pcbrtchange; > int nxt; > > if (sa->sa_family != AF_INET6 || > @@ -318,7 +318,7 @@ rip6_ctlinput(int cmd, struct sockaddr * > if ((unsigned)cmd >= PRC_NCMDS) > return; > if (PRC_IS_REDIRECT(cmd)) > - notify = in_rtchange, d = NULL; > + notify = in_pcbrtchange, d = NULL; > else if (cmd == PRC_HOSTDEAD) > d = NULL; > else if (cmd == PRC_MSGSIZE)