From: Josh Grosse Subject: Re: Please test: wg(4): drop "while (!ifq_empty())" hack in wg_peer_destroy() To: tech@openbsd.org Date: Fri, 2 Feb 2024 19:12:08 -0500 On Tue, Jan 23, 2024 at 03:23:32PM +0300, Vitaliy Makkoveev wrote: > Updated diff. Against the previous, it introduces reference counters for > 'wg_peer' structure. The unprotected `peer' dereference within > wg_output() looks very fragile and could be the problem in future. So, > now wg_aip_lookup() returns `peer' with bumped reference counter. The > rest of `peer' acquisition and dereference left as is for a while. I ran with the prior diff, and I've been running with this diff since Jan 24 on amd64. No issues. > > Index: sys/dev/dt/dt_prov_static.c > =================================================================== > RCS file: /cvs/src/sys/dev/dt/dt_prov_static.c,v > retrieving revision 1.22 > diff -u -p -r1.22 dt_prov_static.c > --- sys/dev/dt/dt_prov_static.c 28 Aug 2023 14:50:01 -0000 1.22 > +++ sys/dev/dt/dt_prov_static.c 23 Jan 2024 11:42:24 -0000 > @@ -102,6 +102,7 @@ DT_STATIC_PROBE3(refcnt, inpcb, "void *" > DT_STATIC_PROBE3(refcnt, rtentry, "void *", "int", "int"); > DT_STATIC_PROBE3(refcnt, syncache, "void *", "int", "int"); > DT_STATIC_PROBE3(refcnt, tdb, "void *", "int", "int"); > +DT_STATIC_PROBE3(refcnt, wg_peer, "void *", "int", "int"); > > /* > * List of all static probes > @@ -155,6 +156,7 @@ struct dt_probe *const dtps_static[] = { > &_DT_STATIC_P(refcnt, rtentry), > &_DT_STATIC_P(refcnt, syncache), > &_DT_STATIC_P(refcnt, tdb), > + &_DT_STATIC_P(refcnt, wg_peer), > }; > > struct dt_probe *const *dtps_index_refcnt; > Index: sys/net/if_wg.c > =================================================================== > RCS file: /cvs/src/sys/net/if_wg.c,v > retrieving revision 1.36 > diff -u -p -r1.36 if_wg.c > --- sys/net/if_wg.c 18 Jan 2024 08:46:41 -0000 1.36 > +++ sys/net/if_wg.c 23 Jan 2024 11:42:27 -0000 > @@ -30,6 +30,7 @@ > #include > #include > #include > +#include > > #include > #include > @@ -190,6 +191,9 @@ struct wg_ring { > struct wg_peer { > LIST_ENTRY(wg_peer) p_pubkey_entry; > TAILQ_ENTRY(wg_peer) p_seq_entry; > + > + struct refcnt p_refcnt; > + > uint64_t p_id; > struct wg_softc *p_sc; > > @@ -270,6 +274,7 @@ struct wg_peer * > wg_peer_create(struct wg_softc *, uint8_t[WG_KEY_SIZE]); > struct wg_peer * > wg_peer_lookup(struct wg_softc *, const uint8_t[WG_KEY_SIZE]); > +void wg_peer_rele(struct wg_peer *); > void wg_peer_destroy(struct wg_peer *); > void wg_peer_set_endpoint_from_tag(struct wg_peer *, struct wg_tag *); > void wg_peer_set_sockaddr(struct wg_peer *, struct sockaddr *); > @@ -398,6 +403,8 @@ wg_peer_create(struct wg_softc *sc, uint > if ((peer = pool_get(&wg_peer_pool, PR_NOWAIT)) == NULL) > return NULL; > > + refcnt_init_trace(&peer->p_refcnt, DT_REFCNT_IDX_WGPEER); > + > peer->p_id = peer_counter++; > peer->p_sc = sc; > > @@ -474,6 +481,22 @@ done: > } > > void > +wg_peer_rele(struct wg_peer *peer) > +{ > + struct wg_softc *sc = peer->p_sc; > + > + if (refcnt_rele(&peer->p_refcnt) == 0) > + return; > + > + if (!mq_empty(&peer->p_stage_queue)) > + mq_purge(&peer->p_stage_queue); > + > + DPRINTF(sc, "Peer %llu destroyed\n", peer->p_id); > + explicit_bzero(peer, sizeof(*peer)); > + pool_put(&wg_peer_pool, peer); > +} > + > +void > wg_peer_destroy(struct wg_peer *peer) > { > struct wg_softc *sc = peer->p_sc; > @@ -507,31 +530,10 @@ wg_peer_destroy(struct wg_peer *peer) > > noise_remote_clear(&peer->p_remote); > > - NET_LOCK(); > - while (!ifq_empty(&sc->sc_if.if_snd)) { > - /* > - * XXX: `if_snd' of stopped interface could still > - * contain packets > - */ > - if (!ISSET(sc->sc_if.if_flags, IFF_RUNNING)) { > - ifq_purge(&sc->sc_if.if_snd); > - continue; > - } > - NET_UNLOCK(); > - tsleep_nsec(&nowake, PWAIT, "wg_ifq", 1000); > - NET_LOCK(); > - } > - NET_UNLOCK(); > - > taskq_barrier(wg_crypt_taskq); > taskq_barrier(net_tq(sc->sc_if.if_index)); > > - if (!mq_empty(&peer->p_stage_queue)) > - mq_purge(&peer->p_stage_queue); > - > - DPRINTF(sc, "Peer %llu destroyed\n", peer->p_id); > - explicit_bzero(peer, sizeof(*peer)); > - pool_put(&wg_peer_pool, peer); > + wg_peer_rele(peer); > } > > void > @@ -619,6 +621,7 @@ wg_aip_add(struct wg_softc *sc, struct w > node = art_insert(root, &aip->a_node, &d->a_addr, d->a_cidr); > > if (node == &aip->a_node) { > + refcnt_take(&peer->p_refcnt); > aip->a_peer = peer; > aip->a_data = *d; > LIST_INSERT_HEAD(&peer->p_aip, aip, a_entry); > @@ -627,9 +630,13 @@ wg_aip_add(struct wg_softc *sc, struct w > pool_put(&wg_aip_pool, aip); > aip = (struct wg_aip *) node; > if (aip->a_peer != peer) { > + struct wg_peer *peer_old = aip->a_peer; > + > + refcnt_take(&peer->p_refcnt); > LIST_REMOVE(aip, a_entry); > LIST_INSERT_HEAD(&peer->p_aip, aip, a_entry); > aip->a_peer = peer; > + wg_peer_rele(peer_old); > } > } > rw_exit_write(&root->ar_lock); > @@ -641,11 +648,18 @@ wg_aip_lookup(struct art_root *root, voi > { > struct srp_ref sr; > struct art_node *node; > + struct wg_peer *peer = NULL; > > node = art_match(root, addr, &sr); > + > + if (node) { > + peer = ((struct wg_aip *) node)->a_peer; > + refcnt_take(&peer->p_refcnt); > + } > + > srp_leave(&sr); > > - return node == NULL ? NULL : ((struct wg_aip *) node)->a_peer; > + return peer; > } > > int > @@ -678,6 +692,7 @@ wg_aip_remove(struct wg_softc *sc, struc > sc->sc_aip_num--; > LIST_REMOVE(aip, a_entry); > pool_put(&wg_aip_pool, aip); > + wg_peer_rele(peer); > } > > srp_leave(&sr); > @@ -1672,6 +1687,9 @@ wg_decap(struct wg_softc *sc, struct mbu > goto error; > } > > + if (allowed_peer) > + wg_peer_rele(allowed_peer); > + > if (__predict_false(peer != allowed_peer)) { > DPRINTF(sc, "Packet has unallowed src IP from peer " > "%llu\n", peer->p_id); > @@ -2092,7 +2110,6 @@ wg_qstart(struct ifqueue *ifq) > struct ifnet *ifp = ifq->ifq_if; > struct wg_softc *sc = ifp->if_softc; > struct wg_peer *peer; > - struct wg_tag *t; > struct mbuf *m; > SLIST_HEAD(,wg_peer) start_list; > > @@ -2104,14 +2121,34 @@ wg_qstart(struct ifqueue *ifq) > * time. > */ > while ((m = ifq_dequeue(ifq)) != NULL) { > - t = wg_tag_get(m); > - peer = t->t_peer; > + switch (m->m_pkthdr.ph_family) { > + case AF_INET: > + peer = wg_aip_lookup(sc->sc_aip4, > + &mtod(m, struct ip *)->ip_dst); > + break; > +#ifdef INET6 > + case AF_INET6: > + peer = wg_aip_lookup(sc->sc_aip6, > + &mtod(m, struct ip6_hdr *)->ip6_dst); > + break; > +#endif > + default: > + m_freem(m); > + continue; > + } > + > + if (peer == NULL) { > + m_freem(m); > + continue; > + } > + > if (mq_push(&peer->p_stage_queue, m) != 0) > counters_inc(ifp->if_counters, ifc_oqdrops); > if (!peer->p_start_onlist) { > SLIST_INSERT_HEAD(&start_list, peer, p_start_list); > peer->p_start_onlist = 1; > - } > + } else > + wg_peer_rele(peer); > } > SLIST_FOREACH(peer, &start_list, p_start_list) { > if (noise_remote_ready(&peer->p_remote) == 0) > @@ -2119,7 +2156,9 @@ wg_qstart(struct ifqueue *ifq) > else > wg_timers_event_want_initiation(&peer->p_timers); > peer->p_start_onlist = 0; > + wg_peer_rele(peer); > } > + > task_add(wg_crypt_taskq, &sc->sc_encap); > } > > @@ -2169,26 +2208,13 @@ wg_output(struct ifnet *ifp, struct mbuf > DPRINTF(sc, "No valid endpoint has been configured or " > "discovered for peer %llu\n", peer->p_id); > ret = EDESTADDRREQ; > - goto error; > + goto rele; > } > > if (m->m_pkthdr.ph_loopcnt++ > M_MAXLOOP) { > DPRINTF(sc, "Packet looped\n"); > ret = ELOOP; > - goto error; > - } > - > - /* > - * As we hold a reference to peer in the mbuf, we can't handle a > - * delayed packet without doing some refcnting. If a peer is removed > - * while a delayed holds a reference, bad things will happen. For the > - * time being, delayed packets are unsupported. This may be fixed with > - * another aip_lookup in wg_qstart, or refcnting as mentioned before. > - */ > - if (m->m_pkthdr.pf.delay > 0) { > - DPRINTF(sc, "PF delay unsupported\n"); > - ret = EOPNOTSUPP; > - goto error; > + goto rele; > } > > t->t_peer = peer; > @@ -2196,12 +2222,16 @@ wg_output(struct ifnet *ifp, struct mbuf > t->t_done = 0; > t->t_mtu = ifp->if_mtu; > > + wg_peer_rele(peer); > + > /* > * We still have an issue with ifq that will count a packet that gets > * dropped in wg_qstart, or not encrypted. These get counted as > * ofails or oqdrops, so the packet gets counted twice. > */ > return if_enqueue(ifp, m); > +rele: > + wg_peer_rele(peer); > error: > counters_inc(ifp->if_counters, ifc_oerrors); > m_freem(m); > Index: sys/sys/refcnt.h > =================================================================== > RCS file: /cvs/src/sys/sys/refcnt.h,v > retrieving revision 1.12 > diff -u -p -r1.12 refcnt.h > --- sys/sys/refcnt.h 28 Aug 2023 14:50:02 -0000 1.12 > +++ sys/sys/refcnt.h 23 Jan 2024 11:42:27 -0000 > @@ -51,6 +51,7 @@ unsigned int refcnt_read(struct refcnt * > #define DT_REFCNT_IDX_RTENTRY 5 > #define DT_REFCNT_IDX_SYNCACHE 6 > #define DT_REFCNT_IDX_TDB 7 > +#define DT_REFCNT_IDX_WGPEER 8 > > #endif /* _KERNEL */ > >