Index | Thread | Search

From:
David Gwynne <david@gwynne.id.au>
Subject:
Re: link mbufs/inpcbs to pf_states, not pf_state_keys
To:
Alexander Bluhm <alexander.bluhm@gmx.net>
Cc:
Alexandr Nedvedicky <sashan@fastmail.net>, tech@openbsd.org
Date:
Sat, 10 May 2025 10:28:56 +1000

Download raw body.

Thread
On Fri, Sep 29, 2023 at 08:29:41PM +0200, Alexander Bluhm wrote:
> On Tue, Aug 22, 2023 at 06:30:31AM +0200, Alexandr Nedvedicky wrote:
> >     Currently we have something like this:
> >
> > 	{ mbuf, pcb } <-> state key <-> { state, state ... }
> >
> >     with this diff we get to:
> >
> > 	{ mbuf, pcb } <-> state <-> state key
> >
> >     Basically when we do process packet we are interested in state
> >     not state key itself.

apologies for taking so long to get back to this. it's been one of those
years.

> The PCB holds the 4 tupel IP/Port/SRC/DST, and the state key uses
> that also.  So we have a 1:1 relation.  That why the linking is
> that way.

while a pf_state is alive, it's reference to the pf_state_key is held.
pf_states always have pf_state_keys now, therefore the tuple is
accessible either way. however, pf_states are what holds actual
connection info, such as the state of the connection (ESTABLISHED
etc) and the tcp windows etc.

currently we have to go from the state_keys to the state in pf anyway.
directly linking them now we have better lifetime management for pf_states
is more precise and avoids pointer chasing in the hot path.

> The mbuf linking is just to transport the information up and down
> the stack to get the link with the first packet.  There are corner
> cases that need this.  Especially port reuse and connection abort
> is tricky.
> 
> I always forget why state key to states is a 1:n relation.  But I
> expect touble with connectionless divert rules when we change the
> PCB with state linking to 1:n.  Idea is to keep one PCB and one
> state key in sync.

i was in the room when they were (re)designing this, but it's a
long time ago so my memory may be fuzzy. what i do remember is that
they were trying to save memory by deduplicating the addresssing
info, and better supporting interface bound states by making the
matching more deterministic. this also allowed doing a pointer
comparison between the state keys to detect the NAT case, which
sped things up.

a (connected) pcb manages a connection in the stack, the equivalent of
which is a state in pf. the fact that states have pf_state_keys is an
implementation detail for states. pf_state_keys represent addresses, not
a connection.

> For the same reason we have sk_reverse.  It links both state keys
> 1:1.  Like sk_inp links PCB and state keys 1:1.  I expect sublte
> breakage in our firewall product if we change that.

i've been running an older version of this for nearly two years without
issue. i'm using different features to you though, so there is some
risk.

i updated the diff last week and fixed it up for current, and simplified
it so it follows the existing semantics better. it's passing the
pf_divert regress with flying colours.

apart from pf_find_state, the diff is a pretty straightforward move of
the pointers.

tl;dr: less pointer chasing, and more precise links between things
representing connections in the kernel.

it also lets me move forward with source and state pools.

Index: kern/uipc_mbuf.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_mbuf.c,v
diff -u -p -r1.296 uipc_mbuf.c
--- kern/uipc_mbuf.c	1 Jan 2025 13:44:22 -0000	1.296
+++ kern/uipc_mbuf.c	9 May 2025 02:48:26 -0000
@@ -301,7 +301,7 @@ m_clearhdr(struct mbuf *m)
 	/* delete all mbuf tags to reset the state */
 	m_tag_delete_chain(m);
 #if NPF > 0
-	pf_mbuf_unlink_state_key(m);
+	pf_mbuf_unlink_state(m);
 	pf_mbuf_unlink_inpcb(m);
 #endif	/* NPF > 0 */
 
@@ -429,7 +429,7 @@ m_free(struct mbuf *m)
 	if (m->m_flags & M_PKTHDR) {
 		m_tag_delete_chain(m);
 #if NPF > 0
-		pf_mbuf_unlink_state_key(m);
+		pf_mbuf_unlink_state(m);
 		pf_mbuf_unlink_inpcb(m);
 #endif	/* NPF > 0 */
 	}
