From: Alexander Bluhm Subject: Re: pf inpcb link mutex To: Alexandr Nedvedicky Cc: tech@openbsd.org Date: Mon, 1 Jan 2024 20:26:00 +0100 On Mon, Jan 01, 2024 at 07:10:37PM +0100, Alexandr Nedvedicky wrote: > in other words: can we use inpcb::inp_mtx in your diff? this way we > can avoid introduction of global mutex. This was also my first attempt. Below is an earlier version of the diff. It has races that I could not fix. pf_state_key_unlink_inpcb() dereferences inp = sk->sk_inp. There is no guarantee that another CPU does not change sk_inp. This code is not MP safe as the mutex protects the life time of inp. + if (inp == NULL) + return; + mtx_enter(&inp->inp_mtx); Another problem is in pf_inp_lookup(). The inp can be freed between READ_ONCE() and in_pcbref(). if (!pf_state_key_isvalid(sk)) pf_mbuf_unlink_state_key(m); else + inp = READ_ONCE(sk->sk_inp); in_pcbref(inp); A global mutex protects inp = sk->sk_inp. So I changed it. Note that pf_inp_mtx covers only very short sections of code. I do not expect much lock contention. bluhm Index: net/pf.c =================================================================== RCS file: /cvs/src/sys/net/pf.c,v diff -u -p -r1.1191 pf.c --- net/pf.c 1 Jan 2024 17:00:57 -0000 1.1191 +++ net/pf.c 1 Jan 2024 19:14:29 -0000 @@ -256,7 +256,6 @@ void pf_state_key_unlink_reverse(stru void pf_state_key_link_inpcb(struct pf_state_key *, struct inpcb *); void pf_state_key_unlink_inpcb(struct pf_state_key *); -void pf_inpcb_unlink_state_key(struct inpcb *); void pf_pktenqueue_delayed(void *); int32_t pf_state_expires(const struct pf_state *, uint8_t); @@ -1128,7 +1127,7 @@ int pf_find_state(struct pf_pdesc *pd, struct pf_state_key_cmp *key, struct pf_state **stp) { - struct pf_state_key *sk, *pkt_sk, *inp_sk; + struct pf_state_key *sk, *pkt_sk; struct pf_state_item *si; struct pf_state *st = NULL; @@ -1140,7 +1139,6 @@ pf_find_state(struct pf_pdesc *pd, struc addlog("\n"); } - inp_sk = NULL; pkt_sk = NULL; sk = NULL; if (pd->dir == PF_OUT) { @@ -1156,14 +1154,27 @@ pf_find_state(struct pf_pdesc *pd, struc sk = pkt_sk->sk_reverse; if (pkt_sk == NULL) { + struct inpcb *inp = pd->m->m_pkthdr.pf.inp; + /* here we deal with local outbound packet */ - if (pd->m->m_pkthdr.pf.inp != NULL) { - inp_sk = pd->m->m_pkthdr.pf.inp->inp_pf_sk; - if (pf_state_key_isvalid(inp_sk)) + if (inp != NULL) { + struct pf_state_key *inp_sk; + + mtx_enter(&inp->inp_mtx); + inp_sk = inp->inp_pf_sk; + if (pf_state_key_isvalid(inp_sk)) { sk = inp_sk; - else - pf_inpcb_unlink_state_key( - pd->m->m_pkthdr.pf.inp); + mtx_leave(&inp->inp_mtx); + } else if (inp_sk != NULL) { + KASSERT(inp_sk->sk_inp == inp); + inp_sk->sk_inp = NULL; + inp->inp_pf_sk = NULL; + mtx_leave(&inp->inp_mtx); + + pf_state_key_unref(inp_sk); + in_pcbunref(inp); + } else + mtx_leave(&inp->inp_mtx); } } } @@ -1175,8 +1186,7 @@ pf_find_state(struct pf_pdesc *pd, struc if (pd->dir == PF_OUT && pkt_sk && pf_compare_state_keys(pkt_sk, sk, pd->kif, pd->dir) == 0) pf_state_key_link_reverse(sk, pkt_sk); - else if (pd->dir == PF_OUT && pd->m->m_pkthdr.pf.inp && - !pd->m->m_pkthdr.pf.inp->inp_pf_sk && !sk->sk_inp) + else if (pd->dir == PF_OUT) pf_state_key_link_inpcb(sk, pd->m->m_pkthdr.pf.inp); } @@ -7842,9 +7852,7 @@ done: pd.m->m_pkthdr.pf.qid = qid; if (pd.dir == PF_IN && st && st->key[PF_SK_STACK]) pf_mbuf_link_state_key(pd.m, st->key[PF_SK_STACK]); - if (pd.dir == PF_OUT && - pd.m->m_pkthdr.pf.inp && !pd.m->m_pkthdr.pf.inp->inp_pf_sk && - st && st->key[PF_SK_STACK] && !st->key[PF_SK_STACK]->sk_inp) + if (pd.dir == PF_OUT && st && st->key[PF_SK_STACK]) pf_state_key_link_inpcb(st->key[PF_SK_STACK], pd.m->m_pkthdr.pf.inp); @@ -8015,7 +8023,7 @@ pf_ouraddr(struct mbuf *m) sk = m->m_pkthdr.pf.statekey; if (sk != NULL) { - if (sk->sk_inp != NULL) + if (READ_ONCE(sk->sk_inp) != NULL) return (1); } @@ -8042,10 +8050,7 @@ pf_inp_lookup(struct mbuf *m) if (!pf_state_key_isvalid(sk)) pf_mbuf_unlink_state_key(m); else - inp = m->m_pkthdr.pf.statekey->sk_inp; - - if (inp && inp->inp_pf_sk) - KASSERT(m->m_pkthdr.pf.statekey == inp->inp_pf_sk); + inp = READ_ONCE(sk->sk_inp); in_pcbref(inp); return (inp); @@ -8066,8 +8071,7 @@ pf_inp_link(struct mbuf *m, struct inpcb * state, which might be just being marked as deleted by another * thread. */ - if (inp && !sk->sk_inp && !inp->inp_pf_sk) - pf_state_key_link_inpcb(sk, inp); + pf_state_key_link_inpcb(sk, inp); /* The statekey has finished finding the inp, it is no longer needed. */ pf_mbuf_unlink_state_key(m); @@ -8076,7 +8080,21 @@ pf_inp_link(struct mbuf *m, struct inpcb void pf_inp_unlink(struct inpcb *inp) { - pf_inpcb_unlink_state_key(inp); + struct pf_state_key *sk; + + mtx_enter(&inp->inp_mtx); + sk = inp->inp_pf_sk; + if (sk == NULL) { + mtx_leave(&inp->inp_mtx); + return; + } + KASSERT(sk->sk_inp == inp); + sk->sk_inp = NULL; + inp->inp_pf_sk = NULL; + mtx_leave(&inp->inp_mtx); + + pf_state_key_unref(sk); + in_pcbunref(inp); } void @@ -8189,24 +8207,18 @@ pf_mbuf_unlink_inpcb(struct mbuf *m) void pf_state_key_link_inpcb(struct pf_state_key *sk, struct inpcb *inp) { + if (inp == NULL || sk->sk_inp != NULL) + return; + + mtx_enter(&inp->inp_mtx); + if (inp->inp_pf_sk != NULL || sk->sk_inp != NULL) { + mtx_leave(&inp->inp_mtx); + return; + } KASSERT(sk->sk_inp == NULL); sk->sk_inp = in_pcbref(inp); - KASSERT(inp->inp_pf_sk == NULL); inp->inp_pf_sk = pf_state_key_ref(sk); -} - -void -pf_inpcb_unlink_state_key(struct inpcb *inp) -{ - struct pf_state_key *sk = inp->inp_pf_sk; - - if (sk != NULL) { - KASSERT(sk->sk_inp == inp); - sk->sk_inp = NULL; - inp->inp_pf_sk = NULL; - pf_state_key_unref(sk); - in_pcbunref(inp); - } + mtx_leave(&inp->inp_mtx); } void @@ -8214,13 +8226,16 @@ pf_state_key_unlink_inpcb(struct pf_stat { struct inpcb *inp = sk->sk_inp; - if (inp != NULL) { - KASSERT(inp->inp_pf_sk == sk); - sk->sk_inp = NULL; - inp->inp_pf_sk = NULL; - pf_state_key_unref(sk); - in_pcbunref(inp); - } + if (inp == NULL) + return; + mtx_enter(&inp->inp_mtx); + KASSERT(inp->inp_pf_sk == sk); + sk->sk_inp = NULL; + inp->inp_pf_sk = NULL; + mtx_leave(&inp->inp_mtx); + + pf_state_key_unref(sk); + in_pcbunref(inp); } void Index: netinet/in_pcb.c =================================================================== RCS file: /cvs/src/sys/netinet/in_pcb.c,v diff -u -p -r1.282 in_pcb.c --- netinet/in_pcb.c 7 Dec 2023 16:08:30 -0000 1.282 +++ netinet/in_pcb.c 1 Jan 2024 19:14:32 -0000 @@ -573,17 +573,13 @@ in_pcbconnect(struct inpcb *inp, struct void in_pcbdisconnect(struct inpcb *inp) { - /* - * XXXSMP pf lock sleeps, so we cannot use table->inpt_mtx - * to keep inp_pf_sk in sync with pcb. Use net lock for now. - */ - NET_ASSERT_LOCKED_EXCLUSIVE(); #if NPF > 0 - if (inp->inp_pf_sk) { - pf_remove_divert_state(inp->inp_pf_sk); - /* pf_remove_divert_state() may have detached the state */ - pf_inp_unlink(inp); - } + struct pf_state_key *sk; + + sk = READ_ONCE(inp->inp_pf_sk); + if (sk != NULL) + pf_remove_divert_state(sk); + pf_inp_unlink(inp); #endif inp->inp_flowid = 0; if (inp->inp_socket->so_state & SS_NOFDREF) @@ -595,6 +591,9 @@ in_pcbdetach(struct inpcb *inp) { struct socket *so = inp->inp_socket; struct inpcbtable *table = inp->inp_table; +#if NPF > 0 + struct pf_state_key *sk; +#endif so->so_pcb = NULL; /* @@ -616,17 +615,11 @@ in_pcbdetach(struct inpcb *inp) #endif ip_freemoptions(inp->inp_moptions); - /* - * XXXSMP pf lock sleeps, so we cannot use table->inpt_mtx - * to keep inp_pf_sk in sync with pcb. Use net lock for now. - */ - NET_ASSERT_LOCKED_EXCLUSIVE(); #if NPF > 0 - if (inp->inp_pf_sk) { - pf_remove_divert_state(inp->inp_pf_sk); - /* pf_remove_divert_state() may have detached the state */ - pf_inp_unlink(inp); - } + sk = READ_ONCE(inp->inp_pf_sk); + if (sk != NULL) + pf_remove_divert_state(sk); + pf_inp_unlink(inp); #endif mtx_enter(&table->inpt_mtx); LIST_REMOVE(inp, inp_lhash); Index: netinet/in_pcb.h =================================================================== RCS file: /cvs/src/sys/netinet/in_pcb.h,v diff -u -p -r1.145 in_pcb.h --- netinet/in_pcb.h 18 Dec 2023 13:11:20 -0000 1.145 +++ netinet/in_pcb.h 1 Jan 2024 19:14:32 -0000 @@ -187,7 +187,7 @@ struct inpcb { #define inp_csumoffset inp_cksum6 #endif struct icmp6_filter *inp_icmp6filt; - struct pf_state_key *inp_pf_sk; + struct pf_state_key *inp_pf_sk; /* [p] */ struct mbuf *(*inp_upcall)(void *, struct mbuf *, struct ip *, struct ip6_hdr *, void *, int); void *inp_upcall_arg;