+static bool validate_ipv4_net_dev(struct net_device *net_dev,
+ const struct sockaddr_in *dst_addr,
+ const struct sockaddr_in *src_addr)
+{
+ __be32 daddr = dst_addr->sin_addr.s_addr,
+ saddr = src_addr->sin_addr.s_addr;
+ struct fib_result res;
+ struct flowi4 fl4;
+ int err;
+ bool ret;
+
+ if (ipv4_is_multicast(saddr) || ipv4_is_lbcast(saddr) ||
+ ipv4_is_lbcast(daddr) || ipv4_is_zeronet(saddr) ||
+ ipv4_is_zeronet(daddr) || ipv4_is_loopback(daddr) ||
+ ipv4_is_loopback(saddr))
+ return false;
+
+ memset(&fl4, 0, sizeof(fl4));
+ fl4.flowi4_iif = net_dev->ifindex;
+ fl4.daddr = daddr;
+ fl4.saddr = saddr;
+
+ rcu_read_lock();
+ err = fib_lookup(dev_net(net_dev), &fl4, &res, 0);
+ ret = err == 0 && FIB_RES_DEV(res) == net_dev;
+ rcu_read_unlock();
+
+ return ret;
+}
+
+static bool validate_ipv6_net_dev(struct net_device *net_dev,
+ const struct sockaddr_in6 *dst_addr,
+ const struct sockaddr_in6 *src_addr)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ const int strict = ipv6_addr_type(&dst_addr->sin6_addr) &
+ IPV6_ADDR_LINKLOCAL;
+ struct rt6_info *rt = rt6_lookup(dev_net(net_dev), &dst_addr->sin6_addr,
+ &src_addr->sin6_addr, net_dev->ifindex,
+ strict);
+ bool ret;
+
+ if (!rt)
+ return false;
+
+ ret = rt->rt6i_idev->dev == net_dev;
+ ip6_rt_put(rt);
+
+ return ret;
+#else
+ return false;
+#endif
+}
+
+static bool validate_net_dev(struct net_device *net_dev,
+ const struct sockaddr *daddr,
+ const struct sockaddr *saddr)
+{
+ const struct sockaddr_in *daddr4 = (const struct sockaddr_in *)daddr;
+ const struct sockaddr_in *saddr4 = (const struct sockaddr_in *)saddr;
+ const struct sockaddr_in6 *daddr6 = (const struct sockaddr_in6 *)daddr;
+ const struct sockaddr_in6 *saddr6 = (const struct sockaddr_in6 *)saddr;
+
+ switch (daddr->sa_family) {
+ case AF_INET:
+ return saddr->sa_family == AF_INET &&
+ validate_ipv4_net_dev(net_dev, daddr4, saddr4);
+
+ case AF_INET6:
+ return saddr->sa_family == AF_INET6 &&
+ validate_ipv6_net_dev(net_dev, daddr6, saddr6);
+
+ default:
+ return false;
+ }
+}
+
+static struct net_device *cma_get_net_dev(struct ib_cm_event *ib_event,
+ const struct cma_req_info *req)
+{
+ struct sockaddr_storage listen_addr_storage, src_addr_storage;
+ struct sockaddr *listen_addr = (struct sockaddr *)&listen_addr_storage,
+ *src_addr = (struct sockaddr *)&src_addr_storage;
+ struct net_device *net_dev;
+ const union ib_gid *gid = req->has_gid ? &req->local_gid : NULL;
+ int err;
+
+ err = cma_save_ip_info(listen_addr, src_addr, ib_event,
+ req->service_id);
+ if (err)
+ return ERR_PTR(err);
+
+ net_dev = ib_get_net_dev_by_params(req->device, req->port, req->pkey,
+ gid, listen_addr);
+ if (!net_dev)
+ return ERR_PTR(-ENODEV);
+
+ if (!validate_net_dev(net_dev, listen_addr, src_addr)) {
+ dev_put(net_dev);
+ return ERR_PTR(-EHOSTUNREACH);
+ }
+
+ return net_dev;
+}
+
+static enum rdma_port_space rdma_ps_from_service_id(__be64 service_id)
+{
+ return (be64_to_cpu(service_id) >> 16) & 0xffff;
+}
+
+static bool cma_match_private_data(struct rdma_id_private *id_priv,
+ const struct cma_hdr *hdr)
+{
+ struct sockaddr *addr = cma_src_addr(id_priv);
+ __be32 ip4_addr;
+ struct in6_addr ip6_addr;
+
+ if (cma_any_addr(addr) && !id_priv->afonly)
+ return true;
+
+ switch (addr->sa_family) {
+ case AF_INET:
+ ip4_addr = ((struct sockaddr_in *)addr)->sin_addr.s_addr;
+ if (cma_get_ip_ver(hdr) != 4)
+ return false;
+ if (!cma_any_addr(addr) &&
+ hdr->dst_addr.ip4.addr != ip4_addr)
+ return false;
+ break;
+ case AF_INET6:
+ ip6_addr = ((struct sockaddr_in6 *)addr)->sin6_addr;
+ if (cma_get_ip_ver(hdr) != 6)
+ return false;
+ if (!cma_any_addr(addr) &&
+ memcmp(&hdr->dst_addr.ip6, &ip6_addr, sizeof(ip6_addr)))
+ return false;
+ break;
+ case AF_IB:
+ return true;
+ default:
+ return false;
+ }
+
+ return true;
+}
+
+static bool cma_protocol_roce_dev_port(struct ib_device *device, int port_num)
+{
+ enum rdma_link_layer ll = rdma_port_get_link_layer(device, port_num);
+ enum rdma_transport_type transport =
+ rdma_node_get_transport(device->node_type);
+
+ return ll == IB_LINK_LAYER_ETHERNET && transport == RDMA_TRANSPORT_IB;
+}
+
+static bool cma_protocol_roce(const struct rdma_cm_id *id)
+{
+ struct ib_device *device = id->device;
+ const int port_num = id->port_num ?: rdma_start_port(device);
+
+ return cma_protocol_roce_dev_port(device, port_num);
+}
+
+static bool cma_match_net_dev(const struct rdma_cm_id *id,
+ const struct net_device *net_dev,
+ u8 port_num)
+{
+ const struct rdma_addr *addr = &id->route.addr;
+
+ if (!net_dev)
+ /* This request is an AF_IB request or a RoCE request */
+ return (!id->port_num || id->port_num == port_num) &&
+ (addr->src_addr.ss_family == AF_IB ||
+ cma_protocol_roce_dev_port(id->device, port_num));
+
+ return !addr->dev_addr.bound_dev_if ||
+ (net_eq(dev_net(net_dev), addr->dev_addr.net) &&
+ addr->dev_addr.bound_dev_if == net_dev->ifindex);
+}
+
+static struct rdma_id_private *cma_find_listener(
+ const struct rdma_bind_list *bind_list,
+ const struct ib_cm_id *cm_id,
+ const struct ib_cm_event *ib_event,
+ const struct cma_req_info *req,
+ const struct net_device *net_dev)
+{
+ struct rdma_id_private *id_priv, *id_priv_dev;
+
+ if (!bind_list)
+ return ERR_PTR(-EINVAL);
+
+ hlist_for_each_entry(id_priv, &bind_list->owners, node) {
+ if (cma_match_private_data(id_priv, ib_event->private_data)) {
+ if (id_priv->id.device == cm_id->device &&
+ cma_match_net_dev(&id_priv->id, net_dev, req->port))
+ return id_priv;
+ list_for_each_entry(id_priv_dev,
+ &id_priv->listen_list,
+ listen_list) {
+ if (id_priv_dev->id.device == cm_id->device &&
+ cma_match_net_dev(&id_priv_dev->id, net_dev, req->port))
+ return id_priv_dev;
+ }
+ }
+ }
+
+ return ERR_PTR(-EINVAL);
+}
+
+static struct rdma_id_private *cma_id_from_event(struct ib_cm_id *cm_id,
+ struct ib_cm_event *ib_event,
+ struct net_device **net_dev)
+{
+ struct cma_req_info req;
+ struct rdma_bind_list *bind_list;
+ struct rdma_id_private *id_priv;
+ int err;
+
+ err = cma_save_req_info(ib_event, &req);
+ if (err)
+ return ERR_PTR(err);
+
+ *net_dev = cma_get_net_dev(ib_event, &req);
+ if (IS_ERR(*net_dev)) {
+ if (PTR_ERR(*net_dev) == -EAFNOSUPPORT) {
+ /* Assuming the protocol is AF_IB */
+ *net_dev = NULL;
+ } else if (cma_protocol_roce_dev_port(req.device, req.port)) {
+ /* TODO find the net dev matching the request parameters
+ * through the RoCE GID table */
+ *net_dev = NULL;
+ } else {
+ return ERR_CAST(*net_dev);
+ }
+ }
+
+ bind_list = cma_ps_find(*net_dev ? dev_net(*net_dev) : &init_net,
+ rdma_ps_from_service_id(req.service_id),
+ cma_port_from_service_id(req.service_id));
+ id_priv = cma_find_listener(bind_list, cm_id, ib_event, &req, *net_dev);
+ if (IS_ERR(id_priv) && *net_dev) {
+ dev_put(*net_dev);
+ *net_dev = NULL;
+ }
+
+ return id_priv;
+}
+