These changes are the raw update to linux-4.4.6-rt14. Kernel sources
[kvmfornfv.git] / kernel / net / mpls / af_mpls.c
index 1f93a59..c32fc41 100644 (file)
 #include <net/ip_fib.h>
 #include <net/netevent.h>
 #include <net/netns/generic.h>
+#if IS_ENABLED(CONFIG_IPV6)
+#include <net/ipv6.h>
+#include <net/addrconf.h>
+#endif
+#include <net/nexthop.h>
 #include "internal.h"
 
-#define LABEL_NOT_SPECIFIED (1<<20)
-#define MAX_NEW_LABELS 2
-
-/* This maximum ha length copied from the definition of struct neighbour */
-#define MAX_VIA_ALEN (ALIGN(MAX_ADDR_LEN, sizeof(unsigned long)))
-
-struct mpls_route { /* next hop label forwarding entry */
-       struct net_device __rcu *rt_dev;
-       struct rcu_head         rt_rcu;
-       u32                     rt_label[MAX_NEW_LABELS];
-       u8                      rt_protocol; /* routing protocol that set this entry */
-       u8                      rt_labels;
-       u8                      rt_via_alen;
-       u8                      rt_via_table;
-       u8                      rt_via[0];
-};
+/* Maximum number of labels to look ahead at when selecting a path of
+ * a multipath route
+ */
+#define MAX_MP_SELECT_LABELS 4
+
+#define MPLS_NEIGH_TABLE_UNSPEC (NEIGH_LINK_TABLE + 1)
 
 static int zero = 0;
 static int label_limit = (1 << 20) - 1;
@@ -58,24 +53,40 @@ static inline struct mpls_dev *mpls_dev_get(const struct net_device *dev)
        return rcu_dereference_rtnl(dev->mpls_ptr);
 }
 
-static bool mpls_output_possible(const struct net_device *dev)
+bool mpls_output_possible(const struct net_device *dev)
 {
        return dev && (dev->flags & IFF_UP) && netif_carrier_ok(dev);
 }
+EXPORT_SYMBOL_GPL(mpls_output_possible);
+
+static u8 *__mpls_nh_via(struct mpls_route *rt, struct mpls_nh *nh)
+{
+       u8 *nh0_via = PTR_ALIGN((u8 *)&rt->rt_nh[rt->rt_nhn], VIA_ALEN_ALIGN);
+       int nh_index = nh - rt->rt_nh;
+
+       return nh0_via + rt->rt_max_alen * nh_index;
+}
+
+static const u8 *mpls_nh_via(const struct mpls_route *rt,
+                            const struct mpls_nh *nh)
+{
+       return __mpls_nh_via((struct mpls_route *)rt, (struct mpls_nh *)nh);
+}
 
-static unsigned int mpls_rt_header_size(const struct mpls_route *rt)
+static unsigned int mpls_nh_header_size(const struct mpls_nh *nh)
 {
        /* The size of the layer 2.5 labels to be added for this route */
-       return rt->rt_labels * sizeof(struct mpls_shim_hdr);
+       return nh->nh_labels * sizeof(struct mpls_shim_hdr);
 }
 
-static unsigned int mpls_dev_mtu(const struct net_device *dev)
+unsigned int mpls_dev_mtu(const struct net_device *dev)
 {
        /* The amount of data the layer 2 frame can hold */
        return dev->mtu;
 }
+EXPORT_SYMBOL_GPL(mpls_dev_mtu);
 
-static bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
+bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
 {
        if (skb->len <= mtu)
                return false;
@@ -85,20 +96,87 @@ static bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
 
        return true;
 }