@@ -1387,8 +1387,8 @@ m_dup_pkthdr(struct mbuf *to, struct mbu
 	to->m_pkthdr = from->m_pkthdr;
 
 #if NPF > 0
-	to->m_pkthdr.pf.statekey = NULL;
-	pf_mbuf_link_state_key(to, from->m_pkthdr.pf.statekey);
+	to->m_pkthdr.pf.st = NULL;
+	pf_mbuf_link_state(to, from->m_pkthdr.pf.st);
 	to->m_pkthdr.pf.inp = NULL;
 	pf_mbuf_link_inpcb(to, from->m_pkthdr.pf.inp);
 #endif	/* NPF > 0 */
@@ -1517,8 +1517,8 @@ m_print(void *v,
 		    m->m_pkthdr.csum_flags, MCS_BITS);
 		(*pr)("m_pkthdr.ether_vtag: %u\tm_ptkhdr.ph_rtableid: %u\n",
 		    m->m_pkthdr.ether_vtag, m->m_pkthdr.ph_rtableid);
-		(*pr)("m_pkthdr.pf.statekey: %p\tm_pkthdr.pf.inp %p\n",
-		    m->m_pkthdr.pf.statekey, m->m_pkthdr.pf.inp);
+		(*pr)("m_pkthdr.pf.st: %p\tm_pkthdr.pf.inp %p\n",
+		    m->m_pkthdr.pf.st, m->m_pkthdr.pf.inp);
 		(*pr)("m_pkthdr.pf.qid: %u\tm_pkthdr.pf.tag: %u\n",
 		    m->m_pkthdr.pf.qid, m->m_pkthdr.pf.tag);
 		(*pr)("m_pkthdr.pf.flags: %b\n",
Index: net/if_mpw.c
===================================================================
RCS file: /cvs/src/sys/net/if_mpw.c,v
diff -u -p -r1.67 if_mpw.c
--- net/if_mpw.c	2 Mar 2025 21:28:32 -0000	1.67
+++ net/if_mpw.c	9 May 2025 02:48:26 -0000
@@ -612,7 +612,7 @@ mpw_input(struct mpw_softc *sc, struct m
 	m->m_pkthdr.ph_rtableid = ifp->if_rdomain;
 
 	/* packet has not been processed by PF yet. */
-	KASSERT(m->m_pkthdr.pf.statekey == NULL);
+	KASSERT(m->m_pkthdr.pf.st == NULL);
 
 	if_vinput(ifp, m, NULL);
 	return;
Index: net/if_tpmr.c
===================================================================
RCS file: /cvs/src/sys/net/if_tpmr.c,v
diff -u -p -r1.36 if_tpmr.c
--- net/if_tpmr.c	2 Mar 2025 21:28:32 -0000	1.36
+++ net/if_tpmr.c	9 May 2025 02:48:26 -0000
@@ -304,7 +304,7 @@ tpmr_pf(struct ifnet *ifp0, int dir, str
 		return (NULL);
 
 	if (dir == PF_IN && ISSET(m->m_pkthdr.pf.flags, PF_TAG_DIVERTED)) {
-		pf_mbuf_unlink_state_key(m);
+		pf_mbuf_unlink_state(m);
 		pf_mbuf_unlink_inpcb(m);
 		(*fam->ip_input)(ifp0, m, ns);
 		return (NULL);
Index: net/if_veb.c
===================================================================
RCS file: /cvs/src/sys/net/if_veb.c,v
diff -u -p -r1.37 if_veb.c
--- net/if_veb.c	2 Mar 2025 21:28:32 -0000	1.37
+++ net/if_veb.c	9 May 2025 02:48:26 -0000
@@ -654,7 +654,7 @@ veb_pf(struct ifnet *ifp0, int dir, stru
 		return (NULL);
 
 	if (dir == PF_IN && ISSET(m->m_pkthdr.pf.flags, PF_TAG_DIVERTED)) {
-		pf_mbuf_unlink_state_key(m);
+		pf_mbuf_unlink_state(m);
 		pf_mbuf_unlink_inpcb(m);
 		(*fam->ip_input)(ifp0, m, ns);
 		return (NULL);
Index: net/pf.c
===================================================================
RCS file: /cvs/src/sys/net/pf.c,v
diff -u -p -r1.1210 pf.c
--- net/pf.c	1 May 2025 01:10:08 -0000	1.1210
+++ net/pf.c	9 May 2025 02:48:26 -0000
@@ -247,15 +247,16 @@ int			 pf_state_insert(struct pfi_kif *,
 			    struct pf_state_key **, struct pf_state_key **,
 			    struct pf_state *);
 
+int			 pf_state_isvalid(struct pf_state *);
 int			 pf_state_key_isvalid(struct pf_state_key *);
 struct pf_state_key	*pf_state_key_ref(struct pf_state_key *);
 void			 pf_state_key_unref(struct pf_state_key *);
-void			 pf_state_key_link_reverse(struct pf_state_key *,
-			    struct pf_state_key *);
-void			 pf_state_key_unlink_reverse(struct pf_state_key *);
-void			 pf_state_key_link_inpcb(struct pf_state_key *,
+void			 pf_state_link_reverse(struct pf_state *,
+			    struct pf_state *);
+void			 pf_state_unlink_reverse(struct pf_state *);
+void			 pf_state_link_inpcb(struct pf_state *,
 			    struct inpcb *);
-void			 pf_state_key_unlink_inpcb(struct pf_state_key *);
+void			 pf_state_unlink_inpcb(struct pf_state *);
 void			 pf_pktenqueue_delayed(void *);
 int32_t			 pf_state_expires(const struct pf_state *, uint8_t);
 
@@ -860,8 +861,6 @@ pf_state_key_detach(struct pf_state *st,
 	if (TAILQ_EMPTY(&sk->sk_states)) {
 		RBT_REMOVE(pf_state_tree, &pf_statetbl, sk);
 		sk->sk_removed = 1;
-		pf_state_key_unlink_reverse(sk);
-		pf_state_key_unlink_inpcb(sk);
 		pf_state_key_unref(sk);
 	}
 
@@ -1123,13 +1122,40 @@ pf_compare_state_keys(struct pf_state_ke
 	}
 }
 
+static inline struct pf_state *
+pf_find_state_lookup(struct pf_pdesc *pd, const struct pf_state_key_cmp *key)
+{
+	struct pf_state_key	*sk;
+	struct pf_state_item	*si;
+	struct pf_state		*st;
+	uint8_t			 dir = pd->dir;
+
+	sk = RBT_FIND(pf_state_tree, &pf_statetbl, (struct pf_state_key *)key);
+	if (sk == NULL)
+		return (NULL);
+
+	/* list is sorted, if-bound states before floating ones */
+	TAILQ_FOREACH(si, &sk->sk_states, si_entry) {
+		st = si->si_st;
+		if (st->timeout == PFTM_PURGE)
+			continue;
+		if (st->kif != pfi_all && st->kif != pd->kif)
+			continue;
+
+		if (st->key[dir == PF_IN ? PF_SK_WIRE : PF_SK_STACK] == sk)
+			return (st);
+	}
+
+	return (NULL);
+}
+
 int
 pf_find_state(struct pf_pdesc *pd, struct pf_state_key_cmp *key,
     struct pf_state **stp)
 {
-	struct pf_state_key	*sk, *pkt_sk;
-	struct pf_state_item	*si;
 	struct pf_state		*st = NULL;
+	struct pf_state		*strev = NULL;
+	struct inpcb		*inp = NULL;
 
 	pf_status.fcounters[FCNT_STATE_SEARCH]++;
 	if (pf_status.debug >= LOG_DEBUG) {
@@ -1139,82 +1165,60 @@ pf_find_state(struct pf_pdesc *pd, struc
 		addlog("\n");
 	}
 
-	pkt_sk = NULL;
-	sk = NULL;
 	if (pd->dir == PF_OUT) {
-		/* first if block deals with outbound forwarded packet */
-		pkt_sk = pd->m->m_pkthdr.pf.statekey;
+		strev = pd->m->m_pkthdr.pf.st;
+		inp = pd->m->m_pkthdr.pf.inp;
 
-		if (!pf_state_key_isvalid(pkt_sk)) {
-			pf_mbuf_unlink_state_key(pd->m);
-			pkt_sk = NULL;
-		}
+		/* first if block deals with outbound forwarded packet */
+		if (strev != NULL) {
+			pd->m->m_pkthdr.pf.st = NULL;
+			KASSERT(inp == NULL);
 
-		if (pkt_sk && pf_state_key_isvalid(pkt_sk->sk_reverse))
-			sk = pkt_sk->sk_reverse;
+			if (pf_state_isvalid(strev)) {
+				st = strev->reverse;
+				if (st != NULL && pf_state_isvalid(st))
+					goto match;
+			}
 
-		if (pkt_sk == NULL) {
-			struct inpcb *inp = pd->m->m_pkthdr.pf.inp;
+			/* this handles st not being valid too */
+			pf_state_unlink_reverse(strev);
 
+		} else if (inp != NULL && READ_ONCE(inp->inp_pf_st) != NULL) {
 			/* here we deal with local outbound packet */
-			if (inp != NULL) {
-				struct pf_state_key	*inp_sk;
-
-				mtx_enter(&pf_inp_mtx);
-				inp_sk = inp->inp_pf_sk;
-				if (pf_state_key_isvalid(inp_sk)) {
-					sk = inp_sk;
+			mtx_enter(&pf_inp_mtx);
+			st = inp->inp_pf_st;
+			if (st != NULL) {
+				if (pf_state_isvalid(st)) {
 					mtx_leave(&pf_inp_mtx);
-				} else if (inp_sk != NULL) {
-					KASSERT(inp_sk->sk_inp == inp);
-					inp_sk->sk_inp = NULL;
-					inp->inp_pf_sk = NULL;
+					goto match;
+				} else {
+					KASSERT(st->inp == inp);
+					st->inp = NULL;
+					inp->inp_pf_st = NULL;
 					mtx_leave(&pf_inp_mtx);
 
-					pf_state_key_unref(inp_sk);
+					pf_state_unref(st);
 					in_pcbunref(inp);
-				} else
-					mtx_leave(&pf_inp_mtx);
-			}
-		}
-	}
-
-	if (sk == NULL) {
-		if ((sk = RBT_FIND(pf_state_tree, &pf_statetbl,
-		    (struct pf_state_key *)key)) == NULL)
-			return (PF_DROP);
-		if (pd->dir == PF_OUT && pkt_sk &&
-		    pf_compare_state_keys(pkt_sk, sk, pd->kif, pd->dir) == 0)
-			pf_state_key_link_reverse(sk, pkt_sk);
-		else if (pd->dir == PF_OUT)
-			pf_state_key_link_inpcb(sk, pd->m->m_pkthdr.pf.inp);
-	}
-
-	/* remove firewall data from outbound packet */
-	if (pd->dir == PF_OUT)
-		pf_pkt_addr_changed(pd->m);
-
-	/* list is sorted, if-bound states before floating ones */
-	TAILQ_FOREACH(si, &sk->sk_states, si_entry) {
-		struct pf_state *sist = si->si_st;
-		if (sist->timeout != PFTM_PURGE &&
-		    (sist->kif == pfi_all || sist->kif == pd->kif) &&
-		    ((sist->key[PF_SK_WIRE]->af == sist->key[PF_SK_STACK]->af &&
-		      sk == (pd->dir == PF_IN ? sist->key[PF_SK_WIRE] :
-		    sist->key[PF_SK_STACK])) ||
-		    (sist->key[PF_SK_WIRE]->af != sist->key[PF_SK_STACK]->af
-		    && pd->dir == PF_IN && (sk == sist->key[PF_SK_STACK] ||
-		    sk == sist->key[PF_SK_WIRE])))) {
-			st = sist;
-			break;
+				}
+			} else
+				mtx_leave(&pf_inp_mtx);
 		}
 	}
 
+	st = pf_find_state_lookup(pd, key);
 	if (st == NULL)
 		return (PF_DROP);
 	if (ISSET(st->state_flags, PFSTATE_INP_UNLINKED))
 		return (PF_DROP);
 
+	if (pd->dir == PF_OUT) {
+		if (strev != NULL)
+			pf_state_link_reverse(st, strev);
+		else if (inp != NULL)
+			pf_state_link_inpcb(st, inp);
+	}
+
+match:
 	if (st->rule.ptr->pktrate.limit && pd->dir == st->direction) {
 		pf_add_threshold(&st->rule.ptr->pktrate);
 		if (pf_check_threshold(&st->rule.ptr->pktrate))
@@ -1782,6 +1786,9 @@ pf_remove_state(struct pf_state *st)
 	st->timeout = PFTM_UNLINKED;
 	mtx_leave(&st->mtx);
 
+	pf_state_unlink_reverse(st);
+	pf_state_unlink_inpcb(st);
+
 	/* handle load balancing related tasks */
 	pf_postprocess_addr(st);
 
@@ -1813,52 +1820,47 @@ pf_remove_state(struct pf_state *st)
 void
 pf_remove_divert_state(struct inpcb *inp)
 {
-	struct pf_state_key	*sk;
-	struct pf_state_item	*si;
+	struct pf_state		*st;
 
 	PF_ASSERT_UNLOCKED();
 
-	if (READ_ONCE(inp->inp_pf_sk) == NULL)
+	if (READ_ONCE(inp->inp_pf_st) == NULL)
 		return;
 
 	mtx_enter(&pf_inp_mtx);
-	sk = pf_state_key_ref(inp->inp_pf_sk);
+	st = pf_state_ref(inp->inp_pf_st);
 	mtx_leave(&pf_inp_mtx);
-	if (sk == NULL)
+	if (st == NULL)
 		return;
 
 	PF_LOCK();
 	PF_STATE_ENTER_WRITE();
-	TAILQ_FOREACH(si, &sk->sk_states, si_entry) {
-		struct pf_state *sist = si->si_st;
-		if (sk == sist->key[PF_SK_STACK] && sist->rule.ptr &&
-		    (sist->rule.ptr->divert.type == PF_DIVERT_TO ||
-		     sist->rule.ptr->divert.type == PF_DIVERT_REPLY)) {
-			if (sist->key[PF_SK_STACK]->proto == IPPROTO_TCP &&
-			    sist->key[PF_SK_WIRE] != sist->key[PF_SK_STACK]) {
-				/*
-				 * If the local address is translated, keep
-				 * the state for "tcp.closed" seconds to
-				 * prevent its source port from being reused.
-				 */
-				if (sist->src.state < TCPS_FIN_WAIT_2 ||
-				    sist->dst.state < TCPS_FIN_WAIT_2) {
-					pf_set_protostate(sist, PF_PEER_BOTH,
-					    TCPS_TIME_WAIT);
-					pf_update_state_timeout(sist,
-					    PFTM_TCP_CLOSED);
-					sist->expire = getuptime();
-				}
-				sist->state_flags |= PFSTATE_INP_UNLINKED;
-			} else
-				pf_remove_state(sist);
-			break;
-		}
+	if (st->rule.ptr &&
+	    (st->rule.ptr->divert.type == PF_DIVERT_TO ||
+	     st->rule.ptr->divert.type == PF_DIVERT_REPLY)) {
+		if (st->key[PF_SK_STACK]->proto == IPPROTO_TCP &&
+		    st->key[PF_SK_WIRE] != st->key[PF_SK_STACK]) {
+			/*
+			 * If the local address is translated, keep
+			 * the state for "tcp.closed" seconds to
+			 * prevent its source port from being reused.
+			 */
+			if (st->src.state < TCPS_FIN_WAIT_2 ||
+			    st->dst.state < TCPS_FIN_WAIT_2) {
+				pf_set_protostate(st, PF_PEER_BOTH,
+				    TCPS_TIME_WAIT);
+				pf_update_state_timeout(st,
+				    PFTM_TCP_CLOSED);
+				st->expire = getuptime();
+			}
+			st->state_flags |= PFSTATE_INP_UNLINKED;
+		} else
+			pf_remove_state(st);
 	}
 	PF_STATE_EXIT_WRITE();
 	PF_UNLOCK();
 
-	pf_state_key_unref(sk);
+	pf_state_unref(st);
 }
 
 void
@@ -7916,9 +7918,9 @@ done:
 
 		if (pd.dir == PF_IN) {
 			KASSERT(inp == NULL);
-			pf_mbuf_link_state_key(m, st->key[PF_SK_STACK]);
+			pf_mbuf_link_state(m, st);
 		} else if (pd.dir == PF_OUT)
-			pf_state_key_link_inpcb(st->key[PF_SK_STACK], inp);
+			pf_state_link_inpcb(st, inp);
 
 		if (!ISSET(m->m_pkthdr.csum_flags, M_FLOWID)) {
 			m->m_pkthdr.ph_flowid = st->key[PF_SK_WIRE]->hash;
@@ -8100,14 +8102,14 @@ done:
 int
 pf_ouraddr(struct mbuf *m)
 {
-	struct pf_state_key	*sk;
+	struct pf_state		*st;
 
 	if (m->m_pkthdr.pf.flags & PF_TAG_DIVERTED)
 		return (1);
 
-	sk = m->m_pkthdr.pf.statekey;
-	if (sk != NULL) {
-		if (READ_ONCE(sk->sk_inp) != NULL)
+	st = m->m_pkthdr.pf.st;
+	if (st != NULL) {
+		if (READ_ONCE(st->inp) != NULL)
 			return (1);
 	}
 
@@ -8121,7 +8123,7 @@ pf_ouraddr(struct mbuf *m)
 void
 pf_pkt_addr_changed(struct mbuf *m)
 {
-	pf_mbuf_unlink_state_key(m);
+	pf_mbuf_unlink_state(m);
 	pf_mbuf_unlink_inpcb(m);
 }
 
@@ -8129,85 +8131,97 @@ struct inpcb *
 pf_inp_lookup(struct mbuf *m)
 {
 	struct inpcb *inp = NULL;
-	struct pf_state_key *sk = m->m_pkthdr.pf.statekey;
+	struct pf_state *st;
+
+	st = m->m_pkthdr.pf.st;
+	if (st == NULL)
+		return (NULL);
+	if (!pf_state_isvalid(st)) {
+		pf_mbuf_unlink_state(m);
+		return (NULL);
+	}
 
-	if (!pf_state_key_isvalid(sk))
-		pf_mbuf_unlink_state_key(m);
-	else if (READ_ONCE(sk->sk_inp) != NULL) {
+	if (READ_ONCE(st->inp) != NULL) {
 		mtx_enter(&pf_inp_mtx);
-		inp = in_pcbref(sk->sk_inp);
+		inp = in_pcbref(st->inp);
 		mtx_leave(&pf_inp_mtx);
 	}
 
 	return (inp);
 }
 
+/*
+ * This is called from the IP stack after it's found an inpcb for
+ * an mbuf so it can link the pf_state to that pcb.
+ */
 void
 pf_inp_link(struct mbuf *m, struct inpcb *inp)
 {
-	struct pf_state_key *sk = m->m_pkthdr.pf.statekey;
+	struct pf_state *st;
 
-	if (!pf_state_key_isvalid(sk)) {
-		pf_mbuf_unlink_state_key(m);
+	st = m->m_pkthdr.pf.st;
+	if (st == NULL)
 		return;
-	}
 
 	/*
 	 * we don't need to grab PF-lock here. At worst case we link inp to
 	 * state, which might be just being marked as deleted by another
 	 * thread.
 	 */
-	pf_state_key_link_inpcb(sk, inp);
+	if (pf_state_isvalid(st)) {
+		if (READ_ONCE(st->inp) == NULL)
+			pf_state_link_inpcb(st, inp);
+	}
 
 	/* The statekey has finished finding the inp, it is no longer needed. */
-	pf_mbuf_unlink_state_key(m);
+	pf_mbuf_unlink_state(m);
 }
 
 void
 pf_inp_unlink(struct inpcb *inp)
 {
-	struct pf_state_key *sk;
+	struct pf_state *st;
 
-	if (READ_ONCE(inp->inp_pf_sk) == NULL)
+	if (READ_ONCE(inp->inp_pf_st) == NULL)
 		return;
 
 	mtx_enter(&pf_inp_mtx);
-	sk = inp->inp_pf_sk;
-	if (sk == NULL) {
+	st = inp->inp_pf_st;
+	if (st == NULL) {
 		mtx_leave(&pf_inp_mtx);
 		return;
 	}
-	KASSERT(sk->sk_inp == inp);
-	sk->sk_inp = NULL;
-	inp->inp_pf_sk = NULL;
+	KASSERT(st->inp == inp);
+	st->inp = NULL;
+	inp->inp_pf_st = NULL;
 	mtx_leave(&pf_inp_mtx);
 
-	pf_state_key_unref(sk);
+	pf_state_unref(st);
 	in_pcbunref(inp);
 }
 
 void
-pf_state_key_link_reverse(struct pf_state_key *sk, struct pf_state_key *skrev)
+pf_state_link_reverse(struct pf_state *st, struct pf_state *strev)
 {
-	struct pf_state_key *old_reverse;
+	struct pf_state *ost;
 
-	old_reverse = atomic_cas_ptr(&sk->sk_reverse, NULL, skrev);
-	if (old_reverse != NULL)
-		KASSERT(old_reverse == skrev);
+	ost = atomic_cas_ptr(&st->reverse, NULL, strev);
+	if (ost != NULL)
+		KASSERT(ost == strev);
 	else {
-		pf_state_key_ref(skrev);
+		pf_state_ref(strev);
 
 		/*
-		 * NOTE: if sk == skrev, then KASSERT() below holds true, we
+		 * NOTE: if st == strev, then KASSERT() below holds true, we
 		 * still want to grab a reference in such case, because
-		 * pf_state_key_unlink_reverse() does not check whether keys
+		 * pf_state_unlink_reverse() does not check whether states
 		 * are identical or not.
 		 */
-		old_reverse = atomic_cas_ptr(&skrev->sk_reverse, NULL, sk);
-		if (old_reverse != NULL)
-			KASSERT(old_reverse == sk);
+		ost = atomic_cas_ptr(&strev->reverse, NULL, st);
+		if (ost != NULL)
+			KASSERT(ost == st);
 
-		pf_state_key_ref(sk);
+		pf_state_ref(st);
 	}
 }
 
@@ -8243,10 +8257,6 @@ pf_state_key_unref(struct pf_state_key *
 	if (PF_REF_RELE(sk->sk_refcnt)) {
 		/* state key must be removed from tree */
 		KASSERT(!pf_state_key_isvalid(sk));
-		/* state key must be unlinked from reverse key */
-		KASSERT(sk->sk_reverse == NULL);
-		/* state key must be unlinked from socket */
-		KASSERT(sk->sk_inp == NULL);
 		pool_put(&pf_state_key_pl, sk);
 	}
 }
@@ -8257,21 +8267,28 @@ pf_state_key_isvalid(struct pf_state_key
 	return ((sk != NULL) && (sk->sk_removed == 0));
 }
 
+int
+pf_state_isvalid(struct pf_state *st)
+{
+	return (st->timeout < PFTM_MAX);
+}
+
 void
-pf_mbuf_link_state_key(struct mbuf *m, struct pf_state_key *sk)
+pf_mbuf_link_state(struct mbuf *m, struct pf_state *st)
 {
-	KASSERT(m->m_pkthdr.pf.statekey == NULL);
-	m->m_pkthdr.pf.statekey = pf_state_key_ref(sk);
+	KASSERT(m->m_pkthdr.pf.st == NULL);
+	m->m_pkthdr.pf.st = pf_state_ref(st);
 }
 
 void
-pf_mbuf_unlink_state_key(struct mbuf *m)
+pf_mbuf_unlink_state(struct mbuf *m)
 {
-	struct pf_state_key *sk = m->m_pkthdr.pf.statekey;
+	struct pf_state *st;
 
-	if (sk != NULL) {
-		m->m_pkthdr.pf.statekey = NULL;
-		pf_state_key_unref(sk);
+	st = m->m_pkthdr.pf.st;
+	if (st != NULL) {
+		m->m_pkthdr.pf.st = NULL;
+		pf_state_unref(st);
 	}
 }
 
@@ -8294,57 +8311,58 @@ pf_mbuf_unlink_inpcb(struct mbuf *m)
 }
 
 void
-pf_state_key_link_inpcb(struct pf_state_key *sk, struct inpcb *inp)
+pf_state_link_inpcb(struct pf_state *st, struct inpcb *inp)
 {
-	if (inp == NULL || READ_ONCE(sk->sk_inp) != NULL)
+	if (inp == NULL || READ_ONCE(st->inp) != NULL)
 		return;
 
 	mtx_enter(&pf_inp_mtx);
-	if (inp->inp_pf_sk != NULL || sk->sk_inp != NULL) {
+	if (inp->inp_pf_st != NULL || st->inp != NULL) {
 		mtx_leave(&pf_inp_mtx);
 		return;
 	}
-	sk->sk_inp = in_pcbref(inp);
-	inp->inp_pf_sk = pf_state_key_ref(sk);
+	st->inp = in_pcbref(inp);
+	inp->inp_pf_st = pf_state_ref(st);
 	mtx_leave(&pf_inp_mtx);
 }
 
 void
-pf_state_key_unlink_inpcb(struct pf_state_key *sk)
+pf_state_unlink_inpcb(struct pf_state *st)
 {
 	struct inpcb *inp;
 
-	if (READ_ONCE(sk->sk_inp) == NULL)
+	if (READ_ONCE(st->inp) == NULL)
 		return;
 
 	mtx_enter(&pf_inp_mtx);
-	inp = sk->sk_inp;
+	inp = st->inp;
 	if (inp == NULL) {
 		mtx_leave(&pf_inp_mtx);
 		return;
 	}
-	KASSERT(inp->inp_pf_sk == sk);
-	sk->sk_inp = NULL;
-	inp->inp_pf_sk = NULL;
+	KASSERT(inp->inp_pf_st == st);
+	st->inp = NULL;
+	inp->inp_pf_st = NULL;
 	mtx_leave(&pf_inp_mtx);
 
-	pf_state_key_unref(sk);
+	pf_state_unref(st);
 	in_pcbunref(inp);
 }
 
 void
-pf_state_key_unlink_reverse(struct pf_state_key *sk)
+pf_state_unlink_reverse(struct pf_state *st)
 {
-	struct pf_state_key *skrev = sk->sk_reverse;
-
-	/* Note that sk and skrev may be equal, then we unref twice. */
-	if (skrev != NULL) {
-		KASSERT(skrev->sk_reverse == sk);
-		sk->sk_reverse = NULL;
-		skrev->sk_reverse = NULL;
-		pf_state_key_unref(skrev);
-		pf_state_key_unref(sk);
-	}
+	struct pf_state *strev;
+ 
+	/* Note that st and strev may be equal, then we unref twice. */
+	strev = st->reverse;
+	if (strev != NULL) {
+		KASSERT(strev->reverse == st);
+		st->reverse = NULL;
+		strev->reverse = NULL;
+		pf_state_unref(strev);
+		pf_state_unref(st);
+       }
 }
 
 struct pf_state *
@@ -8370,6 +8388,11 @@ pf_state_unref(struct pf_state *st)
 
 		pf_state_key_unref(st->key[PF_SK_WIRE]);
 		pf_state_key_unref(st->key[PF_SK_STACK]);
+
+		/* state must be unlinked from reverse */
+		KASSERT(st->reverse == NULL);
+		/* state must be unlinked from socket */
+		KASSERT(st->inp == NULL);
 
 		pool_put(&pf_state_pl, st);
 	}
Index: net/pfvar.h
===================================================================
RCS file: /cvs/src/sys/net/pfvar.h,v
diff -u -p -r1.543 pfvar.h
--- net/pfvar.h	14 Apr 2025 20:02:34 -0000	1.543
+++ net/pfvar.h	9 May 2025 02:48:26 -0000
@@ -1849,9 +1849,8 @@ int			 pf_map_addr(sa_family_t, struct p
 			    struct pf_pool *, enum pf_sn_types);
 int			 pf_postprocess_addr(struct pf_state *);
 
-void			 pf_mbuf_link_state_key(struct mbuf *,
-			    struct pf_state_key *);
-void			 pf_mbuf_unlink_state_key(struct mbuf *);
+void			 pf_mbuf_link_state(struct mbuf *, struct pf_state *);
+void			 pf_mbuf_unlink_state(struct mbuf *);
 void			 pf_mbuf_link_inpcb(struct mbuf *, struct inpcb *);
 void			 pf_mbuf_unlink_inpcb(struct mbuf *);
 
Index: net/pfvar_priv.h
===================================================================
RCS file: /cvs/src/sys/net/pfvar_priv.h,v
diff -u -p -r1.38 pfvar_priv.h
--- net/pfvar_priv.h	7 Sep 2024 22:41:55 -0000	1.38
+++ net/pfvar_priv.h	9 May 2025 02:48:26 -0000
@@ -41,11 +41,6 @@
 #include <sys/mutex.h>
 #include <sys/percpu.h>
 
-/*
- * Locks used to protect struct members in this file:
- *	L	pf_inp_mtx		link pf to inp mutex
- */
-
 struct pfsync_deferral;
 
 /*
@@ -74,8 +69,6 @@ struct pf_state_key {
 
 	RBT_ENTRY(pf_state_key)	 sk_entry;
 	struct pf_statelisthead	 sk_states;
-	struct pf_state_key	*sk_reverse;
-	struct inpcb		*sk_inp;	/* [L] */
 	pf_refcnt_t		 sk_refcnt;
 	u_int8_t		 sk_removed;
 };
@@ -95,7 +88,8 @@ RBT_PROTOTYPE(pf_state_tree, pf_state_ke
  *	M	pf_state mtx
  *	P	PF_STATE_LOCK
  *	S	pfsync
- *	L	pf_state_list
+ *	L	pf_inp_mtx		link pf to inp mutex
+ *	G	pf_state_list
  *	g	pf_purge gc
  */
 
@@ -107,7 +101,7 @@ struct pf_state {
 
 	TAILQ_ENTRY(pf_state)	 sync_list;	/* [S] */
 	struct pfsync_deferral	*sync_defer;	/* [S] */
-	TAILQ_ENTRY(pf_state)	 entry_list;	/* [L] */
+	TAILQ_ENTRY(pf_state)	 entry_list;	/* [G] */
 	SLIST_ENTRY(pf_state)	 gc_list;	/* [g] */
 	RBT_ENTRY(pf_state)	 entry_id;	/* [P] */
 	struct pf_state_peer	 src;
@@ -120,6 +114,8 @@ struct pf_state {
 	struct pf_sn_head	 src_nodes;	/* [I] */
 	struct pf_state_key	*key[2];	/* [I] stack and wire */
 	struct pfi_kif		*kif;		/* [I] */
+	struct pf_state		*reverse;
+	struct inpcb		*inp;		/* [L] */
 	struct mutex		 mtx;
 	pf_refcnt_t		 refcnt;
 	u_int64_t		 packets[2];
Index: netinet/in_pcb.h
===================================================================
RCS file: /cvs/src/sys/netinet/in_pcb.h,v
diff -u -p -r1.167 in_pcb.h
--- netinet/in_pcb.h	4 May 2025 23:05:17 -0000	1.167
+++ netinet/in_pcb.h	9 May 2025 02:48:26 -0000
@@ -110,7 +110,7 @@
  * Protocol input only reads inp_[lf]addr/port during lookup and is safe.
  */
 
-struct pf_state_key;
+struct pf_state;
 
 union inpaddru {
 	struct in_addr iau_addr;
@@ -168,7 +168,7 @@ struct inpcb {
 
 	int	inp_cksum6;
 	struct	icmp6_filter *inp_icmp6filt;
-	struct	pf_state_key *inp_pf_sk; /* [L] */
+	struct	pf_state *inp_pf_st;	/* [L] */
 	struct	mbuf *(*inp_upcall)(void *, struct mbuf *,
 	    struct ip *, struct ip6_hdr *, void *, int, struct netstack *);
 	void	*inp_upcall_arg;
Index: sys/mbuf.h
===================================================================
RCS file: /cvs/src/sys/sys/mbuf.h,v
diff -u -p -r1.265 mbuf.h
--- sys/mbuf.h	5 Nov 2024 13:15:13 -0000	1.265
+++ sys/mbuf.h	9 May 2025 02:48:26 -0000
@@ -90,11 +90,11 @@ struct m_hdr {
 };
 
 /* pf stuff */
-struct pf_state_key;
+struct pf_state;
 struct inpcb;
 
 struct pkthdr_pf {
-	struct pf_state_key *statekey;	/* pf stackside statekey */
+	struct pf_state	*st;		/* pf state */
 	struct inpcb	*inp;		/* connected pcb for outgoing packet */
 	u_int32_t	 qid;		/* queue id */
 	u_int16_t	 tag;		/* tag id */
@@ -325,7 +325,7 @@ u_int mextfree_register(void (*)(caddr_t
 	(to)->m_pkthdr = (from)->m_pkthdr;				\
 	(from)->m_flags &= ~M_PKTHDR;					\
 	SLIST_INIT(&(from)->m_pkthdr.ph_tags);				\
-	(from)->m_pkthdr.pf.statekey = NULL;				\
+	(from)->m_pkthdr.pf.st = NULL;					\
 } while (/* CONSTCOND */ 0)
 
 /*