Download raw body.
pf inpcb link mutex
On Mon, Jan 01, 2024 at 07:10:37PM +0100, Alexandr Nedvedicky wrote:
> in other words: can we use inpcb::inp_mtx in your diff? this way we
> can avoid introduction of global mutex.
This was also my first attempt. Below is an earlier version of the
diff. It has races that I could not fix.
pf_state_key_unlink_inpcb() dereferences inp = sk->sk_inp. There
is no guarantee that another CPU does not change sk_inp. This code
is not MP safe as the mutex protects the life time of inp.
+ if (inp == NULL)
+ return;
+ mtx_enter(&inp->inp_mtx);
Another problem is in pf_inp_lookup(). The inp can be freed between
READ_ONCE() and in_pcbref().
if (!pf_state_key_isvalid(sk))
pf_mbuf_unlink_state_key(m);
else
+ inp = READ_ONCE(sk->sk_inp);
in_pcbref(inp);
A global mutex protects inp = sk->sk_inp. So I changed it. Note
that pf_inp_mtx covers only very short sections of code. I do not
expect much lock contention.
bluhm
Index: net/pf.c
===================================================================
RCS file: /cvs/src/sys/net/pf.c,v
diff -u -p -r1.1191 pf.c
--- net/pf.c 1 Jan 2024 17:00:57 -0000 1.1191
+++ net/pf.c 1 Jan 2024 19:14:29 -0000
@@ -256,7 +256,6 @@ void pf_state_key_unlink_reverse(stru
void pf_state_key_link_inpcb(struct pf_state_key *,
struct inpcb *);
void pf_state_key_unlink_inpcb(struct pf_state_key *);
-void pf_inpcb_unlink_state_key(struct inpcb *);
void pf_pktenqueue_delayed(void *);
int32_t pf_state_expires(const struct pf_state *, uint8_t);
@@ -1128,7 +1127,7 @@ int
pf_find_state(struct pf_pdesc *pd, struct pf_state_key_cmp *key,
struct pf_state **stp)
{
- struct pf_state_key *sk, *pkt_sk, *inp_sk;
+ struct pf_state_key *sk, *pkt_sk;
struct pf_state_item *si;
struct pf_state *st = NULL;
@@ -1140,7 +1139,6 @@ pf_find_state(struct pf_pdesc *pd, struc
addlog("\n");
}
- inp_sk = NULL;
pkt_sk = NULL;
sk = NULL;
if (pd->dir == PF_OUT) {
@@ -1156,14 +1154,27 @@ pf_find_state(struct pf_pdesc *pd, struc
sk = pkt_sk->sk_reverse;
if (pkt_sk == NULL) {
+ struct inpcb *inp = pd->m->m_pkthdr.pf.inp;
+
/* here we deal with local outbound packet */
- if (pd->m->m_pkthdr.pf.inp != NULL) {
- inp_sk = pd->m->m_pkthdr.pf.inp->inp_pf_sk;
- if (pf_state_key_isvalid(inp_sk))
+ if (inp != NULL) {
+ struct pf_state_key *inp_sk;
+
+ mtx_enter(&inp->inp_mtx);
+ inp_sk = inp->inp_pf_sk;
+ if (pf_state_key_isvalid(inp_sk)) {
sk = inp_sk;
- else
- pf_inpcb_unlink_state_key(
- pd->m->m_pkthdr.pf.inp);
+ mtx_leave(&inp->inp_mtx);
+ } else if (inp_sk != NULL) {
+ KASSERT(inp_sk->sk_inp == inp);
+ inp_sk->sk_inp = NULL;
+ inp->inp_pf_sk = NULL;
+ mtx_leave(&inp->inp_mtx);
+
+ pf_state_key_unref(inp_sk);
+ in_pcbunref(inp);
+ } else
+ mtx_leave(&inp->inp_mtx);
}
}
}
@@ -1175,8 +1186,7 @@ pf_find_state(struct pf_pdesc *pd, struc
if (pd->dir == PF_OUT && pkt_sk &&
pf_compare_state_keys(pkt_sk, sk, pd->kif, pd->dir) == 0)
pf_state_key_link_reverse(sk, pkt_sk);
- else if (pd->dir == PF_OUT && pd->m->m_pkthdr.pf.inp &&
- !pd->m->m_pkthdr.pf.inp->inp_pf_sk && !sk->sk_inp)
+ else if (pd->dir == PF_OUT)
pf_state_key_link_inpcb(sk, pd->m->m_pkthdr.pf.inp);
}
@@ -7842,9 +7852,7 @@ done:
pd.m->m_pkthdr.pf.qid = qid;
if (pd.dir == PF_IN && st && st->key[PF_SK_STACK])
pf_mbuf_link_state_key(pd.m, st->key[PF_SK_STACK]);
- if (pd.dir == PF_OUT &&
- pd.m->m_pkthdr.pf.inp && !pd.m->m_pkthdr.pf.inp->inp_pf_sk &&
- st && st->key[PF_SK_STACK] && !st->key[PF_SK_STACK]->sk_inp)
+ if (pd.dir == PF_OUT && st && st->key[PF_SK_STACK])
pf_state_key_link_inpcb(st->key[PF_SK_STACK],
pd.m->m_pkthdr.pf.inp);
@@ -8015,7 +8023,7 @@ pf_ouraddr(struct mbuf *m)
sk = m->m_pkthdr.pf.statekey;
if (sk != NULL) {
- if (sk->sk_inp != NULL)
+ if (READ_ONCE(sk->sk_inp) != NULL)
return (1);
}
@@ -8042,10 +8050,7 @@ pf_inp_lookup(struct mbuf *m)
if (!pf_state_key_isvalid(sk))
pf_mbuf_unlink_state_key(m);
else
- inp = m->m_pkthdr.pf.statekey->sk_inp;
-
- if (inp && inp->inp_pf_sk)
- KASSERT(m->m_pkthdr.pf.statekey == inp->inp_pf_sk);
+ inp = READ_ONCE(sk->sk_inp);
in_pcbref(inp);
return (inp);
@@ -8066,8 +8071,7 @@ pf_inp_link(struct mbuf *m, struct inpcb
* state, which might be just being marked as deleted by another
* thread.
*/
- if (inp && !sk->sk_inp && !inp->inp_pf_sk)
- pf_state_key_link_inpcb(sk, inp);
+ pf_state_key_link_inpcb(sk, inp);
/* The statekey has finished finding the inp, it is no longer needed. */
pf_mbuf_unlink_state_key(m);
@@ -8076,7 +8080,21 @@ pf_inp_link(struct mbuf *m, struct inpcb
void
pf_inp_unlink(struct inpcb *inp)
{
- pf_inpcb_unlink_state_key(inp);
+ struct pf_state_key *sk;
+
+ mtx_enter(&inp->inp_mtx);
+ sk = inp->inp_pf_sk;
+ if (sk == NULL) {
+ mtx_leave(&inp->inp_mtx);
+ return;
+ }
+ KASSERT(sk->sk_inp == inp);
+ sk->sk_inp = NULL;
+ inp->inp_pf_sk = NULL;
+ mtx_leave(&inp->inp_mtx);
+
+ pf_state_key_unref(sk);
+ in_pcbunref(inp);
}
void
@@ -8189,24 +8207,18 @@ pf_mbuf_unlink_inpcb(struct mbuf *m)
void
pf_state_key_link_inpcb(struct pf_state_key *sk, struct inpcb *inp)
{
+ if (inp == NULL || sk->sk_inp != NULL)
+ return;
+
+ mtx_enter(&inp->inp_mtx);
+ if (inp->inp_pf_sk != NULL || sk->sk_inp != NULL) {
+ mtx_leave(&inp->inp_mtx);
+ return;
+ }
KASSERT(sk->sk_inp == NULL);
sk->sk_inp = in_pcbref(inp);
- KASSERT(inp->inp_pf_sk == NULL);
inp->inp_pf_sk = pf_state_key_ref(sk);
-}
-
-void
-pf_inpcb_unlink_state_key(struct inpcb *inp)
-{
- struct pf_state_key *sk = inp->inp_pf_sk;
-
- if (sk != NULL) {
- KASSERT(sk->sk_inp == inp);
- sk->sk_inp = NULL;
- inp->inp_pf_sk = NULL;
- pf_state_key_unref(sk);
- in_pcbunref(inp);
- }
+ mtx_leave(&inp->inp_mtx);
}
void
@@ -8214,13 +8226,16 @@ pf_state_key_unlink_inpcb(struct pf_stat
{
struct inpcb *inp = sk->sk_inp;
- if (inp != NULL) {
- KASSERT(inp->inp_pf_sk == sk);
- sk->sk_inp = NULL;
- inp->inp_pf_sk = NULL;
- pf_state_key_unref(sk);
- in_pcbunref(inp);
- }
+ if (inp == NULL)
+ return;
+ mtx_enter(&inp->inp_mtx);
+ KASSERT(inp->inp_pf_sk == sk);
+ sk->sk_inp = NULL;
+ inp->inp_pf_sk = NULL;
+ mtx_leave(&inp->inp_mtx);
+
+ pf_state_key_unref(sk);
+ in_pcbunref(inp);
}
void
Index: netinet/in_pcb.c
===================================================================
RCS file: /cvs/src/sys/netinet/in_pcb.c,v
diff -u -p -r1.282 in_pcb.c
--- netinet/in_pcb.c 7 Dec 2023 16:08:30 -0000 1.282
+++ netinet/in_pcb.c 1 Jan 2024 19:14:32 -0000
@@ -573,17 +573,13 @@ in_pcbconnect(struct inpcb *inp, struct
void
in_pcbdisconnect(struct inpcb *inp)
{
- /*
- * XXXSMP pf lock sleeps, so we cannot use table->inpt_mtx
- * to keep inp_pf_sk in sync with pcb. Use net lock for now.
- */
- NET_ASSERT_LOCKED_EXCLUSIVE();
#if NPF > 0
- if (inp->inp_pf_sk) {
- pf_remove_divert_state(inp->inp_pf_sk);
- /* pf_remove_divert_state() may have detached the state */
- pf_inp_unlink(inp);
- }
+ struct pf_state_key *sk;
+
+ sk = READ_ONCE(inp->inp_pf_sk);
+ if (sk != NULL)
+ pf_remove_divert_state(sk);
+ pf_inp_unlink(inp);
#endif
inp->inp_flowid = 0;
if (inp->inp_socket->so_state & SS_NOFDREF)
@@ -595,6 +591,9 @@ in_pcbdetach(struct inpcb *inp)
{
struct socket *so = inp->inp_socket;
struct inpcbtable *table = inp->inp_table;
+#if NPF > 0
+ struct pf_state_key *sk;
+#endif
so->so_pcb = NULL;
/*
@@ -616,17 +615,11 @@ in_pcbdetach(struct inpcb *inp)
#endif
ip_freemoptions(inp->inp_moptions);
- /*
- * XXXSMP pf lock sleeps, so we cannot use table->inpt_mtx
- * to keep inp_pf_sk in sync with pcb. Use net lock for now.
- */
- NET_ASSERT_LOCKED_EXCLUSIVE();
#if NPF > 0
- if (inp->inp_pf_sk) {
- pf_remove_divert_state(inp->inp_pf_sk);
- /* pf_remove_divert_state() may have detached the state */
- pf_inp_unlink(inp);
- }
+ sk = READ_ONCE(inp->inp_pf_sk);
+ if (sk != NULL)
+ pf_remove_divert_state(sk);
+ pf_inp_unlink(inp);
#endif
mtx_enter(&table->inpt_mtx);
LIST_REMOVE(inp, inp_lhash);
Index: netinet/in_pcb.h
===================================================================
RCS file: /cvs/src/sys/netinet/in_pcb.h,v
diff -u -p -r1.145 in_pcb.h
--- netinet/in_pcb.h 18 Dec 2023 13:11:20 -0000 1.145
+++ netinet/in_pcb.h 1 Jan 2024 19:14:32 -0000
@@ -187,7 +187,7 @@ struct inpcb {
#define inp_csumoffset inp_cksum6
#endif
struct icmp6_filter *inp_icmp6filt;
- struct pf_state_key *inp_pf_sk;
+ struct pf_state_key *inp_pf_sk; /* [p] */
struct mbuf *(*inp_upcall)(void *, struct mbuf *,
struct ip *, struct ip6_hdr *, void *, int);
void *inp_upcall_arg;
pf inpcb link mutex