+EXPORT_SYMBOL_GPL(mpls_pkt_too_big);
+
+static struct mpls_nh *mpls_select_multipath(struct mpls_route *rt,
+                                            struct sk_buff *skb, bool bos)
+{
+       struct mpls_entry_decoded dec;
+       struct mpls_shim_hdr *hdr;
+       bool eli_seen = false;
+       int label_index;
+       int nh_index = 0;
+       u32 hash = 0;
+
+       /* No need to look further into packet if there's only
+        * one path
+        */
+       if (rt->rt_nhn == 1)
+               goto out;
+
+       for (label_index = 0; label_index < MAX_MP_SELECT_LABELS && !bos;
+            label_index++) {
+               if (!pskb_may_pull(skb, sizeof(*hdr) * label_index))
+                       break;
+
+               /* Read and decode the current label */
+               hdr = mpls_hdr(skb) + label_index;
+               dec = mpls_entry_decode(hdr);
+
+               /* RFC6790 - reserved labels MUST NOT be used as keys
+                * for the load-balancing function
+                */
+               if (likely(dec.label >= MPLS_LABEL_FIRST_UNRESERVED)) {
+                       hash = jhash_1word(dec.label, hash);
+
+                       /* The entropy label follows the entropy label
+                        * indicator, so this means that the entropy
+                        * label was just added to the hash - no need to
+                        * go any deeper either in the label stack or in the
+                        * payload
+                        */
+                       if (eli_seen)
+                               break;
+               } else if (dec.label == MPLS_LABEL_ENTROPY) {
+                       eli_seen = true;
+               }
+
+               bos = dec.bos;
+               if (bos && pskb_may_pull(skb, sizeof(*hdr) * label_index +
+                                        sizeof(struct iphdr))) {
+                       const struct iphdr *v4hdr;
+
+                       v4hdr = (const struct iphdr *)(mpls_hdr(skb) +
+                                                      label_index);
+                       if (v4hdr->version == 4) {
+                               hash = jhash_3words(ntohl(v4hdr->saddr),
+                                                   ntohl(v4hdr->daddr),
+                                                   v4hdr->protocol, hash);
+                       } else if (v4hdr->version == 6 &&
+                               pskb_may_pull(skb, sizeof(*hdr) * label_index +
+                                             sizeof(struct ipv6hdr))) {
+                               const struct ipv6hdr *v6hdr;
+
+                               v6hdr = (const struct ipv6hdr *)(mpls_hdr(skb) +
+                                                               label_index);
+
+                               hash = __ipv6_addr_jhash(&v6hdr->saddr, hash);
+                               hash = __ipv6_addr_jhash(&v6hdr->daddr, hash);
+                               hash = jhash_1word(v6hdr->nexthdr, hash);
+                       }
+               }
+       }
+
+       nh_index = hash % rt->rt_nhn;
+out:
+       return &rt->rt_nh[nh_index];
+}
 
 static bool mpls_egress(struct mpls_route *rt, struct sk_buff *skb,
                        struct mpls_entry_decoded dec)
 {
-       /* RFC4385 and RFC5586 encode other packets in mpls such that
-        * they don't conflict with the ip version number, making
-        * decoding by examining the ip version correct in everything
-        * except for the strangest cases.
-        *
-        * The strange cases if we choose to support them will require
-        * manual configuration.
-        */
-       struct iphdr *hdr4;
-       bool success = true;
+       enum mpls_payload_type payload_type;
+       bool success = false;
 
        /* The IPv4 code below accesses through the IPv4 header
         * checksum, which is 12 bytes into the packet.
@@ -113,23 +191,32 @@ static bool mpls_egress(struct mpls_route *rt, struct sk_buff *skb,
        if (!pskb_may_pull(skb, 12))
                return false;
 
-       /* Use ip_hdr to find the ip protocol version */
-       hdr4 = ip_hdr(skb);
-       if (hdr4->version == 4) {
+       payload_type = rt->rt_payload_type;
+       if (payload_type == MPT_UNSPEC)
+               payload_type = ip_hdr(skb)->version;
+
+       switch (payload_type) {
+       case MPT_IPV4: {
+               struct iphdr *hdr4 = ip_hdr(skb);
                skb->protocol = htons(ETH_P_IP);
                csum_replace2(&hdr4->check,
                              htons(hdr4->ttl << 8),
                              htons(dec.ttl << 8));
                hdr4->ttl = dec.ttl;
+               success = true;
+               break;
        }
-       else if (hdr4->version == 6) {
+       case MPT_IPV6: {
                struct ipv6hdr *hdr6 = ipv6_hdr(skb);
                skb->protocol = htons(ETH_P_IPV6);
                hdr6->hop_limit = dec.ttl;
+               success = true;
+               break;
        }
-       else
-               /* version 0 and version 1 are used by pseudo wires */
-               success = false;
+       case MPT_UNSPEC:
+               break;
+       }
+
        return success;
 }
 
@@ -139,6 +226,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
        struct net *net = dev_net(dev);
        struct mpls_shim_hdr *hdr;
        struct mpls_route *rt;
+       struct mpls_nh *nh;
        struct mpls_entry_decoded dec;
        struct net_device *out_dev;
        struct mpls_dev *mdev;
@@ -176,8 +264,12 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
        if (!rt)
                goto drop;
 
+       nh = mpls_select_multipath(rt, skb, dec.bos);
+       if (!nh)
+               goto drop;
+
        /* Find the output device */
-       out_dev = rcu_dereference(rt->rt_dev);
+       out_dev = rcu_dereference(nh->nh_dev);
        if (!mpls_output_possible(out_dev))
                goto drop;
 
@@ -192,7 +284,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
        dec.ttl -= 1;
 
        /* Verify the destination can hold the packet */
-       new_header_size = mpls_rt_header_size(rt);
+       new_header_size = mpls_nh_header_size(nh);
        mtu = mpls_dev_mtu(out_dev);
        if (mpls_pkt_too_big(skb, mtu - new_header_size))
                goto drop;
@@ -220,13 +312,20 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
                /* Push the new labels */
                hdr = mpls_hdr(skb);
                bos = dec.bos;
-               for (i = rt->rt_labels - 1; i >= 0; i--) {
-                       hdr[i] = mpls_entry_encode(rt->rt_label[i], dec.ttl, 0, bos);
+               for (i = nh->nh_labels - 1; i >= 0; i--) {
+                       hdr[i] = mpls_entry_encode(nh->nh_label[i],
+                                                  dec.ttl, 0, bos);
                        bos = false;
                }
        }
 
-       err = neigh_xmit(rt->rt_via_table, out_dev, rt->rt_via, skb);
+       /* If via wasn't specified then send out using device address */
+       if (nh->nh_via_table == MPLS_NEIGH_TABLE_UNSPEC)
+               err = neigh_xmit(NEIGH_LINK_TABLE, out_dev,
+                                out_dev->dev_addr, skb);
+       else
+               err = neigh_xmit(nh->nh_via_table, out_dev,
+                                mpls_nh_via(rt, nh), skb);
        if (err)
                net_dbg_ratelimited("%s: packet transmission failed: %d\n",
                                    __func__, err);
@@ -248,25 +347,35 @@ static const struct nla_policy rtm_mpls_policy[RTA_MAX+1] = {
 };
 
 struct mpls_route_config {
-       u32             rc_protocol;
-       u32             rc_ifindex;
-       u16             rc_via_table;
-       u16             rc_via_alen;
-       u8              rc_via[MAX_VIA_ALEN];
-       u32             rc_label;
-       u32             rc_output_labels;
-       u32             rc_output_label[MAX_NEW_LABELS];
-       u32             rc_nlflags;
-       struct nl_info  rc_nlinfo;
+       u32                     rc_protocol;
+       u32                     rc_ifindex;
+       u8                      rc_via_table;
+       u8                      rc_via_alen;
+       u8                      rc_via[MAX_VIA_ALEN];
+       u32                     rc_label;
+       u8                      rc_output_labels;
+       u32                     rc_output_label[MAX_NEW_LABELS];
+       u32                     rc_nlflags;
+       enum mpls_payload_type  rc_payload_type;
+       struct nl_info          rc_nlinfo;
+       struct rtnexthop        *rc_mp;
+       int                     rc_mp_len;
 };
 
-static struct mpls_route *mpls_rt_alloc(size_t alen)
+static struct mpls_route *mpls_rt_alloc(int num_nh, u8 max_alen)
 {
+       u8 max_alen_aligned = ALIGN(max_alen, VIA_ALEN_ALIGN);
        struct mpls_route *rt;
 
-       rt = kzalloc(sizeof(*rt) + alen, GFP_KERNEL);
-       if (rt)
-               rt->rt_via_alen = alen;
+       rt = kzalloc(ALIGN(sizeof(*rt) + num_nh * sizeof(*rt->rt_nh),
+                          VIA_ALEN_ALIGN) +
+                    num_nh * max_alen_aligned,
+                    GFP_KERNEL);
+       if (rt) {
+               rt->rt_nhn = num_nh;
+               rt->rt_max_alen = max_alen_aligned;
+       }
+
        return rt;
 }
 
@@ -286,30 +395,27 @@ static void mpls_notify_route(struct net *net, unsigned index,
        struct mpls_route *rt = new ? new : old;
        unsigned nlm_flags = (old && new) ? NLM_F_REPLACE : 0;
        /* Ignore reserved labels for now */
-       if (rt && (index >= 16))
+       if (rt && (index >= MPLS_LABEL_FIRST_UNRESERVED))
                rtmsg_lfib(event, index, rt, nlh, net, portid, nlm_flags);
 }
 
 static void mpls_route_update(struct net *net, unsigned index,
-                             struct net_device *dev, struct mpls_route *new,
+                             struct mpls_route *new,
                              const struct nl_info *info)
 {
        struct mpls_route __rcu **platform_label;
-       struct mpls_route *rt, *old = NULL;
+       struct mpls_route *rt;
 
        ASSERT_RTNL();
 
        platform_label = rtnl_dereference(net->mpls.platform_label);
        rt = rtnl_dereference(platform_label[index]);
-       if (!dev || (rt && (rtnl_dereference(rt->rt_dev) == dev))) {
-               rcu_assign_pointer(platform_label[index], new);
-               old = rt;
-       }
+       rcu_assign_pointer(platform_label[index], new);
 
-       mpls_notify_route(net, index, old, new, info);
+       mpls_notify_route(net, index, rt, new, info);
 
        /* If we removed a route free it now */
-       mpls_rt_free(old);
+       mpls_rt_free(rt);
 }
 
 static unsigned find_free_label(struct net *net)
@@ -320,22 +426,300 @@ static unsigned find_free_label(struct net *net)
 
        platform_label = rtnl_dereference(net->mpls.platform_label);
        platform_labels = net->mpls.platform_labels;
-       for (index = 16; index < platform_labels; index++) {
+       for (index = MPLS_LABEL_FIRST_UNRESERVED; index < platform_labels;
+            index++) {
                if (!rtnl_dereference(platform_label[index]))
                        return index;
        }
        return LABEL_NOT_SPECIFIED;
 }
 
+#if IS_ENABLED(CONFIG_INET)
+static struct net_device *inet_fib_lookup_dev(struct net *net,
+                                             const void *addr)
+{
+       struct net_device *dev;
+       struct rtable *rt;
+       struct in_addr daddr;
+
+       memcpy(&daddr, addr, sizeof(struct in_addr));
+       rt = ip_route_output(net, daddr.s_addr, 0, 0, 0);
+       if (IS_ERR(rt))
+               return ERR_CAST(rt);
+
+       dev = rt->dst.dev;
+       dev_hold(dev);
+
+       ip_rt_put(rt);
+
+       return dev;
+}
+#else
+static struct net_device *inet_fib_lookup_dev(struct net *net,
+                                             const void *addr)
+{
+       return ERR_PTR(-EAFNOSUPPORT);
+}
+#endif
+
+#if IS_ENABLED(CONFIG_IPV6)
+static struct net_device *inet6_fib_lookup_dev(struct net *net,
+                                              const void *addr)
+{
+       struct net_device *dev;
+       struct dst_entry *dst;
+       struct flowi6 fl6;
+       int err;
+
+       if (!ipv6_stub)
+               return ERR_PTR(-EAFNOSUPPORT);
+
+       memset(&fl6, 0, sizeof(fl6));
+       memcpy(&fl6.daddr, addr, sizeof(struct in6_addr));
+       err = ipv6_stub->ipv6_dst_lookup(net, NULL, &dst, &fl6);
+       if (err)
+               return ERR_PTR(err);
+
+       dev = dst->dev;
+       dev_hold(dev);
+       dst_release(dst);
+
+       return dev;
+}
+#else
+static struct net_device *inet6_fib_lookup_dev(struct net *net,
+                                              const void *addr)
+{
+       return ERR_PTR(-EAFNOSUPPORT);
+}
+#endif
+
+static struct net_device *find_outdev(struct net *net,
+                                     struct mpls_route *rt,
+                                     struct mpls_nh *nh, int oif)
+{
+       struct net_device *dev = NULL;
+
+       if (!oif) {
+               switch (nh->nh_via_table) {
+               case NEIGH_ARP_TABLE:
+                       dev = inet_fib_lookup_dev(net, mpls_nh_via(rt, nh));
+                       break;
+               case NEIGH_ND_TABLE:
+                       dev = inet6_fib_lookup_dev(net, mpls_nh_via(rt, nh));
+                       break;
+               case NEIGH_LINK_TABLE:
+                       break;
+               }
+       } else {
+               dev = dev_get_by_index(net, oif);
+       }
+
+       if (!dev)
+               return ERR_PTR(-ENODEV);
+
+       /* The caller is holding rtnl anyways, so release the dev reference */
+       dev_put(dev);
+
+       return dev;
+}
+
+static int mpls_nh_assign_dev(struct net *net, struct mpls_route *rt,
+                             struct mpls_nh *nh, int oif)
+{
+       struct net_device *dev = NULL;
+       int err = -ENODEV;
+
+       dev = find_outdev(net, rt, nh, oif);
+       if (IS_ERR(dev)) {
+               err = PTR_ERR(dev);
+               dev = NULL;
+               goto errout;
+       }
+
+       /* Ensure this is a supported device */
+       err = -EINVAL;
+       if (!mpls_dev_get(dev))
+               goto errout;
+
+       if ((nh->nh_via_table == NEIGH_LINK_TABLE) &&
+           (dev->addr_len != nh->nh_via_alen))
+               goto errout;
+
+       RCU_INIT_POINTER(nh->nh_dev, dev);
+
+       return 0;
+
+errout:
+       return err;
+}
+
+static int mpls_nh_build_from_cfg(struct mpls_route_config *cfg,
+                                 struct mpls_route *rt)
+{
+       struct net *net = cfg->rc_nlinfo.nl_net;
+       struct mpls_nh *nh = rt->rt_nh;
+       int err;
+       int i;
+
+       if (!nh)
+               return -ENOMEM;
+
+       err = -EINVAL;
+       /* Ensure only a supported number of labels are present */
+       if (cfg->rc_output_labels > MAX_NEW_LABELS)
+               goto errout;
+
+       nh->nh_labels = cfg->rc_output_labels;
+       for (i = 0; i < nh->nh_labels; i++)
+               nh->nh_label[i] = cfg->rc_output_label[i];
+
+       nh->nh_via_table = cfg->rc_via_table;
+       memcpy(__mpls_nh_via(rt, nh), cfg->rc_via, cfg->rc_via_alen);
+       nh->nh_via_alen = cfg->rc_via_alen;
+
+       err = mpls_nh_assign_dev(net, rt, nh, cfg->rc_ifindex);
+       if (err)
+               goto errout;
+
+       return 0;
+
+errout:
+       return err;
+}
+
+static int mpls_nh_build(struct net *net, struct mpls_route *rt,
+                        struct mpls_nh *nh, int oif,
+                        struct nlattr *via, struct nlattr *newdst)
+{
+       int err = -ENOMEM;
+
+       if (!nh)
+               goto errout;
+
+       if (newdst) {
+               err = nla_get_labels(newdst, MAX_NEW_LABELS,
+                                    &nh->nh_labels, nh->nh_label);
+               if (err)
+                       goto errout;
+       }
+
+       if (via) {
+               err = nla_get_via(via, &nh->nh_via_alen, &nh->nh_via_table,
+                                 __mpls_nh_via(rt, nh));
+               if (err)
+                       goto errout;
+       } else {
+               nh->nh_via_table = MPLS_NEIGH_TABLE_UNSPEC;
+       }
+
+       err = mpls_nh_assign_dev(net, rt, nh, oif);
+       if (err)
+               goto errout;
+
+       return 0;
+
+errout:
+       return err;
+}
+
+static int mpls_count_nexthops(struct rtnexthop *rtnh, int len,
+                              u8 cfg_via_alen, u8 *max_via_alen)
+{
+       int nhs = 0;
+       int remaining = len;
+
+       if (!rtnh) {
+               *max_via_alen = cfg_via_alen;
+               return 1;
+       }
+
+       *max_via_alen = 0;
+
+       while (rtnh_ok(rtnh, remaining)) {
+               struct nlattr *nla, *attrs = rtnh_attrs(rtnh);
+               int attrlen;
+
+               attrlen = rtnh_attrlen(rtnh);
+               nla = nla_find(attrs, attrlen, RTA_VIA);
+               if (nla && nla_len(nla) >=
+                   offsetof(struct rtvia, rtvia_addr)) {
+                       int via_alen = nla_len(nla) -
+                               offsetof(struct rtvia, rtvia_addr);
+
+                       if (via_alen <= MAX_VIA_ALEN)
+                               *max_via_alen = max_t(u16, *max_via_alen,
+                                                     via_alen);
+               }
+
+               nhs++;
+               rtnh = rtnh_next(rtnh, &remaining);
+       }
+
+       /* leftover implies invalid nexthop configuration, discard it */
+       return remaining > 0 ? 0 : nhs;
+}
+
+static int mpls_nh_build_multi(struct mpls_route_config *cfg,
+                              struct mpls_route *rt)
+{
+       struct rtnexthop *rtnh = cfg->rc_mp;
+       struct nlattr *nla_via, *nla_newdst;
+       int remaining = cfg->rc_mp_len;
+       int nhs = 0;
+       int err = 0;
+
+       change_nexthops(rt) {
+               int attrlen;
+
+               nla_via = NULL;
+               nla_newdst = NULL;
+
+               err = -EINVAL;
+               if (!rtnh_ok(rtnh, remaining))
+                       goto errout;
+
+               /* neither weighted multipath nor any flags
+                * are supported
+                */
+               if (rtnh->rtnh_hops || rtnh->rtnh_flags)
+                       goto errout;
+
+               attrlen = rtnh_attrlen(rtnh);
+               if (attrlen > 0) {
+                       struct nlattr *attrs = rtnh_attrs(rtnh);
+
+                       nla_via = nla_find(attrs, attrlen, RTA_VIA);
+                       nla_newdst = nla_find(attrs, attrlen, RTA_NEWDST);
+               }
+
+               err = mpls_nh_build(cfg->rc_nlinfo.nl_net, rt, nh,
+                                   rtnh->rtnh_ifindex, nla_via,
+                                   nla_newdst);
+               if (err)
+                       goto errout;
+
+               rtnh = rtnh_next(rtnh, &remaining);
+               nhs++;
+       } endfor_nexthops(rt);
+
+       rt->rt_nhn = nhs;
+
+       return 0;
+
+errout:
+       return err;
+}
+
 static int mpls_route_add(struct mpls_route_config *cfg)
 {
        struct mpls_route __rcu **platform_label;
        struct net *net = cfg->rc_nlinfo.nl_net;
-       struct net_device *dev = NULL;
        struct mpls_route *rt, *old;
-       unsigned index;
-       int i;
        int err = -EINVAL;
+       u8 max_via_alen;
+       unsigned index;
+       int nhs;
 
        index = cfg->rc_label;
 
@@ -345,33 +729,14 @@ static int mpls_route_add(struct mpls_route_config *cfg)
                index = find_free_label(net);
        }
 
-       /* The first 16 labels are reserved, and may not be set */
-       if (index < 16)
+       /* Reserved labels may not be set */
+       if (index < MPLS_LABEL_FIRST_UNRESERVED)
                goto errout;
 
        /* The full 20 bit range may not be supported. */
        if (index >= net->mpls.platform_labels)
                goto errout;
 
-       /* Ensure only a supported number of labels are present */
-       if (cfg->rc_output_labels > MAX_NEW_LABELS)
-               goto errout;
-
-       err = -ENODEV;
-       dev = dev_get_by_index(net, cfg->rc_ifindex);
-       if (!dev)
-               goto errout;
-
-       /* Ensure this is a supported device */
-       err = -EINVAL;
-       if (!mpls_dev_get(dev))
-               goto errout;
-
-       err = -EINVAL;
-       if ((cfg->rc_via_table == NEIGH_LINK_TABLE) &&
-           (dev->addr_len != cfg->rc_via_alen))
-               goto errout;
-
        /* Append makes no sense with mpls */
        err = -EOPNOTSUPP;
        if (cfg->rc_nlflags & NLM_F_APPEND)
@@ -391,27 +756,34 @@ static int mpls_route_add(struct mpls_route_config *cfg)
        if (!(cfg->rc_nlflags & NLM_F_CREATE) && !old)
                goto errout;
 
+       err = -EINVAL;
+       nhs = mpls_count_nexthops(cfg->rc_mp, cfg->rc_mp_len,
+                                 cfg->rc_via_alen, &max_via_alen);
+       if (nhs == 0)
+               goto errout;
+
        err = -ENOMEM;
-       rt = mpls_rt_alloc(cfg->rc_via_alen);
+       rt = mpls_rt_alloc(nhs, max_via_alen);
        if (!rt)
                goto errout;
 
-       rt->rt_labels = cfg->rc_output_labels;
-       for (i = 0; i < rt->rt_labels; i++)
-               rt->rt_label[i] = cfg->rc_output_label[i];
        rt->rt_protocol = cfg->rc_protocol;
-       RCU_INIT_POINTER(rt->rt_dev, dev);
-       rt->rt_via_table = cfg->rc_via_table;
-       memcpy(rt->rt_via, cfg->rc_via, cfg->rc_via_alen);
+       rt->rt_payload_type = cfg->rc_payload_type;
 
-       mpls_route_update(net, index, NULL, rt, &cfg->rc_nlinfo);
+       if (cfg->rc_mp)
+               err = mpls_nh_build_multi(cfg, rt);
+       else
+               err = mpls_nh_build_from_cfg(cfg, rt);
+       if (err)
+               goto freert;
+
+       mpls_route_update(net, index, rt, &cfg->rc_nlinfo);
 
-       dev_put(dev);
        return 0;
 
+freert:
+       mpls_rt_free(rt);
 errout:
-       if (dev)
-               dev_put(dev);
        return err;
 }
 
@@ -423,15 +795,15 @@ static int mpls_route_del(struct mpls_route_config *cfg)
 
        index = cfg->rc_label;
 
-       /* The first 16 labels are reserved, and may not be removed */
-       if (index < 16)
+       /* Reserved labels may not be removed */
+       if (index < MPLS_LABEL_FIRST_UNRESERVED)
                goto errout;
 
        /* The full 20 bit range may not be supported */
        if (index >= net->mpls.platform_labels)
                goto errout;
 
-       mpls_route_update(net, index, NULL, NULL, &cfg->rc_nlinfo);
+       mpls_route_update(net, index, NULL, &cfg->rc_nlinfo);
 
        err = 0;
 errout:
@@ -528,9 +900,11 @@ static void mpls_ifdown(struct net_device *dev)
                struct mpls_route *rt = rtnl_dereference(platform_label[index]);
                if (!rt)
                        continue;
-               if (rtnl_dereference(rt->rt_dev) != dev)
-                       continue;
-               rt->rt_dev = NULL;
+               for_nexthops(rt) {
+                       if (rtnl_dereference(nh->nh_dev) != dev)
+                               continue;
+                       nh->nh_dev = NULL;
+               } endfor_nexthops(rt);
        }
 
        mdev = mpls_dev_get(dev);
