diff --git a/include/net/ip.h b/include/net/ip.h index 47ed6d23853d..375304bb99f6 100644 --- a/include/net/ip.h +++ b/include/net/ip.h @@ -59,6 +59,7 @@ struct inet_skb_parm { #define IPSKB_L3SLAVE BIT(7) #define IPSKB_NOPOLICY BIT(8) #define IPSKB_MULTIPATH BIT(9) +#define IPSKB_MCROUTE BIT(10) u16 frag_max_size; }; @@ -167,6 +168,7 @@ void ip_list_rcv(struct list_head *head, struct packet_type *pt, int ip_local_deliver(struct sk_buff *skb); void ip_protocol_deliver_rcu(struct net *net, struct sk_buff *skb, int proto); int ip_mr_input(struct sk_buff *skb); +int ip_mr_output(struct net *net, struct sock *sk, struct sk_buff *skb); int ip_output(struct net *net, struct sock *sk, struct sk_buff *skb); int ip_mc_output(struct net *net, struct sock *sk, struct sk_buff *skb); int ip_do_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, diff --git a/net/ipv4/ipmr.c b/net/ipv4/ipmr.c index 74d45fd5d11e..f78c4e53dc8c 100644 --- a/net/ipv4/ipmr.c +++ b/net/ipv4/ipmr.c @@ -1965,6 +1965,19 @@ static void ipmr_queue_fwd_xmit(struct net *net, struct mr_table *mrt, kfree_skb(skb); } +static void ipmr_queue_output_xmit(struct net *net, struct mr_table *mrt, + struct sk_buff *skb, int vifi) +{ + if (ipmr_prepare_xmit(net, mrt, skb, vifi)) + goto out_free; + + ip_mc_output(net, NULL, skb); + return; + +out_free: + kfree_skb(skb); +} + /* Called with mrt_lock or rcu_read_lock() */ static int ipmr_find_vif(const struct mr_table *mrt, struct net_device *dev) { @@ -2224,6 +2237,110 @@ int ip_mr_input(struct sk_buff *skb) return 0; } +static void ip_mr_output_finish(struct net *net, struct mr_table *mrt, + struct net_device *dev, struct sk_buff *skb, + struct mfc_cache *c) +{ + int psend = -1; + int ct; + + atomic_long_inc(&c->_c.mfc_un.res.pkt); + atomic_long_add(skb->len, &c->_c.mfc_un.res.bytes); + WRITE_ONCE(c->_c.mfc_un.res.lastuse, jiffies); + + /* Forward the frame */ + if (c->mfc_origin == htonl(INADDR_ANY) && + c->mfc_mcastgrp == htonl(INADDR_ANY)) { + if (ip_hdr(skb)->ttl > + c->_c.mfc_un.res.ttls[c->_c.mfc_parent]) { + /* It's an (*,*) entry and the packet is not coming from + * the upstream: forward the packet to the upstream + * only. + */ + psend = c->_c.mfc_parent; + goto last_xmit; + } + goto dont_xmit; + } + + for (ct = c->_c.mfc_un.res.maxvif - 1; + ct >= c->_c.mfc_un.res.minvif; ct--) { + if (ip_hdr(skb)->ttl > c->_c.mfc_un.res.ttls[ct]) { + if (psend != -1) { + struct sk_buff *skb2; + + skb2 = skb_clone(skb, GFP_ATOMIC); + if (skb2) + ipmr_queue_output_xmit(net, mrt, + skb2, psend); + } + psend = ct; + } + } + +last_xmit: + if (psend != -1) { + ipmr_queue_output_xmit(net, mrt, skb, psend); + return; + } + +dont_xmit: + kfree_skb(skb); +} + +/* Multicast packets for forwarding arrive here + * Called with rcu_read_lock(); + */ +int ip_mr_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct rtable *rt = skb_rtable(skb); + struct mfc_cache *cache; + struct net_device *dev; + struct mr_table *mrt; + int vif; + + WARN_ON_ONCE(!rcu_read_lock_held()); + dev = rt->dst.dev; + + if (IPCB(skb)->flags & IPSKB_FORWARDED) + goto mc_output; + if (!(IPCB(skb)->flags & IPSKB_MCROUTE)) + goto mc_output; + + skb->dev = dev; + + mrt = ipmr_rt_fib_lookup(net, skb); + if (IS_ERR(mrt)) + goto mc_output; + + /* already under rcu_read_lock() */ + cache = ipmr_cache_find(mrt, ip_hdr(skb)->saddr, ip_hdr(skb)->daddr); + if (!cache) { + vif = ipmr_find_vif(mrt, dev); + if (vif >= 0) + cache = ipmr_cache_find_any(mrt, ip_hdr(skb)->daddr, + vif); + } + + /* No usable cache entry */ + if (!cache) { + vif = ipmr_find_vif(mrt, dev); + if (vif >= 0) + return ipmr_cache_unresolved(mrt, vif, skb, dev); + goto mc_output; + } + + vif = cache->_c.mfc_parent; + if (rcu_access_pointer(mrt->vif_table[vif].dev) != dev) + goto mc_output; + + ip_mr_output_finish(net, mrt, dev, skb, cache); + return 0; + +mc_output: + return ip_mc_output(net, sk, skb); +} + #ifdef CONFIG_IP_PIMSM_V1 /* Handle IGMP messages of PIMv1 */ int pim_rcv_v1(struct sk_buff *skb) diff --git a/net/ipv4/route.c b/net/ipv4/route.c index fccb05fb3a79..3ddf6bf40357 100644 --- a/net/ipv4/route.c +++ b/net/ipv4/route.c @@ -2660,7 +2660,7 @@ static struct rtable *__mkroute_output(const struct fib_result *res, if (IN_DEV_MFORWARD(in_dev) && !ipv4_is_local_multicast(fl4->daddr)) { rth->dst.input = ip_mr_input; - rth->dst.output = ip_mc_output; + rth->dst.output = ip_mr_output; } } #endif