These changes are the raw update to linux-4.4.6-rt14. Kernel sources
[kvmfornfv.git] / kernel / net / netfilter / xt_socket.c
index e092cb0..2ec08f0 100644 (file)
@@ -143,7 +143,8 @@ static bool xt_socket_sk_is_transparent(struct sock *sk)
        }
 }
 
-static struct sock *xt_socket_lookup_slow_v4(const struct sk_buff *skb,
+static struct sock *xt_socket_lookup_slow_v4(struct net *net,
+                                            const struct sk_buff *skb,
                                             const struct net_device *indev)
 {
        const struct iphdr *iph = ip_hdr(skb);
@@ -197,7 +198,7 @@ static struct sock *xt_socket_lookup_slow_v4(const struct sk_buff *skb,
        }
 #endif
 
-       return xt_socket_get_sock_v4(dev_net(skb->dev), protocol, saddr, daddr,
+       return xt_socket_get_sock_v4(net, protocol, saddr, daddr,
                                     sport, dport, indev);
 }
 
@@ -205,10 +206,11 @@ static bool
 socket_match(const struct sk_buff *skb, struct xt_action_param *par,
             const struct xt_socket_mtinfo1 *info)
 {
+       struct sk_buff *pskb = (struct sk_buff *)skb;
        struct sock *sk = skb->sk;
 
        if (!sk)
-               sk = xt_socket_lookup_slow_v4(skb, par->in);
+               sk = xt_socket_lookup_slow_v4(par->net, skb, par->in);
        if (sk) {
                bool wildcard;
                bool transparent = true;
@@ -226,6 +228,10 @@ socket_match(const struct sk_buff *skb, struct xt_action_param *par,
                if (info->flags & XT_SOCKET_TRANSPARENT)
                        transparent = xt_socket_sk_is_transparent(sk);
 
+               if (info->flags & XT_SOCKET_RESTORESKMARK && !wildcard &&
+                   transparent)
+                       pskb->mark = sk->sk_mark;
+
                if (sk != skb->sk)
                        sock_gen_put(sk);
 
@@ -247,7 +253,7 @@ socket_mt4_v0(const struct sk_buff *skb, struct xt_action_param *par)
 }
 
 static bool
-socket_mt4_v1_v2(const struct sk_buff *skb, struct xt_action_param *par)
+socket_mt4_v1_v2_v3(const struct sk_buff *skb, struct xt_action_param *par)
 {
        return socket_match(skb, par, par->matchinfo);
 }
@@ -330,7 +336,8 @@ xt_socket_get_sock_v6(struct net *net, const u8 protocol,
        return NULL;
 }
 
-static struct sock *xt_socket_lookup_slow_v6(const struct sk_buff *skb,
+static struct sock *xt_socket_lookup_slow_v6(struct net *net,
+                                            const struct sk_buff *skb,
                                             const struct net_device *indev)
 {
        __be16 uninitialized_var(dport), uninitialized_var(sport);
@@ -366,18 +373,19 @@ static struct sock *xt_socket_lookup_slow_v6(const struct sk_buff *skb,
                return NULL;
        }
 
-       return xt_socket_get_sock_v6(dev_net(skb->dev), tproto, saddr, daddr,
+       return xt_socket_get_sock_v6(net, tproto, saddr, daddr,
                                     sport, dport, indev);
 }
 
 static bool
-socket_mt6_v1_v2(const struct sk_buff *skb, struct xt_action_param *par)
+socket_mt6_v1_v2_v3(const struct sk_buff *skb, struct xt_action_param *par)
 {
        const struct xt_socket_mtinfo1 *info = (struct xt_socket_mtinfo1 *) par->matchinfo;
+       struct sk_buff *pskb = (struct sk_buff *)skb;
        struct sock *sk = skb->sk;
 
        if (!sk)
-               sk = xt_socket_lookup_slow_v6(skb, par->in);
+               sk = xt_socket_lookup_slow_v6(par->net, skb, par->in);
        if (sk) {
                bool wildcard;
                bool transparent = true;
@@ -395,6 +403,10 @@ socket_mt6_v1_v2(const struct sk_buff *skb, struct xt_action_param *par)
                if (info->flags & XT_SOCKET_TRANSPARENT)
                        transparent = xt_socket_sk_is_transparent(sk);
 
+               if (info->flags & XT_SOCKET_RESTORESKMARK && !wildcard &&
+                   transparent)
+                       pskb->mark = sk->sk_mark;
+
                if (sk != skb->sk)
                        sock_gen_put(sk);
 
@@ -428,6 +440,19 @@ static int socket_mt_v2_check(const struct xt_mtchk_param *par)
        return 0;
 }
 
+static int socket_mt_v3_check(const struct xt_mtchk_param *par)
+{
+       const struct xt_socket_mtinfo3 *info =
+                                   (struct xt_socket_mtinfo3 *)par->matchinfo;
+
+       if (info->flags & ~XT_SOCKET_FLAGS_V3) {
+               pr_info("unknown flags 0x%x\n",
+                       info->flags & ~XT_SOCKET_FLAGS_V3);
+               return -EINVAL;
+       }
+       return 0;
+}
+
 static struct xt_match socket_mt_reg[] __read_mostly = {
        {
                .name           = "socket",
@@ -442,7 +467,7 @@ static struct xt_match socket_mt_reg[] __read_mostly = {
                .name           = "socket",
                .revision       = 1,
                .family         = NFPROTO_IPV4,
-               .match          = socket_mt4_v1_v2,
+               .match          = socket_mt4_v1_v2_v3,
                .checkentry     = socket_mt_v1_check,
                .matchsize      = sizeof(struct xt_socket_mtinfo1),
                .hooks          = (1 << NF_INET_PRE_ROUTING) |
@@ -454,7 +479,7 @@ static struct xt_match socket_mt_reg[] __read_mostly = {
                .name           = "socket",
                .revision       = 1,
                .family         = NFPROTO_IPV6,
-               .match          = socket_mt6_v1_v2,
+               .match          = socket_mt6_v1_v2_v3,
                .checkentry     = socket_mt_v1_check,
                .matchsize      = sizeof(struct xt_socket_mtinfo1),
                .hooks          = (1 << NF_INET_PRE_ROUTING) |
@@ -466,7 +491,7 @@ static struct xt_match socket_mt_reg[] __read_mostly = {
                .name           = "socket",
                .revision       = 2,
                .family         = NFPROTO_IPV4,
-               .match          = socket_mt4_v1_v2,
+               .match          = socket_mt4_v1_v2_v3,
                .checkentry     = socket_mt_v2_check,
                .matchsize      = sizeof(struct xt_socket_mtinfo1),
                .hooks          = (1 << NF_INET_PRE_ROUTING) |
@@ -478,13 +503,37 @@ static struct xt_match socket_mt_reg[] __read_mostly = {
                .name           = "socket",
                .revision       = 2,
                .family         = NFPROTO_IPV6,
-               .match          = socket_mt6_v1_v2,
+               .match          = socket_mt6_v1_v2_v3,
                .checkentry     = socket_mt_v2_check,
                .matchsize      = sizeof(struct xt_socket_mtinfo1),
                .hooks          = (1 << NF_INET_PRE_ROUTING) |
                                  (1 << NF_INET_LOCAL_IN),
                .me             = THIS_MODULE,
        },
+#endif
+       {
+               .name           = "socket",
+               .revision       = 3,
+               .family         = NFPROTO_IPV4,
+               .match          = socket_mt4_v1_v2_v3,
+               .checkentry     = socket_mt_v3_check,
+               .matchsize      = sizeof(struct xt_socket_mtinfo1),
+               .hooks          = (1 << NF_INET_PRE_ROUTING) |
+                                 (1 << NF_INET_LOCAL_IN),
+               .me             = THIS_MODULE,
+       },
+#ifdef XT_SOCKET_HAVE_IPV6
+       {
+               .name           = "socket",
+               .revision       = 3,
+               .family         = NFPROTO_IPV6,
+               .match          = socket_mt6_v1_v2_v3,
+               .checkentry     = socket_mt_v3_check,
+               .matchsize      = sizeof(struct xt_socket_mtinfo1),
+               .hooks          = (1 << NF_INET_PRE_ROUTING) |
+                                 (1 << NF_INET_LOCAL_IN),
+               .me             = THIS_MODULE,
+       },
 #endif
 };