@@ -626,9 +1000,10 @@ int nla_put_labels(struct sk_buff *skb, int attrtype,
 
        return 0;
 }
+EXPORT_SYMBOL_GPL(nla_put_labels);
 
 int nla_get_labels(const struct nlattr *nla,
-                  u32 max_labels, u32 *labels, u32 label[])
+                  u32 max_labels, u8 *labels, u32 label[])
 {
        unsigned len = nla_len(nla);
        unsigned nla_labels;
@@ -671,6 +1046,49 @@ int nla_get_labels(const struct nlattr *nla,
        *labels = nla_labels;
        return 0;
 }
+EXPORT_SYMBOL_GPL(nla_get_labels);
+
+int nla_get_via(const struct nlattr *nla, u8 *via_alen,
+               u8 *via_table, u8 via_addr[])
+{
+       struct rtvia *via = nla_data(nla);
+       int err = -EINVAL;
+       int alen;
+
+       if (nla_len(nla) < offsetof(struct rtvia, rtvia_addr))
+               goto errout;
+       alen = nla_len(nla) -
+                       offsetof(struct rtvia, rtvia_addr);
+       if (alen > MAX_VIA_ALEN)
+               goto errout;
+
+       /* Validate the address family */
+       switch (via->rtvia_family) {
+       case AF_PACKET:
+               *via_table = NEIGH_LINK_TABLE;
+               break;
+       case AF_INET:
+               *via_table = NEIGH_ARP_TABLE;
+               if (alen != 4)
+                       goto errout;
+               break;
+       case AF_INET6:
+               *via_table = NEIGH_ND_TABLE;
+               if (alen != 16)
+                       goto errout;
+               break;
+       default:
+               /* Unsupported address family */
+               goto errout;
+       }
+
+       memcpy(via_addr, via->rtvia_addr, alen);
+       *via_alen = alen;
+       err = 0;
+
+errout:
+       return err;
+}
 
 static int rtm_to_route_config(struct sk_buff *skb,  struct nlmsghdr *nlh,
                               struct mpls_route_config *cfg)
@@ -713,6 +1131,7 @@ static int rtm_to_route_config(struct sk_buff *skb,  struct nlmsghdr *nlh,
 
        cfg->rc_label           = LABEL_NOT_SPECIFIED;
        cfg->rc_protocol        = rtm->rtm_protocol;
+       cfg->rc_via_table       = MPLS_NEIGH_TABLE_UNSPEC;
        cfg->rc_nlflags         = nlh->nlmsg_flags;
        cfg->rc_nlinfo.portid   = NETLINK_CB(skb).portid;
        cfg->rc_nlinfo.nlh      = nlh;
@@ -735,48 +1154,28 @@ static int rtm_to_route_config(struct sk_buff *skb,  struct nlmsghdr *nlh,
                        break;
                case RTA_DST:
                {
-                       u32 label_count;
+                       u8 label_count;
                        if (nla_get_labels(nla, 1, &label_count,
                                           &cfg->rc_label))
                                goto errout;
 
-                       /* The first 16 labels are reserved, and may not be set */
-                       if (cfg->rc_label < 16)
+                       /* Reserved labels may not be set */
+                       if (cfg->rc_label < MPLS_LABEL_FIRST_UNRESERVED)
                                goto errout;
 
                        break;
                }
                case RTA_VIA:
                {
-                       struct rtvia *via = nla_data(nla);
-                       if (nla_len(nla) < offsetof(struct rtvia, rtvia_addr))
-                               goto errout;
-                       cfg->rc_via_alen   = nla_len(nla) -
-                               offsetof(struct rtvia, rtvia_addr);
-                       if (cfg->rc_via_alen > MAX_VIA_ALEN)
-                               goto errout;
-
-                       /* Validate the address family */
-                       switch(via->rtvia_family) {
-                       case AF_PACKET:
-                               cfg->rc_via_table = NEIGH_LINK_TABLE;
-                               break;
-                       case AF_INET:
-                               cfg->rc_via_table = NEIGH_ARP_TABLE;
-                               if (cfg->rc_via_alen != 4)
-                                       goto errout;
-                               break;
-                       case AF_INET6:
-                               cfg->rc_via_table = NEIGH_ND_TABLE;
-                               if (cfg->rc_via_alen != 16)
-                                       goto errout;
-                               break;
-                       default:
-                               /* Unsupported address family */
+                       if (nla_get_via(nla, &cfg->rc_via_alen,
+                                       &cfg->rc_via_table, cfg->rc_via))
                                goto errout;
-                       }
-
-                       memcpy(cfg->rc_via, via->rtvia_addr, cfg->rc_via_alen);
+                       break;
+               }
+               case RTA_MULTIPATH:
+               {
+                       cfg->rc_mp = nla_data(nla);
+                       cfg->rc_mp_len = nla_len(nla);
                        break;
                }
                default:
@@ -837,16 +1236,54 @@ static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event,
        rtm->rtm_type = RTN_UNICAST;
        rtm->rtm_flags = 0;
 
-       if (rt->rt_labels &&
-           nla_put_labels(skb, RTA_NEWDST, rt->rt_labels, rt->rt_label))
-               goto nla_put_failure;
-       if (nla_put_via(skb, rt->rt_via_table, rt->rt_via, rt->rt_via_alen))
-               goto nla_put_failure;
-       dev = rtnl_dereference(rt->rt_dev);
-       if (dev && nla_put_u32(skb, RTA_OIF, dev->ifindex))
-               goto nla_put_failure;
        if (nla_put_labels(skb, RTA_DST, 1, &label))
                goto nla_put_failure;
+       if (rt->rt_nhn == 1) {
+               const struct mpls_nh *nh = rt->rt_nh;
+
+               if (nh->nh_labels &&
+                   nla_put_labels(skb, RTA_NEWDST, nh->nh_labels,
+                                  nh->nh_label))
+                       goto nla_put_failure;
+               if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC &&
+                   nla_put_via(skb, nh->nh_via_table, mpls_nh_via(rt, nh),
+                               nh->nh_via_alen))
+                       goto nla_put_failure;
+               dev = rtnl_dereference(nh->nh_dev);
+               if (dev && nla_put_u32(skb, RTA_OIF, dev->ifindex))
+                       goto nla_put_failure;
+       } else {
+               struct rtnexthop *rtnh;
+               struct nlattr *mp;
+
+               mp = nla_nest_start(skb, RTA_MULTIPATH);
+               if (!mp)
+                       goto nla_put_failure;
+
+               for_nexthops(rt) {
+                       rtnh = nla_reserve_nohdr(skb, sizeof(*rtnh));
+                       if (!rtnh)
+                               goto nla_put_failure;
+
+                       dev = rtnl_dereference(nh->nh_dev);
+                       if (dev)
+                               rtnh->rtnh_ifindex = dev->ifindex;
+                       if (nh->nh_labels && nla_put_labels(skb, RTA_NEWDST,
+                                                           nh->nh_labels,
+                                                           nh->nh_label))
+                               goto nla_put_failure;
+                       if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC &&
+                           nla_put_via(skb, nh->nh_via_table,
+                                       mpls_nh_via(rt, nh),
+                                       nh->nh_via_alen))
+                               goto nla_put_failure;
+
+                       /* length of rtnetlink header + attributes */
+                       rtnh->rtnh_len = nlmsg_get_pos(skb) - (void *)rtnh;
+               } endfor_nexthops(rt);
+
+               nla_nest_end(skb, mp);
+       }
 
        nlmsg_end(skb, nlh);
        return 0;
@@ -866,8 +1303,8 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
        ASSERT_RTNL();
 
        index = cb->args[0];
-       if (index < 16)
-               index = 16;
+       if (index < MPLS_LABEL_FIRST_UNRESERVED)
+               index = MPLS_LABEL_FIRST_UNRESERVED;
 
        platform_label = rtnl_dereference(net->mpls.platform_label);
        platform_labels = net->mpls.platform_labels;
@@ -891,12 +1328,33 @@ static inline size_t lfib_nlmsg_size(struct mpls_route *rt)
 {
        size_t payload =
                NLMSG_ALIGN(sizeof(struct rtmsg))
-               + nla_total_size(2 + rt->rt_via_alen)   /* RTA_VIA */
                + nla_total_size(4);                    /* RTA_DST */
-       if (rt->rt_labels)                              /* RTA_NEWDST */
-               payload += nla_total_size(rt->rt_labels * 4);
-       if (rt->rt_dev)                                 /* RTA_OIF */
-               payload += nla_total_size(4);
+
+       if (rt->rt_nhn == 1) {
+               struct mpls_nh *nh = rt->rt_nh;
+
+               if (nh->nh_dev)
+                       payload += nla_total_size(4); /* RTA_OIF */
+               if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC) /* RTA_VIA */
+                       payload += nla_total_size(2 + nh->nh_via_alen);
+               if (nh->nh_labels) /* RTA_NEWDST */
+                       payload += nla_total_size(nh->nh_labels * 4);
+       } else {
+               /* each nexthop is packed in an attribute */
+               size_t nhsize = 0;
+
+               for_nexthops(rt) {
+                       nhsize += nla_total_size(sizeof(struct rtnexthop));
+                       /* RTA_VIA */
+                       if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC)
+                               nhsize += nla_total_size(2 + nh->nh_via_alen);
+                       if (nh->nh_labels)
+                               nhsize += nla_total_size(nh->nh_labels * 4);
+               } endfor_nexthops(rt);
+               /* nested attribute */
+               payload += nla_total_size(nhsize);
+       }
+
        return payload;
 }
 
