These changes are the raw update to linux-4.4.6-rt14. Kernel sources
[kvmfornfv.git] / kernel / drivers / infiniband / core / cma.c
index 38ffe09..17a15c5 100644 (file)
 #include <linux/module.h>
 #include <net/route.h>
 
+#include <net/net_namespace.h>
+#include <net/netns/generic.h>
 #include <net/tcp.h>
 #include <net/ipv6.h>
+#include <net/ip_fib.h>
+#include <net/ip6_route.h>
 
 #include <rdma/rdma_cm.h>
 #include <rdma/rdma_cm_ib.h>
@@ -65,8 +69,36 @@ MODULE_LICENSE("Dual BSD/GPL");
 #define CMA_CM_MRA_SETTING (IB_CM_MRA_FLAG_DELAY | 24)
 #define CMA_IBOE_PACKET_LIFETIME 18
 
+static const char * const cma_events[] = {
+       [RDMA_CM_EVENT_ADDR_RESOLVED]    = "address resolved",
+       [RDMA_CM_EVENT_ADDR_ERROR]       = "address error",
+       [RDMA_CM_EVENT_ROUTE_RESOLVED]   = "route resolved ",
+       [RDMA_CM_EVENT_ROUTE_ERROR]      = "route error",
+       [RDMA_CM_EVENT_CONNECT_REQUEST]  = "connect request",
+       [RDMA_CM_EVENT_CONNECT_RESPONSE] = "connect response",
+       [RDMA_CM_EVENT_CONNECT_ERROR]    = "connect error",
+       [RDMA_CM_EVENT_UNREACHABLE]      = "unreachable",
+       [RDMA_CM_EVENT_REJECTED]         = "rejected",
+       [RDMA_CM_EVENT_ESTABLISHED]      = "established",
+       [RDMA_CM_EVENT_DISCONNECTED]     = "disconnected",
+       [RDMA_CM_EVENT_DEVICE_REMOVAL]   = "device removal",
+       [RDMA_CM_EVENT_MULTICAST_JOIN]   = "multicast join",
+       [RDMA_CM_EVENT_MULTICAST_ERROR]  = "multicast error",
+       [RDMA_CM_EVENT_ADDR_CHANGE]      = "address change",
+       [RDMA_CM_EVENT_TIMEWAIT_EXIT]    = "timewait exit",
+};
+
+const char *__attribute_const__ rdma_event_msg(enum rdma_cm_event_type event)
+{
+       size_t index = event;
+
+       return (index < ARRAY_SIZE(cma_events) && cma_events[index]) ?
+                       cma_events[index] : "unrecognized event";
+}
+EXPORT_SYMBOL(rdma_event_msg);
+
 static void cma_add_one(struct ib_device *device);