@@ -948,23 +1406,29 @@ static int resize_platform_label_table(struct net *net, size_t limit)
        /* In case the predefined labels need to be populated */
        if (limit > MPLS_LABEL_IPV4NULL) {
                struct net_device *lo = net->loopback_dev;
-               rt0 = mpls_rt_alloc(lo->addr_len);
+               rt0 = mpls_rt_alloc(1, lo->addr_len);
                if (!rt0)
                        goto nort0;
-               RCU_INIT_POINTER(rt0->rt_dev, lo);
+               RCU_INIT_POINTER(rt0->rt_nh->nh_dev, lo);
                rt0->rt_protocol = RTPROT_KERNEL;
-               rt0->rt_via_table = NEIGH_LINK_TABLE;
-               memcpy(rt0->rt_via, lo->dev_addr, lo->addr_len);
+               rt0->rt_payload_type = MPT_IPV4;
+               rt0->rt_nh->nh_via_table = NEIGH_LINK_TABLE;
+               rt0->rt_nh->nh_via_alen = lo->addr_len;
+               memcpy(__mpls_nh_via(rt0, rt0->rt_nh), lo->dev_addr,
+                      lo->addr_len);
        }
        if (limit > MPLS_LABEL_IPV6NULL) {
                struct net_device *lo = net->loopback_dev;
-               rt2 = mpls_rt_alloc(lo->addr_len);
+               rt2 = mpls_rt_alloc(1, lo->addr_len);
                if (!rt2)
                        goto nort2;
-               RCU_INIT_POINTER(rt2->rt_dev, lo);
+               RCU_INIT_POINTER(rt2->rt_nh->nh_dev, lo);
                rt2->rt_protocol = RTPROT_KERNEL;
-               rt2->rt_via_table = NEIGH_LINK_TABLE;
-               memcpy(rt2->rt_via, lo->dev_addr, lo->addr_len);
+               rt2->rt_payload_type = MPT_IPV6;
+               rt2->rt_nh->nh_via_table = NEIGH_LINK_TABLE;
+               rt2->rt_nh->nh_via_alen = lo->addr_len;
+               memcpy(__mpls_nh_via(rt2, rt2->rt_nh), lo->dev_addr,
+                      lo->addr_len);
        }
 
        rtnl_lock();
@@ -974,7 +1438,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
 
        /* Free any labels beyond the new table */
        for (index = limit; index < old_limit; index++)
-               mpls_route_update(net, index, NULL, NULL, NULL);
+               mpls_route_update(net, index, NULL, NULL);
 
        /* Copy over the old labels */
        cp_size = size;
@@ -1066,8 +1530,10 @@ static int mpls_net_init(struct net *net)
 
        table[0].data = net;
        net->mpls.ctl = register_net_sysctl(net, "net/mpls", table);
-       if (net->mpls.ctl == NULL)
+       if (net->mpls.ctl == NULL) {
+               kfree(table);
                return -ENOMEM;
+       }
 
        return 0;
 }