-static void cma_remove_one(struct ib_device *device);
+static void cma_remove_one(struct ib_device *device, void *client_data);
 
 static struct ib_client cma_client = {
        .name   = "cma",
@@ -80,10 +112,37 @@ static LIST_HEAD(dev_list);
 static LIST_HEAD(listen_any_list);
 static DEFINE_MUTEX(lock);
 static struct workqueue_struct *cma_wq;
-static DEFINE_IDR(tcp_ps);
-static DEFINE_IDR(udp_ps);
-static DEFINE_IDR(ipoib_ps);
-static DEFINE_IDR(ib_ps);
+static int cma_pernet_id;
+
+struct cma_pernet {
+       struct idr tcp_ps;
+       struct idr udp_ps;
+       struct idr ipoib_ps;
+       struct idr ib_ps;
+};
+
+static struct cma_pernet *cma_pernet(struct net *net)
+{
+       return net_generic(net, cma_pernet_id);
+}
+
+static struct idr *cma_pernet_idr(struct net *net, enum rdma_port_space ps)
+{
+       struct cma_pernet *pernet = cma_pernet(net);
+
+       switch (ps) {
+       case RDMA_PS_TCP:
+               return &pernet->tcp_ps;
+       case RDMA_PS_UDP:
+               return &pernet->udp_ps;
+       case RDMA_PS_IPOIB:
+               return &pernet->ipoib_ps;
+       case RDMA_PS_IB:
+               return &pernet->ib_ps;
+       default:
+               return NULL;
+       }
+}
 
 struct cma_device {
        struct list_head        list;
@@ -94,11 +153,34 @@ struct cma_device {
 };
 
 struct rdma_bind_list {
-       struct idr              *ps;
+       enum rdma_port_space    ps;
        struct hlist_head       owners;
        unsigned short          port;
 };
 
+static int cma_ps_alloc(struct net *net, enum rdma_port_space ps,
+                       struct rdma_bind_list *bind_list, int snum)
+{
+       struct idr *idr = cma_pernet_idr(net, ps);
+
+       return idr_alloc(idr, bind_list, snum, snum + 1, GFP_KERNEL);
+}
+
+static struct rdma_bind_list *cma_ps_find(struct net *net,
+                                         enum rdma_port_space ps, int snum)
+{
+       struct idr *idr = cma_pernet_idr(net, ps);
+
+       return idr_find(idr, snum);
+}
+
+static void cma_ps_remove(struct net *net, enum rdma_port_space ps, int snum)
+{
+       struct idr *idr = cma_pernet_idr(net, ps);
+
+       idr_remove(idr, snum);
+}
+
 enum {
        CMA_OPTION_AFONLY,
 };
@@ -197,6 +279,15 @@ struct cma_hdr {
 
 #define CMA_VERSION 0x00
 
+struct cma_req_info {
+       struct ib_device *device;
+       int port;
+       union ib_gid local_gid;
+       __be64 service_id;
+       u16 pkey;
+       bool has_gid:1;
+};
+
 static int cma_comp(struct rdma_id_private *id_priv, enum rdma_cm_state comp)
 {
        unsigned long flags;
@@ -234,7 +325,7 @@ static enum rdma_cm_state cma_exch(struct rdma_id_private *id_priv,
        return old;
 }
 
-static inline u8 cma_get_ip_ver(struct cma_hdr *hdr)
+static inline u8 cma_get_ip_ver(const struct cma_hdr *hdr)
 {
        return hdr->ip_version >> 4;
 }
@@ -349,18 +440,40 @@ static int cma_translate_addr(struct sockaddr *addr, struct rdma_dev_addr *dev_a
        return ret;
 }
 
+static inline int cma_validate_port(struct ib_device *device, u8 port,
+                                     union ib_gid *gid, int dev_type,
+                                     int bound_if_index)
+{
+       int ret = -ENODEV;
+       struct net_device *ndev = NULL;
+
+       if ((dev_type == ARPHRD_INFINIBAND) && !rdma_protocol_ib(device, port))
+               return ret;
+
+       if ((dev_type != ARPHRD_INFINIBAND) && rdma_protocol_ib(device, port))
+               return ret;
+
+       if (dev_type == ARPHRD_ETHER && rdma_protocol_roce(device, port))
+               ndev = dev_get_by_index(&init_net, bound_if_index);
+
+       ret = ib_find_cached_gid_by_port(device, gid, port, ndev, NULL);
+
+       if (ndev)
+               dev_put(ndev);
+
+       return ret;
+}
+
 static int cma_acquire_dev(struct rdma_id_private *id_priv,
                           struct rdma_id_private *listen_id_priv)
 {
        struct rdma_dev_addr *dev_addr = &id_priv->id.route.addr.dev_addr;
        struct cma_device *cma_dev;
-       union ib_gid gid, iboe_gid;
+       union ib_gid gid, iboe_gid, *gidp;
        int ret = -ENODEV;
-       u8 port, found_port;
-       enum rdma_link_layer dev_ll = dev_addr->dev_type == ARPHRD_INFINIBAND ?
-               IB_LINK_LAYER_INFINIBAND : IB_LINK_LAYER_ETHERNET;
+       u8 port;
 
-       if (dev_ll != IB_LINK_LAYER_INFINIBAND &&
+       if (dev_addr->dev_type != ARPHRD_INFINIBAND &&
            id_priv->id.ps == RDMA_PS_IPOIB)
                return -EINVAL;
 
@@ -370,41 +483,38 @@ static int cma_acquire_dev(struct rdma_id_private *id_priv,
 
        memcpy(&gid, dev_addr->src_dev_addr +
               rdma_addr_gid_offset(dev_addr), sizeof gid);
-       if (listen_id_priv &&
-           rdma_port_get_link_layer(listen_id_priv->id.device,
-                                    listen_id_priv->id.port_num) == dev_ll) {
+
+       if (listen_id_priv) {
                cma_dev = listen_id_priv->cma_dev;
                port = listen_id_priv->id.port_num;
-               if (rdma_node_get_transport(cma_dev->device->node_type) == RDMA_TRANSPORT_IB &&
-                   rdma_port_get_link_layer(cma_dev->device, port) == IB_LINK_LAYER_ETHERNET)
-                       ret = ib_find_cached_gid(cma_dev->device, &iboe_gid,
-                                                &found_port, NULL);
-               else
-                       ret = ib_find_cached_gid(cma_dev->device, &gid,
-                                                &found_port, NULL);
-
-               if (!ret && (port  == found_port)) {
-                       id_priv->id.port_num = found_port;
+               gidp = rdma_protocol_roce(cma_dev->device, port) ?
+                      &iboe_gid : &gid;
+
+               ret = cma_validate_port(cma_dev->device, port, gidp,
+                                       dev_addr->dev_type,
+                                       dev_addr->bound_dev_if);
+               if (!ret) {
+                       id_priv->id.port_num = port;
                        goto out;
                }
        }
+
        list_for_each_entry(cma_dev, &dev_list, list) {
                for (port = 1; port <= cma_dev->device->phys_port_cnt; ++port) {
                        if (listen_id_priv &&
                            listen_id_priv->cma_dev == cma_dev &&
                            listen_id_priv->id.port_num == port)
                                continue;
-                       if (rdma_port_get_link_layer(cma_dev->device, port) == dev_ll) {
-                               if (rdma_node_get_transport(cma_dev->device->node_type) == RDMA_TRANSPORT_IB &&
-                                   rdma_port_get_link_layer(cma_dev->device, port) == IB_LINK_LAYER_ETHERNET)
-                                       ret = ib_find_cached_gid(cma_dev->device, &iboe_gid, &found_port, NULL);
-                               else
-                                       ret = ib_find_cached_gid(cma_dev->device, &gid, &found_port, NULL);
-
-                               if (!ret && (port == found_port)) {
-                                       id_priv->id.port_num = found_port;
-                                       goto out;
-                               }
+
+                       gidp = rdma_protocol_roce(cma_dev->device, port) ?
+                              &iboe_gid : &gid;
+
+                       ret = cma_validate_port(cma_dev->device, port, gidp,
+                                               dev_addr->dev_type,
+                                               dev_addr->bound_dev_if);
+                       if (!ret) {
+                               id_priv->id.port_num = port;
+                               goto out;
                        }
                }
        }
@@ -435,14 +545,16 @@ static int cma_resolve_ib_dev(struct rdma_id_private *id_priv)
        pkey = ntohs(addr->sib_pkey);
 
        list_for_each_entry(cur_dev, &dev_list, list) {
-               if (rdma_node_get_transport(cur_dev->device->node_type) != RDMA_TRANSPORT_IB)
-                       continue;
-
                for (p = 1; p <= cur_dev->device->phys_port_cnt; ++p) {
+                       if (!rdma_cap_af_ib(cur_dev->device, p))
+                               continue;
+
                        if (ib_find_cached_pkey(cur_dev->device, p, pkey, &index))
                                continue;
 
-                       for (i = 0; !ib_get_cached_gid(cur_dev->device, p, i, &gid); i++) {
+                       for (i = 0; !ib_get_cached_gid(cur_dev->device, p, i,
+                                                      &gid, NULL);
+                            i++) {
                                if (!memcmp(&gid, dgid, sizeof(gid))) {
                                        cma_dev = cur_dev;
                                        sgid = gid;
@@ -488,7 +600,8 @@ static int cma_disable_callback(struct rdma_id_private *id_priv,
        return 0;
 }
 
-struct rdma_cm_id *rdma_create_id(rdma_cm_event_handler event_handler,
+struct rdma_cm_id *rdma_create_id(struct net *net,
+                                 rdma_cm_event_handler event_handler,
                                  void *context, enum rdma_port_space ps,
                                  enum ib_qp_type qp_type)
 {
@@ -512,6 +625,7 @@ struct rdma_cm_id *rdma_create_id(rdma_cm_event_handler event_handler,
        INIT_LIST_HEAD(&id_priv->listen_list);
        INIT_LIST_HEAD(&id_priv->mc_list);
        get_random_bytes(&id_priv->seq_num, sizeof id_priv->seq_num);
+       id_priv->id.route.addr.dev_addr.net = get_net(net);
 
        return &id_priv->id;
 }
@@ -629,19 +743,12 @@ static int cma_modify_qp_rtr(struct rdma_id_private *id_priv,
                goto out;
 
        ret = ib_query_gid(id_priv->id.device, id_priv->id.port_num,
-                          qp_attr.ah_attr.grh.sgid_index, &sgid);
+                          qp_attr.ah_attr.grh.sgid_index, &sgid, NULL);
        if (ret)
                goto out;
 
-       if (rdma_node_get_transport(id_priv->cma_dev->device->node_type)
-           == RDMA_TRANSPORT_IB &&
-           rdma_port_get_link_layer(id_priv->id.device, id_priv->id.port_num)
-           == IB_LINK_LAYER_ETHERNET) {
-               ret = rdma_addr_find_smac_by_sgid(&sgid, qp_attr.smac, NULL);
+       BUG_ON(id_priv->cma_dev->device != id_priv->id.device);
 
-               if (ret)
-                       goto out;
-       }
        if (conn_param)
                qp_attr.max_dest_rd_atomic = conn_param->responder_resources;
        ret = ib_modify_qp(id_priv->id.qp, &qp_attr, qp_attr_mask);
@@ -700,11 +807,10 @@ static int cma_ib_init_qp_attr(struct rdma_id_private *id_priv,
        int ret;
        u16 pkey;
 
-       if (rdma_port_get_link_layer(id_priv->id.device, id_priv->id.port_num) ==
-           IB_LINK_LAYER_INFINIBAND)
-               pkey = ib_addr_get_pkey(dev_addr);
-       else
+       if (rdma_cap_eth_ah(id_priv->id.device, id_priv->id.port_num))
                pkey = 0xffff;
+       else
+               pkey = ib_addr_get_pkey(dev_addr);
 
        ret = ib_find_cached_pkey(id_priv->id.device, id_priv->id.port_num,
                                  pkey, &qp_attr->pkey_index);
@@ -735,8 +841,7 @@ int rdma_init_qp_attr(struct rdma_cm_id *id, struct ib_qp_attr *qp_attr,
        int ret = 0;
 
        id_priv = container_of(id, struct rdma_id_private, id);
-       switch (rdma_node_get_transport(id_priv->id.device->node_type)) {
-       case RDMA_TRANSPORT_IB:
+       if (rdma_cap_ib_cm(id->device, id->port_num)) {
                if (!id_priv->cm_id.ib || (id_priv->id.qp_type == IB_QPT_UD))
                        ret = cma_ib_init_qp_attr(id_priv, qp_attr, qp_attr_mask);
                else
@@ -745,19 +850,15 @@ int rdma_init_qp_attr(struct rdma_cm_id *id, struct ib_qp_attr *qp_attr,
 
                if (qp_attr->qp_state == IB_QPS_RTR)
                        qp_attr->rq_psn = id_priv->seq_num;
-               break;
-       case RDMA_TRANSPORT_IWARP:
+       } else if (rdma_cap_iw_cm(id->device, id->port_num)) {
                if (!id_priv->cm_id.iw) {
                        qp_attr->qp_access_flags = 0;
                        *qp_attr_mask = IB_QP_STATE | IB_QP_ACCESS_FLAGS;
                } else
                        ret = iw_cm_init_qp_attr(id_priv->cm_id.iw, qp_attr,
                                                 qp_attr_mask);
-               break;
-       default:
+       } else
                ret = -ENOSYS;
-               break;
-       }
 
        return ret;
 }
@@ -837,107 +938,419 @@ static inline int cma_any_port(struct sockaddr *addr)
        return !cma_port(addr);
 }
 
-static void cma_save_ib_info(struct rdma_cm_id *id, struct rdma_cm_id *listen_id,
+static void cma_save_ib_info(struct sockaddr *src_addr,
+                            struct sockaddr *dst_addr,
+                            struct rdma_cm_id *listen_id,
                             struct ib_sa_path_rec *path)
 {
        struct sockaddr_ib *listen_ib, *ib;
 
        listen_ib = (struct sockaddr_ib *) &listen_id->route.addr.src_addr;
-       ib = (struct sockaddr_ib *) &id->route.addr.src_addr;
-       ib->sib_family = listen_ib->sib_family;
-       if (path) {
-               ib->sib_pkey = path->pkey;
-               ib->sib_flowinfo = path->flow_label;
-               memcpy(&ib->sib_addr, &path->sgid, 16);
-       } else {
-               ib->sib_pkey = listen_ib->sib_pkey;
-               ib->sib_flowinfo = listen_ib->sib_flowinfo;
-               ib->sib_addr = listen_ib->sib_addr;
+       if (src_addr) {
+               ib = (struct sockaddr_ib *)src_addr;
+               ib->sib_family = AF_IB;
+               if (path) {
+                       ib->sib_pkey = path->pkey;
+                       ib->sib_flowinfo = path->flow_label;
+                       memcpy(&ib->sib_addr, &path->sgid, 16);
+                       ib->sib_sid = path->service_id;
+                       ib->sib_scope_id = 0;
+               } else {
+                       ib->sib_pkey = listen_ib->sib_pkey;
+                       ib->sib_flowinfo = listen_ib->sib_flowinfo;
+                       ib->sib_addr = listen_ib->sib_addr;
+                       ib->sib_sid = listen_ib->sib_sid;
+                       ib->sib_scope_id = listen_ib->sib_scope_id;
+               }
+               ib->sib_sid_mask = cpu_to_be64(0xffffffffffffffffULL);
        }
-       ib->sib_sid = listen_ib->sib_sid;
-       ib->sib_sid_mask = cpu_to_be64(0xffffffffffffffffULL);
-       ib->sib_scope_id = listen_ib->sib_scope_id;
-
-       if (path) {
-               ib = (struct sockaddr_ib *) &id->route.addr.dst_addr;
-               ib->sib_family = listen_ib->sib_family;
-               ib->sib_pkey = path->pkey;
-               ib->sib_flowinfo = path->flow_label;
-               memcpy(&ib->sib_addr, &path->dgid, 16);
+       if (dst_addr) {
+               ib = (struct sockaddr_ib *)dst_addr;
+               ib->sib_family = AF_IB;
+               if (path) {
+                       ib->sib_pkey = path->pkey;
+                       ib->sib_flowinfo = path->flow_label;
+                       memcpy(&ib->sib_addr, &path->dgid, 16);
+               }
        }
 }
 
-static __be16 ss_get_port(const struct sockaddr_storage *ss)
-{
-       if (ss->ss_family == AF_INET)
-               return ((struct sockaddr_in *)ss)->sin_port;
-       else if (ss->ss_family == AF_INET6)
-               return ((struct sockaddr_in6 *)ss)->sin6_port;
-       BUG();
-}
-
-static void cma_save_ip4_info(struct rdma_cm_id *id, struct rdma_cm_id *listen_id,
-                             struct cma_hdr *hdr)
+static void cma_save_ip4_info(struct sockaddr *src_addr,
+                             struct sockaddr *dst_addr,
+                             struct cma_hdr *hdr,
+                             __be16 local_port)
 {
        struct sockaddr_in *ip4;
 
-       ip4 = (struct sockaddr_in *) &id->route.addr.src_addr;
-       ip4->sin_family = AF_INET;
-       ip4->sin_addr.s_addr = hdr->dst_addr.ip4.addr;
-       ip4->sin_port = ss_get_port(&listen_id->route.addr.src_addr);
+       if (src_addr) {
+               ip4 = (struct sockaddr_in *)src_addr;
+               ip4->sin_family = AF_INET;
+               ip4->sin_addr.s_addr = hdr->dst_addr.ip4.addr;
+               ip4->sin_port = local_port;
+       }
 
-       ip4 = (struct sockaddr_in *) &id->route.addr.dst_addr;
-       ip4->sin_family = AF_INET;
-       ip4->sin_addr.s_addr = hdr->src_addr.ip4.addr;
-       ip4->sin_port = hdr->port;
+       if (dst_addr) {
+               ip4 = (struct sockaddr_in *)dst_addr;
+               ip4->sin_family = AF_INET;
+               ip4->sin_addr.s_addr = hdr->src_addr.ip4.addr;
+               ip4->sin_port = hdr->port;
+       }
 }
 
-static void cma_save_ip6_info(struct rdma_cm_id *id, struct rdma_cm_id *listen_id,
-                             struct cma_hdr *hdr)
+static void cma_save_ip6_info(struct sockaddr *src_addr,
+                             struct sockaddr *dst_addr,
+                             struct cma_hdr *hdr,
+                             __be16 local_port)
 {
        struct sockaddr_in6 *ip6;
 
-       ip6 = (struct sockaddr_in6 *) &id->route.addr.src_addr;
-       ip6->sin6_family = AF_INET6;
-       ip6->sin6_addr = hdr->dst_addr.ip6;
-       ip6->sin6_port = ss_get_port(&listen_id->route.addr.src_addr);
+       if (src_addr) {
+               ip6 = (struct sockaddr_in6 *)src_addr;
+               ip6->sin6_family = AF_INET6;
+               ip6->sin6_addr = hdr->dst_addr.ip6;
+               ip6->sin6_port = local_port;
+       }
 
-       ip6 = (struct sockaddr_in6 *) &id->route.addr.dst_addr;
-       ip6->sin6_family = AF_INET6;
-       ip6->sin6_addr = hdr->src_addr.ip6;
-       ip6->sin6_port = hdr->port;
+       if (dst_addr) {
+               ip6 = (struct sockaddr_in6 *)dst_addr;
+               ip6->sin6_family = AF_INET6;
+               ip6->sin6_addr = hdr->src_addr.ip6;
+               ip6->sin6_port = hdr->port;
+       }
 }
 
-static int cma_save_net_info(struct rdma_cm_id *id, struct rdma_cm_id *listen_id,
-                            struct ib_cm_event *ib_event)
+static u16 cma_port_from_service_id(__be64 service_id)
 {
-       struct cma_hdr *hdr;
+       return (u16)be64_to_cpu(service_id);
+}
 
-       if (listen_id->route.addr.src_addr.ss_family == AF_IB) {
-               if (ib_event->event == IB_CM_REQ_RECEIVED)
-                       cma_save_ib_info(id, listen_id, ib_event->param.req_rcvd.primary_path);
-               else if (ib_event->event == IB_CM_SIDR_REQ_RECEIVED)
-                       cma_save_ib_info(id, listen_id, NULL);
-               return 0;
-       }
+static int cma_save_ip_info(struct sockaddr *src_addr,
+                           struct sockaddr *dst_addr,
+                           struct ib_cm_event *ib_event,
+                           __be64 service_id)
+{
+       struct cma_hdr *hdr;
+       __be16 port;
 
        hdr = ib_event->private_data;
        if (hdr->cma_version != CMA_VERSION)
                return -EINVAL;
 
+       port = htons(cma_port_from_service_id(service_id));
+
        switch (cma_get_ip_ver(hdr)) {
        case 4:
-               cma_save_ip4_info(id, listen_id, hdr);
+               cma_save_ip4_info(src_addr, dst_addr, hdr, port);
                break;
        case 6:
-               cma_save_ip6_info(id, listen_id, hdr);
+               cma_save_ip6_info(src_addr, dst_addr, hdr, port);
+               break;
+       default:
+               return -EAFNOSUPPORT;
+       }
+
+       return 0;
+}
+
+static int cma_save_net_info(struct sockaddr *src_addr,
+                            struct sockaddr *dst_addr,
+                            struct rdma_cm_id *listen_id,
+                            struct ib_cm_event *ib_event,
+                            sa_family_t sa_family, __be64 service_id)
+{
+       if (sa_family == AF_IB) {
+               if (ib_event->event == IB_CM_REQ_RECEIVED)
+                       cma_save_ib_info(src_addr, dst_addr, listen_id,
+                                        ib_event->param.req_rcvd.primary_path);
+               else if (ib_event->event == IB_CM_SIDR_REQ_RECEIVED)
+                       cma_save_ib_info(src_addr, dst_addr, listen_id, NULL);
+               return 0;
+       }
+
+       return cma_save_ip_info(src_addr, dst_addr, ib_event, service_id);
+}
+
+static int cma_save_req_info(const struct ib_cm_event *ib_event,
+                            struct cma_req_info *req)
+{
+       const struct ib_cm_req_event_param *req_param =
+               &ib_event->param.req_rcvd;
+       const struct ib_cm_sidr_req_event_param *sidr_param =
+               &ib_event->param.sidr_req_rcvd;
+
+       switch (ib_event->event) {
+       case IB_CM_REQ_RECEIVED:
+               req->device     = req_param->listen_id->device;
+               req->port       = req_param->port;
+               memcpy(&req->local_gid, &req_param->primary_path->sgid,
+                      sizeof(req->local_gid));
+               req->has_gid    = true;
+               req->service_id = req_param->primary_path->service_id;
+               req->pkey       = be16_to_cpu(req_param->primary_path->pkey);
+               break;
+       case IB_CM_SIDR_REQ_RECEIVED:
+               req->device     = sidr_param->listen_id->device;
+               req->port       = sidr_param->port;
+               req->has_gid    = false;
+               req->service_id = sidr_param->service_id;
+               req->pkey       = sidr_param->pkey;
                break;
        default:
                return -EINVAL;
        }
+
        return 0;
 }
 
+static bool validate_ipv4_net_dev(struct net_device *net_dev,
+                                 const struct sockaddr_in *dst_addr,
+                                 const struct sockaddr_in *src_addr)
+{
+       __be32 daddr = dst_addr->sin_addr.s_addr,
+              saddr = src_addr->sin_addr.s_addr;
+       struct fib_result res;
+       struct flowi4 fl4;
+       int err;
+       bool ret;
+
+       if (ipv4_is_multicast(saddr) || ipv4_is_lbcast(saddr) ||
+           ipv4_is_lbcast(daddr) || ipv4_is_zeronet(saddr) ||
+           ipv4_is_zeronet(daddr) || ipv4_is_loopback(daddr) ||
+           ipv4_is_loopback(saddr))
+               return false;
+
+       memset(&fl4, 0, sizeof(fl4));
+       fl4.flowi4_iif = net_dev->ifindex;
+       fl4.daddr = daddr;
+       fl4.saddr = saddr;
+
+       rcu_read_lock();
+       err = fib_lookup(dev_net(net_dev), &fl4, &res, 0);
+       ret = err == 0 && FIB_RES_DEV(res) == net_dev;
+       rcu_read_unlock();
+
+       return ret;
+}
+
+static bool validate_ipv6_net_dev(struct net_device *net_dev,
+                                 const struct sockaddr_in6 *dst_addr,
+                                 const struct sockaddr_in6 *src_addr)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+       const int strict = ipv6_addr_type(&dst_addr->sin6_addr) &
+                          IPV6_ADDR_LINKLOCAL;
+       struct rt6_info *rt = rt6_lookup(dev_net(net_dev), &dst_addr->sin6_addr,
+                                        &src_addr->sin6_addr, net_dev->ifindex,
+                                        strict);
+       bool ret;
+
+       if (!rt)
+               return false;
+
+       ret = rt->rt6i_idev->dev == net_dev;
+       ip6_rt_put(rt);
+
+       return ret;
+#else
+       return false;
+#endif
+}
+
+static bool validate_net_dev(struct net_device *net_dev,
+                            const struct sockaddr *daddr,
+                            const struct sockaddr *saddr)
+{
+       const struct sockaddr_in *daddr4 = (const struct sockaddr_in *)daddr;
+       const struct sockaddr_in *saddr4 = (const struct sockaddr_in *)saddr;
+       const struct sockaddr_in6 *daddr6 = (const struct sockaddr_in6 *)daddr;
+       const struct sockaddr_in6 *saddr6 = (const struct sockaddr_in6 *)saddr;
+
+       switch (daddr->sa_family) {
+       case AF_INET:
+               return saddr->sa_family == AF_INET &&
+                      validate_ipv4_net_dev(net_dev, daddr4, saddr4);
+
+       case AF_INET6:
+               return saddr->sa_family == AF_INET6 &&
+                      validate_ipv6_net_dev(net_dev, daddr6, saddr6);
+
+       default:
+               return false;
+       }
+}
+
+static struct net_device *cma_get_net_dev(struct ib_cm_event *ib_event,
+                                         const struct cma_req_info *req)
+{
+       struct sockaddr_storage listen_addr_storage, src_addr_storage;
+       struct sockaddr *listen_addr = (struct sockaddr *)&listen_addr_storage,
+                       *src_addr = (struct sockaddr *)&src_addr_storage;
+       struct net_device *net_dev;
+       const union ib_gid *gid = req->has_gid ? &req->local_gid : NULL;
+       int err;
+
+       err = cma_save_ip_info(listen_addr, src_addr, ib_event,
+                              req->service_id);
+       if (err)
+               return ERR_PTR(err);
+
+       net_dev = ib_get_net_dev_by_params(req->device, req->port, req->pkey,
+                                          gid, listen_addr);
+       if (!net_dev)
+               return ERR_PTR(-ENODEV);
+
+       if (!validate_net_dev(net_dev, listen_addr, src_addr)) {
+               dev_put(net_dev);
+               return ERR_PTR(-EHOSTUNREACH);
+       }
+
+       return net_dev;
+}
+
+static enum rdma_port_space rdma_ps_from_service_id(__be64 service_id)
+{
+       return (be64_to_cpu(service_id) >> 16) & 0xffff;
+}
+
+static bool cma_match_private_data(struct rdma_id_private *id_priv,
+                                  const struct cma_hdr *hdr)
+{
+       struct sockaddr *addr = cma_src_addr(id_priv);
+       __be32 ip4_addr;
+       struct in6_addr ip6_addr;
+
+       if (cma_any_addr(addr) && !id_priv->afonly)
+               return true;
+
+       switch (addr->sa_family) {
+       case AF_INET:
+               ip4_addr = ((struct sockaddr_in *)addr)->sin_addr.s_addr;
+               if (cma_get_ip_ver(hdr) != 4)
+                       return false;
+               if (!cma_any_addr(addr) &&
+                   hdr->dst_addr.ip4.addr != ip4_addr)
+                       return false;
+               break;
+       case AF_INET6:
+               ip6_addr = ((struct sockaddr_in6 *)addr)->sin6_addr;
+               if (cma_get_ip_ver(hdr) != 6)
+                       return false;
+               if (!cma_any_addr(addr) &&
+                   memcmp(&hdr->dst_addr.ip6, &ip6_addr, sizeof(ip6_addr)))
+                       return false;
+               break;
+       case AF_IB:
+               return true;
+       default:
+               return false;
+       }
+
+       return true;
+}
+
+static bool cma_protocol_roce_dev_port(struct ib_device *device, int port_num)
+{
+       enum rdma_link_layer ll = rdma_port_get_link_layer(device, port_num);
+       enum rdma_transport_type transport =
+               rdma_node_get_transport(device->node_type);
+
+       return ll == IB_LINK_LAYER_ETHERNET && transport == RDMA_TRANSPORT_IB;
+}
+
+static bool cma_protocol_roce(const struct rdma_cm_id *id)
+{
+       struct ib_device *device = id->device;
+       const int port_num = id->port_num ?: rdma_start_port(device);
+
+       return cma_protocol_roce_dev_port(device, port_num);
+}
+
+static bool cma_match_net_dev(const struct rdma_cm_id *id,
+                             const struct net_device *net_dev,
+                             u8 port_num)
+{
+       const struct rdma_addr *addr = &id->route.addr;
+
+       if (!net_dev)
+               /* This request is an AF_IB request or a RoCE request */
+               return (!id->port_num || id->port_num == port_num) &&
+                      (addr->src_addr.ss_family == AF_IB ||
+                       cma_protocol_roce_dev_port(id->device, port_num));
+
+       return !addr->dev_addr.bound_dev_if ||
+              (net_eq(dev_net(net_dev), addr->dev_addr.net) &&
+               addr->dev_addr.bound_dev_if == net_dev->ifindex);
+}
+
+static struct rdma_id_private *cma_find_listener(
+               const struct rdma_bind_list *bind_list,
+               const struct ib_cm_id *cm_id,
+               const struct ib_cm_event *ib_event,
+               const struct cma_req_info *req,
+               const struct net_device *net_dev)
+{
+       struct rdma_id_private *id_priv, *id_priv_dev;
+
+       if (!bind_list)
+               return ERR_PTR(-EINVAL);
+
+       hlist_for_each_entry(id_priv, &bind_list->owners, node) {
+               if (cma_match_private_data(id_priv, ib_event->private_data)) {
+                       if (id_priv->id.device == cm_id->device &&
+                           cma_match_net_dev(&id_priv->id, net_dev, req->port))
+                               return id_priv;
+                       list_for_each_entry(id_priv_dev,
+                                           &id_priv->listen_list,
+                                           listen_list) {
+                               if (id_priv_dev->id.device == cm_id->device &&
+                                   cma_match_net_dev(&id_priv_dev->id, net_dev, req->port))
+                                       return id_priv_dev;
+                       }
+               }
+       }
+
+       return ERR_PTR(-EINVAL);
+}
+
+static struct rdma_id_private *cma_id_from_event(struct ib_cm_id *cm_id,
+                                                struct ib_cm_event *ib_event,
+                                                struct net_device **net_dev)
+{
+       struct cma_req_info req;
+       struct rdma_bind_list *bind_list;
+       struct rdma_id_private *id_priv;
+       int err;
+
+       err = cma_save_req_info(ib_event, &req);
+       if (err)
+               return ERR_PTR(err);
+
+       *net_dev = cma_get_net_dev(ib_event, &req);
+       if (IS_ERR(*net_dev)) {
+               if (PTR_ERR(*net_dev) == -EAFNOSUPPORT) {
+                       /* Assuming the protocol is AF_IB */
+                       *net_dev = NULL;
+               } else if (cma_protocol_roce_dev_port(req.device, req.port)) {
+                       /* TODO find the net dev matching the request parameters
+                        * through the RoCE GID table */
+                       *net_dev = NULL;
+               } else {
+                       return ERR_CAST(*net_dev);
+               }
+       }
+
+       bind_list = cma_ps_find(*net_dev ? dev_net(*net_dev) : &init_net,
+                               rdma_ps_from_service_id(req.service_id),
+                               cma_port_from_service_id(req.service_id));
+       id_priv = cma_find_listener(bind_list, cm_id, ib_event, &req, *net_dev);
+       if (IS_ERR(id_priv) && *net_dev) {
+               dev_put(*net_dev);
+               *net_dev = NULL;
+       }
+
+       return id_priv;
+}
+
 static inline int cma_user_data_offset(struct rdma_id_private *id_priv)
 {
        return cma_family(id_priv) == AF_IB ? 0 : sizeof(struct cma_hdr);
@@ -945,13 +1358,9 @@ static inline int cma_user_data_offset(struct rdma_id_private *id_priv)
 
 static void cma_cancel_route(struct rdma_id_private *id_priv)
 {
-       switch (rdma_port_get_link_layer(id_priv->id.device, id_priv->id.port_num)) {
-       case IB_LINK_LAYER_INFINIBAND:
+       if (rdma_cap_ib_sa(id_priv->id.device, id_priv->id.port_num)) {
                if (id_priv->query)
                        ib_sa_cancel_query(id_priv->query_id, id_priv->query);
-               break;
-       default:
-               break;
        }
 }
 
@@ -1002,6 +1411,7 @@ static void cma_cancel_operation(struct rdma_id_private *id_priv,
 static void cma_release_port(struct rdma_id_private *id_priv)
 {
        struct rdma_bind_list *bind_list = id_priv->bind_list;
+       struct net *net = id_priv->id.route.addr.dev_addr.net;
 
        if (!bind_list)
                return;
@@ -1009,7 +1419,7 @@ static void cma_release_port(struct rdma_id_private *id_priv)
        mutex_lock(&lock);
        hlist_del(&id_priv->node);
        if (hlist_empty(&bind_list->owners)) {
-               idr_remove(bind_list->ps, bind_list->port);
+               cma_ps_remove(net, bind_list->ps, bind_list->port);
                kfree(bind_list);
        }
        mutex_unlock(&lock);
@@ -1023,17 +1433,12 @@ static void cma_leave_mc_groups(struct rdma_id_private *id_priv)
                mc = container_of(id_priv->mc_list.next,
                                  struct cma_multicast, list);
                list_del(&mc->list);
-               switch (rdma_port_get_link_layer(id_priv->cma_dev->device, id_priv->id.port_num)) {
-               case IB_LINK_LAYER_INFINIBAND:
+               if (rdma_cap_ib_mcast(id_priv->cma_dev->device,
+                                     id_priv->id.port_num)) {
                        ib_sa_free_multicast(mc->multicast.ib);
                        kfree(mc);
-                       break;
-               case IB_LINK_LAYER_ETHERNET:
+               } else
                        kref_put(&mc->mcref, release_mc);
-                       break;
-               default:
-                       break;
-               }
        }
 }
 
@@ -1054,17 +1459,12 @@ void rdma_destroy_id(struct rdma_cm_id *id)
        mutex_unlock(&id_priv->handler_mutex);
 
        if (id_priv->cma_dev) {
-               switch (rdma_node_get_transport(id_priv->id.device->node_type)) {
-               case RDMA_TRANSPORT_IB:
+               if (rdma_cap_ib_cm(id_priv->id.device, 1)) {
                        if (id_priv->cm_id.ib)
                                ib_destroy_cm_id(id_priv->cm_id.ib);
-                       break;
-               case RDMA_TRANSPORT_IWARP:
+               } else if (rdma_cap_iw_cm(id_priv->id.device, 1)) {
                        if (id_priv->cm_id.iw)
                                iw_destroy_cm_id(id_priv->cm_id.iw);
-                       break;
-               default:
-                       break;
                }
                cma_leave_mc_groups(id_priv);
                cma_release_dev(id_priv);
@@ -1078,6 +1478,7 @@ void rdma_destroy_id(struct rdma_cm_id *id)
                cma_deref_id(id_priv->id.context);
 
        kfree(id_priv->id.route.path_rec);
+       put_net(id_priv->id.route.addr.dev_addr.net);
        kfree(id_priv);
 }
 EXPORT_SYMBOL(rdma_destroy_id);
@@ -1197,20 +1598,27 @@ out:
 }
 
 static struct rdma_id_private *cma_new_conn_id(struct rdma_cm_id *listen_id,
-                                              struct ib_cm_event *ib_event)
+                                              struct ib_cm_event *ib_event,
+                                              struct net_device *net_dev)
 {
        struct rdma_id_private *id_priv;
        struct rdma_cm_id *id;
        struct rdma_route *rt;
+       const sa_family_t ss_family = listen_id->route.addr.src_addr.ss_family;
+       const __be64 service_id =
+                     ib_event->param.req_rcvd.primary_path->service_id;
        int ret;
 
-       id = rdma_create_id(listen_id->event_handler, listen_id->context,
+       id = rdma_create_id(listen_id->route.addr.dev_addr.net,
+                           listen_id->event_handler, listen_id->context,
                            listen_id->ps, ib_event->param.req_rcvd.qp_type);
        if (IS_ERR(id))
                return NULL;
 
        id_priv = container_of(id, struct rdma_id_private, id);
-       if (cma_save_net_info(id, listen_id, ib_event))
+       if (cma_save_net_info((struct sockaddr *)&id->route.addr.src_addr,
+                             (struct sockaddr *)&id->route.addr.dst_addr,
+                             listen_id, ib_event, ss_family, service_id))
                goto err;
 
        rt = &id->route;
@@ -1224,14 +1632,21 @@ static struct rdma_id_private *cma_new_conn_id(struct rdma_cm_id *listen_id,
        if (rt->num_paths == 2)
                rt->path_rec[1] = *ib_event->param.req_rcvd.alternate_path;
 
-       if (cma_any_addr(cma_src_addr(id_priv))) {
-               rt->addr.dev_addr.dev_type = ARPHRD_INFINIBAND;
-               rdma_addr_set_sgid(&rt->addr.dev_addr, &rt->path_rec[0].sgid);
-               ib_addr_set_pkey(&rt->addr.dev_addr, be16_to_cpu(rt->path_rec[0].pkey));
-       } else {
-               ret = cma_translate_addr(cma_src_addr(id_priv), &rt->addr.dev_addr);
+       if (net_dev) {
+               ret = rdma_copy_addr(&rt->addr.dev_addr, net_dev, NULL);
                if (ret)
                        goto err;
+       } else {
+               if (!cma_protocol_roce(listen_id) &&
+                   cma_any_addr(cma_src_addr(id_priv))) {
+                       rt->addr.dev_addr.dev_type = ARPHRD_INFINIBAND;
+                       rdma_addr_set_sgid(&rt->addr.dev_addr, &rt->path_rec[0].sgid);
+                       ib_addr_set_pkey(&rt->addr.dev_addr, be16_to_cpu(rt->path_rec[0].pkey));
+               } else if (!cma_any_addr(cma_src_addr(id_priv))) {
+                       ret = cma_translate_addr(cma_src_addr(id_priv), &rt->addr.dev_addr);
+                       if (ret)
+                               goto err;
+               }
        }
        rdma_addr_set_dgid(&rt->addr.dev_addr, &rt->path_rec[0].dgid);
 
@@ -1244,25 +1659,38 @@ err:
 }
 
 static struct rdma_id_private *cma_new_udp_id(struct rdma_cm_id *listen_id,
-                                             struct ib_cm_event *ib_event)
+                                             struct ib_cm_event *ib_event,
+                                             struct net_device *net_dev)
 {
        struct rdma_id_private *id_priv;
        struct rdma_cm_id *id;
+       const sa_family_t ss_family = listen_id->route.addr.src_addr.ss_family;
+       struct net *net = listen_id->route.addr.dev_addr.net;
        int ret;
 
-       id = rdma_create_id(listen_id->event_handler, listen_id->context,
+       id = rdma_create_id(net, listen_id->event_handler, listen_id->context,
                            listen_id->ps, IB_QPT_UD);
        if (IS_ERR(id))
                return NULL;
 
        id_priv = container_of(id, struct rdma_id_private, id);
-       if (cma_save_net_info(id, listen_id, ib_event))
+       if (cma_save_net_info((struct sockaddr *)&id->route.addr.src_addr,
+                             (struct sockaddr *)&id->route.addr.dst_addr,
+                             listen_id, ib_event, ss_family,
+                             ib_event->param.sidr_req_rcvd.service_id))
                goto err;
 
-       if (!cma_any_addr((struct sockaddr *) &id->route.addr.src_addr)) {
-               ret = cma_translate_addr(cma_src_addr(id_priv), &id->route.addr.dev_addr);
+       if (net_dev) {
+               ret = rdma_copy_addr(&id->route.addr.dev_addr, net_dev, NULL);
                if (ret)
                        goto err;
+       } else {
+               if (!cma_any_addr(cma_src_addr(id_priv))) {
+                       ret = cma_translate_addr(cma_src_addr(id_priv),
+                                                &id->route.addr.dev_addr);
+                       if (ret)
+                               goto err;
+               }
        }
 
        id_priv->state = RDMA_CM_CONNECT;
@@ -1300,25 +1728,33 @@ static int cma_req_handler(struct ib_cm_id *cm_id, struct ib_cm_event *ib_event)
 {
        struct rdma_id_private *listen_id, *conn_id;
        struct rdma_cm_event event;
+       struct net_device *net_dev;
        int offset, ret;
 
-       listen_id = cm_id->context;
-       if (!cma_check_req_qp_type(&listen_id->id, ib_event))
-               return -EINVAL;
+       listen_id = cma_id_from_event(cm_id, ib_event, &net_dev);
+       if (IS_ERR(listen_id))
+               return PTR_ERR(listen_id);
 
-       if (cma_disable_callback(listen_id, RDMA_CM_LISTEN))
-               return -ECONNABORTED;
+       if (!cma_check_req_qp_type(&listen_id->id, ib_event)) {
+               ret = -EINVAL;
+               goto net_dev_put;
+       }
+
+       if (cma_disable_callback(listen_id, RDMA_CM_LISTEN)) {
+               ret = -ECONNABORTED;
+               goto net_dev_put;
+       }
 
        memset(&event, 0, sizeof event);
        offset = cma_user_data_offset(listen_id);
        event.event = RDMA_CM_EVENT_CONNECT_REQUEST;
        if (ib_event->event == IB_CM_SIDR_REQ_RECEIVED) {
-               conn_id = cma_new_udp_id(&listen_id->id, ib_event);
+               conn_id = cma_new_udp_id(&listen_id->id, ib_event, net_dev);
                event.param.ud.private_data = ib_event->private_data + offset;
                event.param.ud.private_data_len =
                                IB_CM_SIDR_REQ_PRIVATE_DATA_SIZE - offset;
        } else {
-               conn_id = cma_new_conn_id(&listen_id->id, ib_event);
+               conn_id = cma_new_conn_id(&listen_id->id, ib_event, net_dev);
                cma_set_req_event_data(&event, &ib_event->param.req_rcvd,
                                       ib_event->private_data, offset);
        }
@@ -1356,6 +1792,8 @@ static int cma_req_handler(struct ib_cm_id *cm_id, struct ib_cm_event *ib_event)
        mutex_unlock(&conn_id->handler_mutex);
        mutex_unlock(&listen_id->handler_mutex);
        cma_deref_id(conn_id);
+       if (net_dev)
+               dev_put(net_dev);
        return 0;
 
 err3:
@@ -1369,6 +1807,11 @@ err1:
        mutex_unlock(&listen_id->handler_mutex);
        if (conn_id)
                rdma_destroy_id(&conn_id->id);
+
+net_dev_put:
+       if (net_dev)
+               dev_put(net_dev);
+
        return ret;
 }
 
@@ -1381,42 +1824,6 @@ __be64 rdma_get_service_id(struct rdma_cm_id *id, struct sockaddr *addr)
 }
 EXPORT_SYMBOL(rdma_get_service_id);
 
-static void cma_set_compare_data(enum rdma_port_space ps, struct sockaddr *addr,
-                                struct ib_cm_compare_data *compare)
-{
-       struct cma_hdr *cma_data, *cma_mask;
-       __be32 ip4_addr;
-       struct in6_addr ip6_addr;
-
-       memset(compare, 0, sizeof *compare);
-       cma_data = (void *) compare->data;
-       cma_mask = (void *) compare->mask;
-
-       switch (addr->sa_family) {
-       case AF_INET:
-               ip4_addr = ((struct sockaddr_in *) addr)->sin_addr.s_addr;
-               cma_set_ip_ver(cma_data, 4);
-               cma_set_ip_ver(cma_mask, 0xF);
-               if (!cma_any_addr(addr)) {
-                       cma_data->dst_addr.ip4.addr = ip4_addr;
-                       cma_mask->dst_addr.ip4.addr = htonl(~0);
-               }
-               break;
-       case AF_INET6:
-               ip6_addr = ((struct sockaddr_in6 *) addr)->sin6_addr;
-               cma_set_ip_ver(cma_data, 6);
-               cma_set_ip_ver(cma_mask, 0xF);
-               if (!cma_any_addr(addr)) {
-                       cma_data->dst_addr.ip6 = ip6_addr;
-                       memset(&cma_mask->dst_addr.ip6, 0xFF,
-                              sizeof cma_mask->dst_addr.ip6);
-               }
-               break;
-       default:
-               break;
-       }
-}
-
 static int cma_iw_handler(struct iw_cm_id *iw_id, struct iw_cm_event *iw_event)
 {
        struct rdma_id_private *id_priv = iw_id->context;
@@ -1498,7 +1905,8 @@ static int iw_conn_req_handler(struct iw_cm_id *cm_id,
                return -ECONNABORTED;
 
        /* Create a new RDMA id for the new IW CM ID */
-       new_cm_id = rdma_create_id(listen_id->id.event_handler,
+       new_cm_id = rdma_create_id(listen_id->id.route.addr.dev_addr.net,
+                                  listen_id->id.event_handler,
                                   listen_id->id.context,
                                   RDMA_PS_TCP, IB_QPT_RC);
        if (IS_ERR(new_cm_id)) {
@@ -1570,33 +1978,18 @@ out:
 
 static int cma_ib_listen(struct rdma_id_private *id_priv)
 {
-       struct ib_cm_compare_data compare_data;
        struct sockaddr *addr;
        struct ib_cm_id *id;
        __be64 svc_id;
-       int ret;
 
-       id = ib_create_cm_id(id_priv->id.device, cma_req_handler, id_priv);
+       addr = cma_src_addr(id_priv);
+       svc_id = rdma_get_service_id(&id_priv->id, addr);
+       id = ib_cm_insert_listen(id_priv->id.device, cma_req_handler, svc_id);
        if (IS_ERR(id))
                return PTR_ERR(id);
-
        id_priv->cm_id.ib = id;
 
-       addr = cma_src_addr(id_priv);
-       svc_id = rdma_get_service_id(&id_priv->id, addr);
-       if (cma_any_addr(addr) && !id_priv->afonly)
-               ret = ib_cm_listen(id_priv->cm_id.ib, svc_id, 0, NULL);
-       else {
-               cma_set_compare_data(id_priv->id.ps, addr, &compare_data);
-               ret = ib_cm_listen(id_priv->cm_id.ib, svc_id, 0, &compare_data);
-       }
-
-       if (ret) {
-               ib_destroy_cm_id(id_priv->cm_id.ib);
-               id_priv->cm_id.ib = NULL;
-       }
-
-       return ret;
+       return 0;
 }
 
 static int cma_iw_listen(struct rdma_id_private *id_priv, int backlog)
@@ -1610,6 +2003,7 @@ static int cma_iw_listen(struct rdma_id_private *id_priv, int backlog)
        if (IS_ERR(id))
                return PTR_ERR(id);
 
+       id->tos = id_priv->tos;
        id_priv->cm_id.iw = id;
 
        memcpy(&id_priv->cm_id.iw->local_addr, cma_src_addr(id_priv),
@@ -1640,13 +2034,13 @@ static void cma_listen_on_dev(struct rdma_id_private *id_priv,
 {
        struct rdma_id_private *dev_id_priv;
        struct rdma_cm_id *id;
+       struct net *net = id_priv->id.route.addr.dev_addr.net;
        int ret;
 
-       if (cma_family(id_priv) == AF_IB &&
-           rdma_node_get_transport(cma_dev->device->node_type) != RDMA_TRANSPORT_IB)
+       if (cma_family(id_priv) == AF_IB && !rdma_cap_ib_cm(cma_dev->device, 1))
                return;
 
-       id = rdma_create_id(cma_listen_handler, id_priv, id_priv->id.ps,
+       id = rdma_create_id(net, cma_listen_handler, id_priv, id_priv->id.ps,
                            id_priv->id.qp_type);
        if (IS_ERR(id))
                return;
@@ -1925,16 +2319,17 @@ static int cma_resolve_iboe_route(struct rdma_id_private *id_priv)
 
        route->num_paths = 1;
 
-       if (addr->dev_addr.bound_dev_if)
+       if (addr->dev_addr.bound_dev_if) {
                ndev = dev_get_by_index(&init_net, addr->dev_addr.bound_dev_if);
+               route->path_rec->net = &init_net;
+               route->path_rec->ifindex = addr->dev_addr.bound_dev_if;
+       }
        if (!ndev) {
                ret = -ENODEV;
                goto err2;
        }
 
-       route->path_rec->vlan_id = rdma_vlan_dev_vlan_id(ndev);
        memcpy(route->path_rec->dmac, addr->dev_addr.dst_dev_addr, ETH_ALEN);
-       memcpy(route->path_rec->smac, ndev->dev_addr, ndev->addr_len);
 
        rdma_ip2gid((struct sockaddr *)&id_priv->id.route.addr.src_addr,
                    &route->path_rec->sgid);
@@ -1984,26 +2379,15 @@ int rdma_resolve_route(struct rdma_cm_id *id, int timeout_ms)
                return -EINVAL;
 
        atomic_inc(&id_priv->refcount);
-       switch (rdma_node_get_transport(id->device->node_type)) {
-       case RDMA_TRANSPORT_IB:
-               switch (rdma_port_get_link_layer(id->device, id->port_num)) {
-               case IB_LINK_LAYER_INFINIBAND:
-                       ret = cma_resolve_ib_route(id_priv, timeout_ms);
-                       break;
-               case IB_LINK_LAYER_ETHERNET:
-                       ret = cma_resolve_iboe_route(id_priv);
-                       break;
-               default:
-                       ret = -ENOSYS;
-               }
-               break;
-       case RDMA_TRANSPORT_IWARP:
+       if (rdma_cap_ib_sa(id->device, id->port_num))
+               ret = cma_resolve_ib_route(id_priv, timeout_ms);
+       else if (rdma_protocol_roce(id->device, id->port_num))
+               ret = cma_resolve_iboe_route(id_priv);
+       else if (rdma_protocol_iwarp(id->device, id->port_num))
                ret = cma_resolve_iw_route(id_priv, timeout_ms);
-               break;
-       default:
+       else
                ret = -ENOSYS;
-               break;
-       }
+
        if (ret)
                goto err;
 
@@ -2045,7 +2429,7 @@ static int cma_bind_loopback(struct rdma_id_private *id_priv)
        mutex_lock(&lock);
        list_for_each_entry(cur_dev, &dev_list, list) {
                if (cma_family(id_priv) == AF_IB &&
-                   rdma_node_get_transport(cur_dev->device->node_type) != RDMA_TRANSPORT_IB)
+                   !rdma_cap_ib_cm(cur_dev->device, 1))
                        continue;
 
                if (!cma_dev)
@@ -2068,7 +2452,7 @@ static int cma_bind_loopback(struct rdma_id_private *id_priv)
        p = 1;
 
 port_found:
-       ret = ib_get_cached_gid(cma_dev->device, p, 0, &gid);
+       ret = ib_get_cached_gid(cma_dev->device, p, 0, &gid, NULL);
        if (ret)
                goto out;
 
@@ -2077,7 +2461,7 @@ port_found:
                goto out;
 
        id_priv->id.route.addr.dev_addr.dev_type =
-               (rdma_port_get_link_layer(cma_dev->device, p) == IB_LINK_LAYER_INFINIBAND) ?
+               (rdma_protocol_ib(cma_dev->device, p)) ?
                ARPHRD_INFINIBAND : ARPHRD_ETHER;
 
        rdma_addr_set_sgid(&id_priv->id.route.addr.dev_addr, &gid);
@@ -2195,8 +2579,11 @@ static int cma_bind_addr(struct rdma_cm_id *id, struct sockaddr *src_addr,
                src_addr = (struct sockaddr *) &id->route.addr.src_addr;
                src_addr->sa_family = dst_addr->sa_family;
                if (dst_addr->sa_family == AF_INET6) {
-                       ((struct sockaddr_in6 *) src_addr)->sin6_scope_id =
-                               ((struct sockaddr_in6 *) dst_addr)->sin6_scope_id;
+                       struct sockaddr_in6 *src_addr6 = (struct sockaddr_in6 *) src_addr;
+                       struct sockaddr_in6 *dst_addr6 = (struct sockaddr_in6 *) dst_addr;
+                       src_addr6->sin6_scope_id = dst_addr6->sin6_scope_id;
+                       if (ipv6_addr_type(&dst_addr6->sin6_addr) & IPV6_ADDR_LINKLOCAL)
+                               id->route.addr.dev_addr.bound_dev_if = dst_addr6->sin6_scope_id;
                } else if (dst_addr->sa_family == AF_IB) {
                        ((struct sockaddr_ib *) src_addr)->sib_pkey =
                                ((struct sockaddr_ib *) dst_addr)->sib_pkey;
@@ -2317,8 +2704,8 @@ static void cma_bind_port(struct rdma_bind_list *bind_list,
        hlist_add_head(&id_priv->node, &bind_list->owners);
 }
 
-static int cma_alloc_port(struct idr *ps, struct rdma_id_private *id_priv,
-                         unsigned short snum)
+static int cma_alloc_port(enum rdma_port_space ps,
+                         struct rdma_id_private *id_priv, unsigned short snum)
 {
        struct rdma_bind_list *bind_list;
        int ret;
@@ -2327,7 +2714,8 @@ static int cma_alloc_port(struct idr *ps, struct rdma_id_private *id_priv,
        if (!bind_list)
                return -ENOMEM;
 
-       ret = idr_alloc(ps, bind_list, snum, snum + 1, GFP_KERNEL);
+       ret = cma_ps_alloc(id_priv->id.route.addr.dev_addr.net, ps, bind_list,
+                          snum);
        if (ret < 0)
                goto err;
 
@@ -2340,18 +2728,20 @@ err:
        return ret == -ENOSPC ? -EADDRNOTAVAIL : ret;
 }
 
-static int cma_alloc_any_port(struct idr *ps, struct rdma_id_private *id_priv)
+static int cma_alloc_any_port(enum rdma_port_space ps,
+                             struct rdma_id_private *id_priv)
 {
        static unsigned int last_used_port;
        int low, high, remaining;
        unsigned int rover;
+       struct net *net = id_priv->id.route.addr.dev_addr.net;
 
-       inet_get_local_port_range(&init_net, &low, &high);
+       inet_get_local_port_range(net, &low, &high);
        remaining = (high - low) + 1;
        rover = prandom_u32() % remaining + low;
 retry:
        if (last_used_port != rover &&
-           !idr_find(ps, (unsigned short) rover)) {
+           !cma_ps_find(net, ps, (unsigned short)rover)) {
                int ret = cma_alloc_port(ps, id_priv, rover);
                /*
                 * Remember previously used port number in order to avoid
@@ -2406,7 +2796,8 @@ static int cma_check_port(struct rdma_bind_list *bind_list,
        return 0;
 }
 
-static int cma_use_port(struct idr *ps, struct rdma_id_private *id_priv)
+static int cma_use_port(enum rdma_port_space ps,
+                       struct rdma_id_private *id_priv)
 {
        struct rdma_bind_list *bind_list;
        unsigned short snum;
@@ -2416,7 +2807,7 @@ static int cma_use_port(struct idr *ps, struct rdma_id_private *id_priv)
        if (snum < PROT_SOCK && !capable(CAP_NET_BIND_SERVICE))
                return -EACCES;
 
-       bind_list = idr_find(ps, snum);
+       bind_list = cma_ps_find(id_priv->id.route.addr.dev_addr.net, ps, snum);
        if (!bind_list) {
                ret = cma_alloc_port(ps, id_priv, snum);
        } else {
@@ -2439,25 +2830,24 @@ static int cma_bind_listen(struct rdma_id_private *id_priv)
        return ret;
 }
 
-static struct idr *cma_select_inet_ps(struct rdma_id_private *id_priv)
+static enum rdma_port_space cma_select_inet_ps(
+               struct rdma_id_private *id_priv)
 {
        switch (id_priv->id.ps) {
        case RDMA_PS_TCP:
-               return &tcp_ps;
        case RDMA_PS_UDP:
-               return &udp_ps;
        case RDMA_PS_IPOIB:
-               return &ipoib_ps;
        case RDMA_PS_IB:
-               return &ib_ps;
+               return id_priv->id.ps;
        default:
-               return NULL;
+
+               return 0;
        }
 }
 
-static struct idr *cma_select_ib_ps(struct rdma_id_private *id_priv)
+static enum rdma_port_space cma_select_ib_ps(struct rdma_id_private *id_priv)
 {
-       struct idr *ps = NULL;
+       enum rdma_port_space ps = 0;
        struct sockaddr_ib *sib;
        u64 sid_ps, mask, sid;
 
@@ -2467,15 +2857,15 @@ static struct idr *cma_select_ib_ps(struct rdma_id_private *id_priv)
 
        if ((id_priv->id.ps == RDMA_PS_IB) && (sid == (RDMA_IB_IP_PS_IB & mask))) {
                sid_ps = RDMA_IB_IP_PS_IB;
-               ps = &ib_ps;
+               ps = RDMA_PS_IB;
        } else if (((id_priv->id.ps == RDMA_PS_IB) || (id_priv->id.ps == RDMA_PS_TCP)) &&
                   (sid == (RDMA_IB_IP_PS_TCP & mask))) {
                sid_ps = RDMA_IB_IP_PS_TCP;
-               ps = &tcp_ps;
+               ps = RDMA_PS_TCP;
        } else if (((id_priv->id.ps == RDMA_PS_IB) || (id_priv->id.ps == RDMA_PS_UDP)) &&
                   (sid == (RDMA_IB_IP_PS_UDP & mask))) {
                sid_ps = RDMA_IB_IP_PS_UDP;
-               ps = &udp_ps;
+               ps = RDMA_PS_UDP;
        }
 
        if (ps) {
@@ -2488,7 +2878,7 @@ static struct idr *cma_select_ib_ps(struct rdma_id_private *id_priv)
 
 static int cma_get_port(struct rdma_id_private *id_priv)
 {
-       struct idr *ps;
+       enum rdma_port_space ps;
        int ret;
 
        if (cma_family(id_priv) != AF_IB)
@@ -2554,18 +2944,15 @@ int rdma_listen(struct rdma_cm_id *id, int backlog)
 
        id_priv->backlog = backlog;
        if (id->device) {
-               switch (rdma_node_get_transport(id->device->node_type)) {
-               case RDMA_TRANSPORT_IB:
+               if (rdma_cap_ib_cm(id->device, 1)) {
                        ret = cma_ib_listen(id_priv);
                        if (ret)
                                goto err;
-                       break;
-               case RDMA_TRANSPORT_IWARP:
+               } else if (rdma_cap_iw_cm(id->device, 1)) {
                        ret = cma_iw_listen(id_priv, backlog);
                        if (ret)
                                goto err;
-                       break;
-               default:
+               } else {
                        ret = -ENOSYS;
                        goto err;
                }
@@ -2612,8 +2999,11 @@ int rdma_bind_addr(struct rdma_cm_id *id, struct sockaddr *addr)
                if (addr->sa_family == AF_INET)
                        id_priv->afonly = 1;
 #if IS_ENABLED(CONFIG_IPV6)
-               else if (addr->sa_family == AF_INET6)
-                       id_priv->afonly = init_net.ipv6.sysctl.bindv6only;
+               else if (addr->sa_family == AF_INET6) {
+                       struct net *net = id_priv->id.route.addr.dev_addr.net;
+
+                       id_priv->afonly = net->ipv6.sysctl.bindv6only;
+               }
 #endif
        }
        ret = cma_get_port(id_priv);
@@ -2857,6 +3247,7 @@ static int cma_connect_iw(struct rdma_id_private *id_priv,
        if (IS_ERR(cm_id))
                return PTR_ERR(cm_id);
 
+       cm_id->tos = id_priv->tos;
        id_priv->cm_id.iw = cm_id;
 
        memcpy(&cm_id->local_addr, cma_src_addr(id_priv),
@@ -2901,20 +3292,15 @@ int rdma_connect(struct rdma_cm_id *id, struct rdma_conn_param *conn_param)
                id_priv->srq = conn_param->srq;
        }
 
-       switch (rdma_node_get_transport(id->device->node_type)) {
-       case RDMA_TRANSPORT_IB:
+       if (rdma_cap_ib_cm(id->device, id->port_num)) {
                if (id->qp_type == IB_QPT_UD)
                        ret = cma_resolve_ib_udp(id_priv, conn_param);
                else
                        ret = cma_connect_ib(id_priv, conn_param);
-               break;
-       case RDMA_TRANSPORT_IWARP:
+       } else if (rdma_cap_iw_cm(id->device, id->port_num))
                ret = cma_connect_iw(id_priv, conn_param);
-               break;
-       default:
+       else
                ret = -ENOSYS;
-               break;
-       }
        if (ret)
                goto err;
 
@@ -3017,8 +3403,7 @@ int rdma_accept(struct rdma_cm_id *id, struct rdma_conn_param *conn_param)
                id_priv->srq = conn_param->srq;
        }
 
-       switch (rdma_node_get_transport(id->device->node_type)) {
-       case RDMA_TRANSPORT_IB:
+       if (rdma_cap_ib_cm(id->device, id->port_num)) {
                if (id->qp_type == IB_QPT_UD) {
                        if (conn_param)
                                ret = cma_send_sidr_rep(id_priv, IB_SIDR_SUCCESS,
@@ -3034,14 +3419,10 @@ int rdma_accept(struct rdma_cm_id *id, struct rdma_conn_param *conn_param)
                        else
                                ret = cma_rep_recv(id_priv);
                }
-               break;
-       case RDMA_TRANSPORT_IWARP:
+       } else if (rdma_cap_iw_cm(id->device, id->port_num))
                ret = cma_accept_iw(id_priv, conn_param);
-               break;
-       default:
+       else
                ret = -ENOSYS;
-               break;
-       }
 
        if (ret)
                goto reject;
@@ -3085,8 +3466,7 @@ int rdma_reject(struct rdma_cm_id *id, const void *private_data,
        if (!id_priv->cm_id.ib)
                return -EINVAL;
 
-       switch (rdma_node_get_transport(id->device->node_type)) {
-       case RDMA_TRANSPORT_IB:
+       if (rdma_cap_ib_cm(id->device, id->port_num)) {
                if (id->qp_type == IB_QPT_UD)
                        ret = cma_send_sidr_rep(id_priv, IB_SIDR_REJECT, 0,
                                                private_data, private_data_len);
@@ -3094,15 +3474,12 @@ int rdma_reject(struct rdma_cm_id *id, const void *private_data,
                        ret = ib_send_cm_rej(id_priv->cm_id.ib,
                                             IB_CM_REJ_CONSUMER_DEFINED, NULL,
                                             0, private_data, private_data_len);
-               break;
-       case RDMA_TRANSPORT_IWARP:
+       } else if (rdma_cap_iw_cm(id->device, id->port_num)) {
                ret = iw_cm_reject(id_priv->cm_id.iw,
                                   private_data, private_data_len);
-               break;
-       default:
+       } else
                ret = -ENOSYS;
-               break;
-       }
+
        return ret;
 }
 EXPORT_SYMBOL(rdma_reject);
@@ -3116,22 +3493,18 @@ int rdma_disconnect(struct rdma_cm_id *id)
        if (!id_priv->cm_id.ib)
                return -EINVAL;
 
-       switch (rdma_node_get_transport(id->device->node_type)) {
-       case RDMA_TRANSPORT_IB:
+       if (rdma_cap_ib_cm(id->device, id->port_num)) {
                ret = cma_modify_qp_err(id_priv);
                if (ret)
                        goto out;
                /* Initiate or respond to a disconnect. */
                if (ib_send_cm_dreq(id_priv->cm_id.ib, NULL, 0))
                        ib_send_cm_drep(id_priv->cm_id.ib, NULL, 0);
-               break;
-       case RDMA_TRANSPORT_IWARP:
+       } else if (rdma_cap_iw_cm(id->device, id->port_num)) {
                ret = iw_cm_disconnect(id_priv->cm_id.iw, 0);
-               break;
-       default:
+       } else
                ret = -EINVAL;
-               break;
-       }
+
 out:
        return ret;
 }
@@ -3377,24 +3750,13 @@ int rdma_join_multicast(struct rdma_cm_id *id, struct sockaddr *addr,
        list_add(&mc->list, &id_priv->mc_list);
        spin_unlock(&id_priv->lock);
 
-       switch (rdma_node_get_transport(id->device->node_type)) {
-       case RDMA_TRANSPORT_IB:
-               switch (rdma_port_get_link_layer(id->device, id->port_num)) {
-               case IB_LINK_LAYER_INFINIBAND:
-                       ret = cma_join_ib_multicast(id_priv, mc);
-                       break;
-               case IB_LINK_LAYER_ETHERNET:
-                       kref_init(&mc->mcref);
-                       ret = cma_iboe_join_multicast(id_priv, mc);
-                       break;
-               default:
-                       ret = -EINVAL;
-               }
-               break;
-       default:
+       if (rdma_protocol_roce(id->device, id->port_num)) {
+               kref_init(&mc->mcref);
+               ret = cma_iboe_join_multicast(id_priv, mc);
+       } else if (rdma_cap_ib_mcast(id->device, id->port_num))
+               ret = cma_join_ib_multicast(id_priv, mc);
+       else
                ret = -ENOSYS;
-               break;
-       }
 
        if (ret) {
                spin_lock_irq(&id_priv->lock);
@@ -3422,19 +3784,15 @@ void rdma_leave_multicast(struct rdma_cm_id *id, struct sockaddr *addr)
                                ib_detach_mcast(id->qp,
                                                &mc->multicast.ib->rec.mgid,
                                                be16_to_cpu(mc->multicast.ib->rec.mlid));
-                       if (rdma_node_get_transport(id_priv->cma_dev->device->node_type) == RDMA_TRANSPORT_IB) {
-                               switch (rdma_port_get_link_layer(id->device, id->port_num)) {
-                               case IB_LINK_LAYER_INFINIBAND:
-                                       ib_sa_free_multicast(mc->multicast.ib);
-                                       kfree(mc);
-                                       break;
-                               case IB_LINK_LAYER_ETHERNET:
-                                       kref_put(&mc->mcref, release_mc);
-                                       break;
-                               default:
-                                       break;
-                               }
-                       }
+
+                       BUG_ON(id_priv->cma_dev->device != id->device);
+
+                       if (rdma_cap_ib_mcast(id->device, id->port_num)) {
+                               ib_sa_free_multicast(mc->multicast.ib);
+                               kfree(mc);
+                       } else if (rdma_protocol_roce(id->device, id->port_num))
+                               kref_put(&mc->mcref, release_mc);
+
                        return;
                }
        }
@@ -3450,6 +3808,7 @@ static int cma_netdev_change(struct net_device *ndev, struct rdma_id_private *id
        dev_addr = &id_priv->id.route.addr.dev_addr;
 
        if ((dev_addr->bound_dev_if == ndev->ifindex) &&
+           (net_eq(dev_net(ndev), dev_addr->net)) &&
            memcmp(dev_addr->src_dev_addr, ndev->dev_addr, ndev->addr_len)) {
                printk(KERN_INFO "RDMA CM addr change for ndev %s used by id %p\n",
                       ndev->name, &id_priv->id);
@@ -3475,9 +3834,6 @@ static int cma_netdev_callback(struct notifier_block *self, unsigned long event,
        struct rdma_id_private *id_priv;
        int ret = NOTIFY_DONE;
 
-       if (dev_net(ndev) != &init_net)
-               return NOTIFY_DONE;
-
        if (event != NETDEV_BONDING_FAILOVER)
                return NOTIFY_DONE;
 
@@ -3578,11 +3934,10 @@ static void cma_process_remove(struct cma_device *cma_dev)
        wait_for_completion(&cma_dev->comp);
 }
 
-static void cma_remove_one(struct ib_device *device)
+static void cma_remove_one(struct ib_device *device, void *client_data)
 {
-       struct cma_device *cma_dev;
+       struct cma_device *cma_dev = client_data;
 
-       cma_dev = ib_get_client_data(device, &cma_client);
        if (!cma_dev)
                return;
 
@@ -3673,6 +4028,35 @@ static const struct ibnl_client_cbs cma_cb_table[] = {
                                       .module = THIS_MODULE },
 };
 
+static int cma_init_net(struct net *net)
+{
+       struct cma_pernet *pernet = cma_pernet(net);
+
+       idr_init(&pernet->tcp_ps);
+       idr_init(&pernet->udp_ps);
+       idr_init(&pernet->ipoib_ps);
+       idr_init(&pernet->ib_ps);
+
+       return 0;
+}
+
+static void cma_exit_net(struct net *net)
+{
+       struct cma_pernet *pernet = cma_pernet(net);
+
+       idr_destroy(&pernet->tcp_ps);
+       idr_destroy(&pernet->udp_ps);
+       idr_destroy(&pernet->ipoib_ps);
+       idr_destroy(&pernet->ib_ps);
+}
+
+static struct pernet_operations cma_pernet_operations = {
+       .init = cma_init_net,
+       .exit = cma_exit_net,
+       .id = &cma_pernet_id,
+       .size = sizeof(struct cma_pernet),
+};
+
 static int __init cma_init(void)
 {
        int ret;
@@ -3681,6 +4065,10 @@ static int __init cma_init(void)
        if (!cma_wq)
                return -ENOMEM;
 
+       ret = register_pernet_subsys(&cma_pernet_operations);
+       if (ret)
+               goto err_wq;
+
        ib_sa_register_client(&sa_client);
        rdma_addr_register_client(&addr_client);
        register_netdevice_notifier(&cma_nb);
@@ -3698,6 +4086,7 @@ err:
        unregister_netdevice_notifier(&cma_nb);
        rdma_addr_unregister_client(&addr_client);
        ib_sa_unregister_client(&sa_client);
+err_wq:
        destroy_workqueue(cma_wq);
        return ret;
 }
@@ -3709,11 +4098,8 @@ static void __exit cma_cleanup(void)
        unregister_netdevice_notifier(&cma_nb);
        rdma_addr_unregister_client(&addr_client);
        ib_sa_unregister_client(&sa_client);
+       unregister_pernet_subsys(&cma_pernet_operations);
        destroy_workqueue(cma_wq);
-       idr_destroy(&tcp_ps);
-       idr_destroy(&udp_ps);
-       idr_destroy(&ipoib_ps);
-       idr_destroy(&ib_ps);
 }
 
 module_init(cma_init);