Add the rt linux 4.1.3-rt3 as base
[kvmfornfv.git] / kernel / net / ipv4 / inet_diag.c
1 /*
2  * inet_diag.c  Module for monitoring INET transport protocols sockets.
3  *
4  * Authors:     Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
5  *
6  *      This program is free software; you can redistribute it and/or
7  *      modify it under the terms of the GNU General Public License
8  *      as published by the Free Software Foundation; either version
9  *      2 of the License, or (at your option) any later version.
10  */
11
12 #include <linux/kernel.h>
13 #include <linux/module.h>
14 #include <linux/types.h>
15 #include <linux/fcntl.h>
16 #include <linux/random.h>
17 #include <linux/slab.h>
18 #include <linux/cache.h>
19 #include <linux/init.h>
20 #include <linux/time.h>
21
22 #include <net/icmp.h>
23 #include <net/tcp.h>
24 #include <net/ipv6.h>
25 #include <net/inet_common.h>
26 #include <net/inet_connection_sock.h>
27 #include <net/inet_hashtables.h>
28 #include <net/inet_timewait_sock.h>
29 #include <net/inet6_hashtables.h>
30 #include <net/netlink.h>
31
32 #include <linux/inet.h>
33 #include <linux/stddef.h>
34
35 #include <linux/inet_diag.h>
36 #include <linux/sock_diag.h>
37
38 static const struct inet_diag_handler **inet_diag_table;
39
40 struct inet_diag_entry {
41         const __be32 *saddr;
42         const __be32 *daddr;
43         u16 sport;
44         u16 dport;
45         u16 family;
46         u16 userlocks;
47 };
48
49 static DEFINE_MUTEX(inet_diag_table_mutex);
50
51 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
52 {
53         if (!inet_diag_table[proto])
54                 request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK,
55                                NETLINK_SOCK_DIAG, AF_INET, proto);
56
57         mutex_lock(&inet_diag_table_mutex);
58         if (!inet_diag_table[proto])
59                 return ERR_PTR(-ENOENT);
60
61         return inet_diag_table[proto];
62 }
63
64 static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
65 {
66         mutex_unlock(&inet_diag_table_mutex);
67 }
68
69 static void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
70 {
71         r->idiag_family = sk->sk_family;
72
73         r->id.idiag_sport = htons(sk->sk_num);
74         r->id.idiag_dport = sk->sk_dport;
75         r->id.idiag_if = sk->sk_bound_dev_if;
76         sock_diag_save_cookie(sk, r->id.idiag_cookie);
77
78 #if IS_ENABLED(CONFIG_IPV6)
79         if (sk->sk_family == AF_INET6) {
80                 *(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr;
81                 *(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr;
82         } else
83 #endif
84         {
85         memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
86         memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
87
88         r->id.idiag_src[0] = sk->sk_rcv_saddr;
89         r->id.idiag_dst[0] = sk->sk_daddr;
90         }
91 }
92
93 static size_t inet_sk_attr_size(void)
94 {
95         return    nla_total_size(sizeof(struct tcp_info))
96                 + nla_total_size(1) /* INET_DIAG_SHUTDOWN */
97                 + nla_total_size(1) /* INET_DIAG_TOS */
98                 + nla_total_size(1) /* INET_DIAG_TCLASS */
99                 + nla_total_size(sizeof(struct inet_diag_meminfo))
100                 + nla_total_size(sizeof(struct inet_diag_msg))
101                 + nla_total_size(SK_MEMINFO_VARS * sizeof(u32))
102                 + nla_total_size(TCP_CA_NAME_MAX)
103                 + nla_total_size(sizeof(struct tcpvegas_info))
104                 + 64;
105 }
106
107 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
108                       struct sk_buff *skb, const struct inet_diag_req_v2 *req,
109                       struct user_namespace *user_ns,
110                       u32 portid, u32 seq, u16 nlmsg_flags,
111                       const struct nlmsghdr *unlh)
112 {
113         const struct inet_sock *inet = inet_sk(sk);
114         const struct tcp_congestion_ops *ca_ops;
115         const struct inet_diag_handler *handler;
116         int ext = req->idiag_ext;
117         struct inet_diag_msg *r;
118         struct nlmsghdr  *nlh;
119         struct nlattr *attr;
120         void *info = NULL;
121
122         handler = inet_diag_table[req->sdiag_protocol];
123         BUG_ON(!handler);
124
125         nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
126                         nlmsg_flags);
127         if (!nlh)
128                 return -EMSGSIZE;
129
130         r = nlmsg_data(nlh);
131         BUG_ON(!sk_fullsock(sk));
132
133         inet_diag_msg_common_fill(r, sk);
134         r->idiag_state = sk->sk_state;
135         r->idiag_timer = 0;
136         r->idiag_retrans = 0;
137
138         if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
139                 goto errout;
140
141         /* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
142          * hence this needs to be included regardless of socket family.
143          */
144         if (ext & (1 << (INET_DIAG_TOS - 1)))
145                 if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0)
146                         goto errout;
147
148 #if IS_ENABLED(CONFIG_IPV6)
149         if (r->idiag_family == AF_INET6) {
150                 if (ext & (1 << (INET_DIAG_TCLASS - 1)))
151                         if (nla_put_u8(skb, INET_DIAG_TCLASS,
152                                        inet6_sk(sk)->tclass) < 0)
153                                 goto errout;
154         }
155 #endif
156
157         r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
158         r->idiag_inode = sock_i_ino(sk);
159
160         if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
161                 struct inet_diag_meminfo minfo = {
162                         .idiag_rmem = sk_rmem_alloc_get(sk),
163                         .idiag_wmem = sk->sk_wmem_queued,
164                         .idiag_fmem = sk->sk_forward_alloc,
165                         .idiag_tmem = sk_wmem_alloc_get(sk),
166                 };
167
168                 if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
169                         goto errout;
170         }
171
172         if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
173                 if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
174                         goto errout;
175
176         if (!icsk) {
177                 handler->idiag_get_info(sk, r, NULL);
178                 goto out;
179         }
180
181 #define EXPIRES_IN_MS(tmo)  DIV_ROUND_UP((tmo - jiffies) * 1000, HZ)
182
183         if (icsk->icsk_pending == ICSK_TIME_RETRANS ||
184             icsk->icsk_pending == ICSK_TIME_EARLY_RETRANS ||
185             icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) {
186                 r->idiag_timer = 1;
187                 r->idiag_retrans = icsk->icsk_retransmits;
188                 r->idiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout);
189         } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
190                 r->idiag_timer = 4;
191                 r->idiag_retrans = icsk->icsk_probes_out;
192                 r->idiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout);
193         } else if (timer_pending(&sk->sk_timer)) {
194                 r->idiag_timer = 2;
195                 r->idiag_retrans = icsk->icsk_probes_out;
196                 r->idiag_expires = EXPIRES_IN_MS(sk->sk_timer.expires);
197         } else {
198                 r->idiag_timer = 0;
199                 r->idiag_expires = 0;
200         }
201 #undef EXPIRES_IN_MS
202
203         if (ext & (1 << (INET_DIAG_INFO - 1))) {
204                 attr = nla_reserve(skb, INET_DIAG_INFO,
205                                    sizeof(struct tcp_info));
206                 if (!attr)
207                         goto errout;
208
209                 info = nla_data(attr);
210         }
211
212         if (ext & (1 << (INET_DIAG_CONG - 1))) {
213                 int err = 0;
214
215                 rcu_read_lock();
216                 ca_ops = READ_ONCE(icsk->icsk_ca_ops);
217                 if (ca_ops)
218                         err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name);
219                 rcu_read_unlock();
220                 if (err < 0)
221                         goto errout;
222         }
223
224         handler->idiag_get_info(sk, r, info);
225
226         if (sk->sk_state < TCP_TIME_WAIT) {
227                 union tcp_cc_info info;
228                 size_t sz = 0;
229                 int attr;
230
231                 rcu_read_lock();
232                 ca_ops = READ_ONCE(icsk->icsk_ca_ops);
233                 if (ca_ops && ca_ops->get_info)
234                         sz = ca_ops->get_info(sk, ext, &attr, &info);
235                 rcu_read_unlock();
236                 if (sz && nla_put(skb, attr, sz, &info) < 0)
237                         goto errout;
238         }
239
240 out:
241         nlmsg_end(skb, nlh);
242         return 0;
243
244 errout:
245         nlmsg_cancel(skb, nlh);
246         return -EMSGSIZE;
247 }
248 EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
249
250 static int inet_csk_diag_fill(struct sock *sk,
251                               struct sk_buff *skb,
252                               const struct inet_diag_req_v2 *req,
253                               struct user_namespace *user_ns,
254                               u32 portid, u32 seq, u16 nlmsg_flags,
255                               const struct nlmsghdr *unlh)
256 {
257         return inet_sk_diag_fill(sk, inet_csk(sk), skb, req,
258                                  user_ns, portid, seq, nlmsg_flags, unlh);
259 }
260
261 static int inet_twsk_diag_fill(struct sock *sk,
262                                struct sk_buff *skb,
263                                u32 portid, u32 seq, u16 nlmsg_flags,
264                                const struct nlmsghdr *unlh)
265 {
266         struct inet_timewait_sock *tw = inet_twsk(sk);
267         struct inet_diag_msg *r;
268         struct nlmsghdr *nlh;
269         long tmo;
270
271         nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
272                         nlmsg_flags);
273         if (!nlh)
274                 return -EMSGSIZE;
275
276         r = nlmsg_data(nlh);
277         BUG_ON(tw->tw_state != TCP_TIME_WAIT);
278
279         tmo = tw->tw_timer.expires - jiffies;
280         if (tmo < 0)
281                 tmo = 0;
282
283         inet_diag_msg_common_fill(r, sk);
284         r->idiag_retrans      = 0;
285
286         r->idiag_state        = tw->tw_substate;
287         r->idiag_timer        = 3;
288         r->idiag_expires      = jiffies_to_msecs(tmo);
289         r->idiag_rqueue       = 0;
290         r->idiag_wqueue       = 0;
291         r->idiag_uid          = 0;
292         r->idiag_inode        = 0;
293
294         nlmsg_end(skb, nlh);
295         return 0;
296 }
297
298 static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
299                               u32 portid, u32 seq, u16 nlmsg_flags,
300                               const struct nlmsghdr *unlh)
301 {
302         struct inet_diag_msg *r;
303         struct nlmsghdr *nlh;
304         long tmo;
305
306         nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
307                         nlmsg_flags);
308         if (!nlh)
309                 return -EMSGSIZE;
310
311         r = nlmsg_data(nlh);
312         inet_diag_msg_common_fill(r, sk);
313         r->idiag_state = TCP_SYN_RECV;
314         r->idiag_timer = 1;
315         r->idiag_retrans = inet_reqsk(sk)->num_retrans;
316
317         BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
318                      offsetof(struct sock, sk_cookie));
319
320         tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies;
321         r->idiag_expires = (tmo >= 0) ? jiffies_to_msecs(tmo) : 0;
322         r->idiag_rqueue = 0;
323         r->idiag_wqueue = 0;
324         r->idiag_uid    = 0;
325         r->idiag_inode  = 0;
326
327         nlmsg_end(skb, nlh);
328         return 0;
329 }
330
331 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
332                         const struct inet_diag_req_v2 *r,
333                         struct user_namespace *user_ns,
334                         u32 portid, u32 seq, u16 nlmsg_flags,
335                         const struct nlmsghdr *unlh)
336 {
337         if (sk->sk_state == TCP_TIME_WAIT)
338                 return inet_twsk_diag_fill(sk, skb, portid, seq,
339                                            nlmsg_flags, unlh);
340
341         if (sk->sk_state == TCP_NEW_SYN_RECV)
342                 return inet_req_diag_fill(sk, skb, portid, seq,
343                                           nlmsg_flags, unlh);
344
345         return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq,
346                                   nlmsg_flags, unlh);
347 }
348
349 int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
350                             struct sk_buff *in_skb,
351                             const struct nlmsghdr *nlh,
352                             const struct inet_diag_req_v2 *req)
353 {
354         struct net *net = sock_net(in_skb->sk);
355         struct sk_buff *rep;
356         struct sock *sk;
357         int err;
358
359         err = -EINVAL;
360         if (req->sdiag_family == AF_INET)
361                 sk = inet_lookup(net, hashinfo, req->id.idiag_dst[0],
362                                  req->id.idiag_dport, req->id.idiag_src[0],
363                                  req->id.idiag_sport, req->id.idiag_if);
364 #if IS_ENABLED(CONFIG_IPV6)
365         else if (req->sdiag_family == AF_INET6)
366                 sk = inet6_lookup(net, hashinfo,
367                                   (struct in6_addr *)req->id.idiag_dst,
368                                   req->id.idiag_dport,
369                                   (struct in6_addr *)req->id.idiag_src,
370                                   req->id.idiag_sport,
371                                   req->id.idiag_if);
372 #endif
373         else
374                 goto out_nosk;
375
376         err = -ENOENT;
377         if (!sk)
378                 goto out_nosk;
379
380         err = sock_diag_check_cookie(sk, req->id.idiag_cookie);
381         if (err)
382                 goto out;
383
384         rep = nlmsg_new(inet_sk_attr_size(), GFP_KERNEL);
385         if (!rep) {
386                 err = -ENOMEM;
387                 goto out;
388         }
389
390         err = sk_diag_fill(sk, rep, req,
391                            sk_user_ns(NETLINK_CB(in_skb).sk),
392                            NETLINK_CB(in_skb).portid,
393                            nlh->nlmsg_seq, 0, nlh);
394         if (err < 0) {
395                 WARN_ON(err == -EMSGSIZE);
396                 nlmsg_free(rep);
397                 goto out;
398         }
399         err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid,
400                               MSG_DONTWAIT);
401         if (err > 0)
402                 err = 0;
403
404 out:
405         if (sk)
406                 sock_gen_put(sk);
407
408 out_nosk:
409         return err;
410 }
411 EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
412
413 static int inet_diag_get_exact(struct sk_buff *in_skb,
414                                const struct nlmsghdr *nlh,
415                                const struct inet_diag_req_v2 *req)
416 {
417         const struct inet_diag_handler *handler;
418         int err;
419
420         handler = inet_diag_lock_handler(req->sdiag_protocol);
421         if (IS_ERR(handler))
422                 err = PTR_ERR(handler);
423         else
424                 err = handler->dump_one(in_skb, nlh, req);
425         inet_diag_unlock_handler(handler);
426
427         return err;
428 }
429
430 static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
431 {
432         int words = bits >> 5;
433
434         bits &= 0x1f;
435
436         if (words) {
437                 if (memcmp(a1, a2, words << 2))
438                         return 0;
439         }
440         if (bits) {
441                 __be32 w1, w2;
442                 __be32 mask;
443
444                 w1 = a1[words];
445                 w2 = a2[words];
446
447                 mask = htonl((0xffffffff) << (32 - bits));
448
449                 if ((w1 ^ w2) & mask)
450                         return 0;
451         }
452
453         return 1;
454 }
455
456 static int inet_diag_bc_run(const struct nlattr *_bc,
457                             const struct inet_diag_entry *entry)
458 {
459         const void *bc = nla_data(_bc);
460         int len = nla_len(_bc);
461
462         while (len > 0) {
463                 int yes = 1;
464                 const struct inet_diag_bc_op *op = bc;
465
466                 switch (op->code) {
467                 case INET_DIAG_BC_NOP:
468                         break;
469                 case INET_DIAG_BC_JMP:
470                         yes = 0;
471                         break;
472                 case INET_DIAG_BC_S_GE:
473                         yes = entry->sport >= op[1].no;
474                         break;
475                 case INET_DIAG_BC_S_LE:
476                         yes = entry->sport <= op[1].no;
477                         break;
478                 case INET_DIAG_BC_D_GE:
479                         yes = entry->dport >= op[1].no;
480                         break;
481                 case INET_DIAG_BC_D_LE:
482                         yes = entry->dport <= op[1].no;
483                         break;
484                 case INET_DIAG_BC_AUTO:
485                         yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
486                         break;
487                 case INET_DIAG_BC_S_COND:
488                 case INET_DIAG_BC_D_COND: {
489                         const struct inet_diag_hostcond *cond;
490                         const __be32 *addr;
491
492                         cond = (const struct inet_diag_hostcond *)(op + 1);
493                         if (cond->port != -1 &&
494                             cond->port != (op->code == INET_DIAG_BC_S_COND ?
495                                              entry->sport : entry->dport)) {
496                                 yes = 0;
497                                 break;
498                         }
499
500                         if (op->code == INET_DIAG_BC_S_COND)
501                                 addr = entry->saddr;
502                         else
503                                 addr = entry->daddr;
504
505                         if (cond->family != AF_UNSPEC &&
506                             cond->family != entry->family) {
507                                 if (entry->family == AF_INET6 &&
508                                     cond->family == AF_INET) {
509                                         if (addr[0] == 0 && addr[1] == 0 &&
510                                             addr[2] == htonl(0xffff) &&
511                                             bitstring_match(addr + 3,
512                                                             cond->addr,
513                                                             cond->prefix_len))
514                                                 break;
515                                 }
516                                 yes = 0;
517                                 break;
518                         }
519
520                         if (cond->prefix_len == 0)
521                                 break;
522                         if (bitstring_match(addr, cond->addr,
523                                             cond->prefix_len))
524                                 break;
525                         yes = 0;
526                         break;
527                 }
528                 }
529
530                 if (yes) {
531                         len -= op->yes;
532                         bc += op->yes;
533                 } else {
534                         len -= op->no;
535                         bc += op->no;
536                 }
537         }
538         return len == 0;
539 }
540
541 /* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV)
542  */
543 static void entry_fill_addrs(struct inet_diag_entry *entry,
544                              const struct sock *sk)
545 {
546 #if IS_ENABLED(CONFIG_IPV6)
547         if (sk->sk_family == AF_INET6) {
548                 entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32;
549                 entry->daddr = sk->sk_v6_daddr.s6_addr32;
550         } else
551 #endif
552         {
553                 entry->saddr = &sk->sk_rcv_saddr;
554                 entry->daddr = &sk->sk_daddr;
555         }
556 }
557
558 int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
559 {
560         struct inet_sock *inet = inet_sk(sk);
561         struct inet_diag_entry entry;
562
563         if (!bc)
564                 return 1;
565
566         entry.family = sk->sk_family;
567         entry_fill_addrs(&entry, sk);
568         entry.sport = inet->inet_num;
569         entry.dport = ntohs(inet->inet_dport);
570         entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
571
572         return inet_diag_bc_run(bc, &entry);
573 }
574 EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
575
576 static int valid_cc(const void *bc, int len, int cc)
577 {
578         while (len >= 0) {
579                 const struct inet_diag_bc_op *op = bc;
580
581                 if (cc > len)
582                         return 0;
583                 if (cc == len)
584                         return 1;
585                 if (op->yes < 4 || op->yes & 3)
586                         return 0;
587                 len -= op->yes;
588                 bc  += op->yes;
589         }
590         return 0;
591 }
592
593 /* Validate an inet_diag_hostcond. */
594 static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
595                            int *min_len)
596 {
597         struct inet_diag_hostcond *cond;
598         int addr_len;
599
600         /* Check hostcond space. */
601         *min_len += sizeof(struct inet_diag_hostcond);
602         if (len < *min_len)
603                 return false;
604         cond = (struct inet_diag_hostcond *)(op + 1);
605
606         /* Check address family and address length. */
607         switch (cond->family) {
608         case AF_UNSPEC:
609                 addr_len = 0;
610                 break;
611         case AF_INET:
612                 addr_len = sizeof(struct in_addr);
613                 break;
614         case AF_INET6:
615                 addr_len = sizeof(struct in6_addr);
616                 break;
617         default:
618                 return false;
619         }
620         *min_len += addr_len;
621         if (len < *min_len)
622                 return false;
623
624         /* Check prefix length (in bits) vs address length (in bytes). */
625         if (cond->prefix_len > 8 * addr_len)
626                 return false;
627
628         return true;
629 }
630
631 /* Validate a port comparison operator. */
632 static bool valid_port_comparison(const struct inet_diag_bc_op *op,
633                                   int len, int *min_len)
634 {
635         /* Port comparisons put the port in a follow-on inet_diag_bc_op. */
636         *min_len += sizeof(struct inet_diag_bc_op);
637         if (len < *min_len)
638                 return false;
639         return true;
640 }
641
642 static int inet_diag_bc_audit(const void *bytecode, int bytecode_len)
643 {
644         const void *bc = bytecode;
645         int  len = bytecode_len;
646
647         while (len > 0) {
648                 int min_len = sizeof(struct inet_diag_bc_op);
649                 const struct inet_diag_bc_op *op = bc;
650
651                 switch (op->code) {
652                 case INET_DIAG_BC_S_COND:
653                 case INET_DIAG_BC_D_COND:
654                         if (!valid_hostcond(bc, len, &min_len))
655                                 return -EINVAL;
656                         break;
657                 case INET_DIAG_BC_S_GE:
658                 case INET_DIAG_BC_S_LE:
659                 case INET_DIAG_BC_D_GE:
660                 case INET_DIAG_BC_D_LE:
661                         if (!valid_port_comparison(bc, len, &min_len))
662                                 return -EINVAL;
663                         break;
664                 case INET_DIAG_BC_AUTO:
665                 case INET_DIAG_BC_JMP:
666                 case INET_DIAG_BC_NOP:
667                         break;
668                 default:
669                         return -EINVAL;
670                 }
671
672                 if (op->code != INET_DIAG_BC_NOP) {
673                         if (op->no < min_len || op->no > len + 4 || op->no & 3)
674                                 return -EINVAL;
675                         if (op->no < len &&
676                             !valid_cc(bytecode, bytecode_len, len - op->no))
677                                 return -EINVAL;
678                 }
679
680                 if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
681                         return -EINVAL;
682                 bc  += op->yes;
683                 len -= op->yes;
684         }
685         return len == 0 ? 0 : -EINVAL;
686 }
687
688 static int inet_csk_diag_dump(struct sock *sk,
689                               struct sk_buff *skb,
690                               struct netlink_callback *cb,
691                               const struct inet_diag_req_v2 *r,
692                               const struct nlattr *bc)
693 {
694         if (!inet_diag_bc_sk(bc, sk))
695                 return 0;
696
697         return inet_csk_diag_fill(sk, skb, r,
698                                   sk_user_ns(NETLINK_CB(cb->skb).sk),
699                                   NETLINK_CB(cb->skb).portid,
700                                   cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
701 }
702
703 static void twsk_build_assert(void)
704 {
705         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) !=
706                      offsetof(struct sock, sk_family));
707
708         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) !=
709                      offsetof(struct inet_sock, inet_num));
710
711         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) !=
712                      offsetof(struct inet_sock, inet_dport));
713
714         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) !=
715                      offsetof(struct inet_sock, inet_rcv_saddr));
716
717         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) !=
718                      offsetof(struct inet_sock, inet_daddr));
719
720 #if IS_ENABLED(CONFIG_IPV6)
721         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) !=
722                      offsetof(struct sock, sk_v6_rcv_saddr));
723
724         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) !=
725                      offsetof(struct sock, sk_v6_daddr));
726 #endif
727 }
728
729 static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
730                                struct netlink_callback *cb,
731                                const struct inet_diag_req_v2 *r,
732                                const struct nlattr *bc)
733 {
734         struct inet_connection_sock *icsk = inet_csk(sk);
735         struct inet_sock *inet = inet_sk(sk);
736         struct inet_diag_entry entry;
737         int j, s_j, reqnum, s_reqnum;
738         struct listen_sock *lopt;
739         int err = 0;
740
741         s_j = cb->args[3];
742         s_reqnum = cb->args[4];
743
744         if (s_j > 0)
745                 s_j--;
746
747         entry.family = sk->sk_family;
748
749         spin_lock_bh(&icsk->icsk_accept_queue.syn_wait_lock);
750
751         lopt = icsk->icsk_accept_queue.listen_opt;
752         if (!lopt || !listen_sock_qlen(lopt))
753                 goto out;
754
755         if (bc) {
756                 entry.sport = inet->inet_num;
757                 entry.userlocks = sk->sk_userlocks;
758         }
759
760         for (j = s_j; j < lopt->nr_table_entries; j++) {
761                 struct request_sock *req, *head = lopt->syn_table[j];
762
763                 reqnum = 0;
764                 for (req = head; req; reqnum++, req = req->dl_next) {
765                         struct inet_request_sock *ireq = inet_rsk(req);
766
767                         if (reqnum < s_reqnum)
768                                 continue;
769                         if (r->id.idiag_dport != ireq->ir_rmt_port &&
770                             r->id.idiag_dport)
771                                 continue;
772
773                         if (bc) {
774                                 /* Note: entry.sport and entry.userlocks are already set */
775                                 entry_fill_addrs(&entry, req_to_sk(req));
776                                 entry.dport = ntohs(ireq->ir_rmt_port);
777
778                                 if (!inet_diag_bc_run(bc, &entry))
779                                         continue;
780                         }
781
782                         err = inet_req_diag_fill(req_to_sk(req), skb,
783                                                  NETLINK_CB(cb->skb).portid,
784                                                  cb->nlh->nlmsg_seq,
785                                                  NLM_F_MULTI, cb->nlh);
786                         if (err < 0) {
787                                 cb->args[3] = j + 1;
788                                 cb->args[4] = reqnum;
789                                 goto out;
790                         }
791                 }
792
793                 s_reqnum = 0;
794         }
795
796 out:
797         spin_unlock_bh(&icsk->icsk_accept_queue.syn_wait_lock);
798
799         return err;
800 }
801
802 void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
803                          struct netlink_callback *cb,
804                          const struct inet_diag_req_v2 *r, struct nlattr *bc)
805 {
806         struct net *net = sock_net(skb->sk);
807         int i, num, s_i, s_num;
808
809         s_i = cb->args[1];
810         s_num = num = cb->args[2];
811
812         if (cb->args[0] == 0) {
813                 if (!(r->idiag_states & (TCPF_LISTEN | TCPF_SYN_RECV)))
814                         goto skip_listen_ht;
815
816                 for (i = s_i; i < INET_LHTABLE_SIZE; i++) {
817                         struct inet_listen_hashbucket *ilb;
818                         struct hlist_nulls_node *node;
819                         struct sock *sk;
820
821                         num = 0;
822                         ilb = &hashinfo->listening_hash[i];
823                         spin_lock_bh(&ilb->lock);
824                         sk_nulls_for_each(sk, node, &ilb->head) {
825                                 struct inet_sock *inet = inet_sk(sk);
826
827                                 if (!net_eq(sock_net(sk), net))
828                                         continue;
829
830                                 if (num < s_num) {
831                                         num++;
832                                         continue;
833                                 }
834
835                                 if (r->sdiag_family != AF_UNSPEC &&
836                                     sk->sk_family != r->sdiag_family)
837                                         goto next_listen;
838
839                                 if (r->id.idiag_sport != inet->inet_sport &&
840                                     r->id.idiag_sport)
841                                         goto next_listen;
842
843                                 if (!(r->idiag_states & TCPF_LISTEN) ||
844                                     r->id.idiag_dport ||
845                                     cb->args[3] > 0)
846                                         goto syn_recv;
847
848                                 if (inet_csk_diag_dump(sk, skb, cb, r, bc) < 0) {
849                                         spin_unlock_bh(&ilb->lock);
850                                         goto done;
851                                 }
852
853 syn_recv:
854                                 if (!(r->idiag_states & TCPF_SYN_RECV))
855                                         goto next_listen;
856
857                                 if (inet_diag_dump_reqs(skb, sk, cb, r, bc) < 0) {
858                                         spin_unlock_bh(&ilb->lock);
859                                         goto done;
860                                 }
861
862 next_listen:
863                                 cb->args[3] = 0;
864                                 cb->args[4] = 0;
865                                 ++num;
866                         }
867                         spin_unlock_bh(&ilb->lock);
868
869                         s_num = 0;
870                         cb->args[3] = 0;
871                         cb->args[4] = 0;
872                 }
873 skip_listen_ht:
874                 cb->args[0] = 1;
875                 s_i = num = s_num = 0;
876         }
877
878         if (!(r->idiag_states & ~(TCPF_LISTEN | TCPF_SYN_RECV)))
879                 goto out;
880
881         for (i = s_i; i <= hashinfo->ehash_mask; i++) {
882                 struct inet_ehash_bucket *head = &hashinfo->ehash[i];
883                 spinlock_t *lock = inet_ehash_lockp(hashinfo, i);
884                 struct hlist_nulls_node *node;
885                 struct sock *sk;
886
887                 num = 0;
888
889                 if (hlist_nulls_empty(&head->chain))
890                         continue;
891
892                 if (i > s_i)
893                         s_num = 0;
894
895                 spin_lock_bh(lock);
896                 sk_nulls_for_each(sk, node, &head->chain) {
897                         int state, res;
898
899                         if (!net_eq(sock_net(sk), net))
900                                 continue;
901                         if (num < s_num)
902                                 goto next_normal;
903                         state = (sk->sk_state == TCP_TIME_WAIT) ?
904                                 inet_twsk(sk)->tw_substate : sk->sk_state;
905                         if (!(r->idiag_states & (1 << state)))
906                                 goto next_normal;
907                         if (r->sdiag_family != AF_UNSPEC &&
908                             sk->sk_family != r->sdiag_family)
909                                 goto next_normal;
910                         if (r->id.idiag_sport != htons(sk->sk_num) &&
911                             r->id.idiag_sport)
912                                 goto next_normal;
913                         if (r->id.idiag_dport != sk->sk_dport &&
914                             r->id.idiag_dport)
915                                 goto next_normal;
916                         twsk_build_assert();
917
918                         if (!inet_diag_bc_sk(bc, sk))
919                                 goto next_normal;
920
921                         res = sk_diag_fill(sk, skb, r,
922                                            sk_user_ns(NETLINK_CB(cb->skb).sk),
923                                            NETLINK_CB(cb->skb).portid,
924                                            cb->nlh->nlmsg_seq, NLM_F_MULTI,
925                                            cb->nlh);
926                         if (res < 0) {
927                                 spin_unlock_bh(lock);
928                                 goto done;
929                         }
930 next_normal:
931                         ++num;
932                 }
933
934                 spin_unlock_bh(lock);
935         }
936
937 done:
938         cb->args[1] = i;
939         cb->args[2] = num;
940 out:
941         ;
942 }
943 EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
944
945 static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
946                             const struct inet_diag_req_v2 *r,
947                             struct nlattr *bc)
948 {
949         const struct inet_diag_handler *handler;
950         int err = 0;
951
952         handler = inet_diag_lock_handler(r->sdiag_protocol);
953         if (!IS_ERR(handler))
954                 handler->dump(skb, cb, r, bc);
955         else
956                 err = PTR_ERR(handler);
957         inet_diag_unlock_handler(handler);
958
959         return err ? : skb->len;
960 }
961
962 static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
963 {
964         int hdrlen = sizeof(struct inet_diag_req_v2);
965         struct nlattr *bc = NULL;
966
967         if (nlmsg_attrlen(cb->nlh, hdrlen))
968                 bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
969
970         return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh), bc);
971 }
972
973 static int inet_diag_type2proto(int type)
974 {
975         switch (type) {
976         case TCPDIAG_GETSOCK:
977                 return IPPROTO_TCP;
978         case DCCPDIAG_GETSOCK:
979                 return IPPROTO_DCCP;
980         default:
981                 return 0;
982         }
983 }
984
985 static int inet_diag_dump_compat(struct sk_buff *skb,
986                                  struct netlink_callback *cb)
987 {
988         struct inet_diag_req *rc = nlmsg_data(cb->nlh);
989         int hdrlen = sizeof(struct inet_diag_req);
990         struct inet_diag_req_v2 req;
991         struct nlattr *bc = NULL;
992
993         req.sdiag_family = AF_UNSPEC; /* compatibility */
994         req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
995         req.idiag_ext = rc->idiag_ext;
996         req.idiag_states = rc->idiag_states;
997         req.id = rc->id;
998
999         if (nlmsg_attrlen(cb->nlh, hdrlen))
1000                 bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
1001
1002         return __inet_diag_dump(skb, cb, &req, bc);
1003 }
1004
1005 static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
1006                                       const struct nlmsghdr *nlh)
1007 {
1008         struct inet_diag_req *rc = nlmsg_data(nlh);
1009         struct inet_diag_req_v2 req;
1010
1011         req.sdiag_family = rc->idiag_family;
1012         req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
1013         req.idiag_ext = rc->idiag_ext;
1014         req.idiag_states = rc->idiag_states;
1015         req.id = rc->id;
1016
1017         return inet_diag_get_exact(in_skb, nlh, &req);
1018 }
1019
1020 static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
1021 {
1022         int hdrlen = sizeof(struct inet_diag_req);
1023         struct net *net = sock_net(skb->sk);
1024
1025         if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
1026             nlmsg_len(nlh) < hdrlen)
1027                 return -EINVAL;
1028
1029         if (nlh->nlmsg_flags & NLM_F_DUMP) {
1030                 if (nlmsg_attrlen(nlh, hdrlen)) {
1031                         struct nlattr *attr;
1032
1033                         attr = nlmsg_find_attr(nlh, hdrlen,
1034                                                INET_DIAG_REQ_BYTECODE);
1035                         if (!attr ||
1036                             nla_len(attr) < sizeof(struct inet_diag_bc_op) ||
1037                             inet_diag_bc_audit(nla_data(attr), nla_len(attr)))
1038                                 return -EINVAL;
1039                 }
1040                 {
1041                         struct netlink_dump_control c = {
1042                                 .dump = inet_diag_dump_compat,
1043                         };
1044                         return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
1045                 }
1046         }
1047
1048         return inet_diag_get_exact_compat(skb, nlh);
1049 }
1050
1051 static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
1052 {
1053         int hdrlen = sizeof(struct inet_diag_req_v2);
1054         struct net *net = sock_net(skb->sk);
1055
1056         if (nlmsg_len(h) < hdrlen)
1057                 return -EINVAL;
1058
1059         if (h->nlmsg_flags & NLM_F_DUMP) {
1060                 if (nlmsg_attrlen(h, hdrlen)) {
1061                         struct nlattr *attr;
1062
1063                         attr = nlmsg_find_attr(h, hdrlen,
1064                                                INET_DIAG_REQ_BYTECODE);
1065                         if (!attr ||
1066                             nla_len(attr) < sizeof(struct inet_diag_bc_op) ||
1067                             inet_diag_bc_audit(nla_data(attr), nla_len(attr)))
1068                                 return -EINVAL;
1069                 }
1070                 {
1071                         struct netlink_dump_control c = {
1072                                 .dump = inet_diag_dump,
1073                         };
1074                         return netlink_dump_start(net->diag_nlsk, skb, h, &c);
1075                 }
1076         }
1077
1078         return inet_diag_get_exact(skb, h, nlmsg_data(h));
1079 }
1080
1081 static const struct sock_diag_handler inet_diag_handler = {
1082         .family = AF_INET,
1083         .dump = inet_diag_handler_dump,
1084 };
1085
1086 static const struct sock_diag_handler inet6_diag_handler = {
1087         .family = AF_INET6,
1088         .dump = inet_diag_handler_dump,
1089 };
1090
1091 int inet_diag_register(const struct inet_diag_handler *h)
1092 {
1093         const __u16 type = h->idiag_type;
1094         int err = -EINVAL;
1095
1096         if (type >= IPPROTO_MAX)
1097                 goto out;
1098
1099         mutex_lock(&inet_diag_table_mutex);
1100         err = -EEXIST;
1101         if (!inet_diag_table[type]) {
1102                 inet_diag_table[type] = h;
1103                 err = 0;
1104         }
1105         mutex_unlock(&inet_diag_table_mutex);
1106 out:
1107         return err;
1108 }
1109 EXPORT_SYMBOL_GPL(inet_diag_register);
1110
1111 void inet_diag_unregister(const struct inet_diag_handler *h)
1112 {
1113         const __u16 type = h->idiag_type;
1114
1115         if (type >= IPPROTO_MAX)
1116                 return;
1117
1118         mutex_lock(&inet_diag_table_mutex);
1119         inet_diag_table[type] = NULL;
1120         mutex_unlock(&inet_diag_table_mutex);
1121 }
1122 EXPORT_SYMBOL_GPL(inet_diag_unregister);
1123
1124 static int __init inet_diag_init(void)
1125 {
1126         const int inet_diag_table_size = (IPPROTO_MAX *
1127                                           sizeof(struct inet_diag_handler *));
1128         int err = -ENOMEM;
1129
1130         inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1131         if (!inet_diag_table)
1132                 goto out;
1133
1134         err = sock_diag_register(&inet_diag_handler);
1135         if (err)
1136                 goto out_free_nl;
1137
1138         err = sock_diag_register(&inet6_diag_handler);
1139         if (err)
1140                 goto out_free_inet;
1141
1142         sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
1143 out:
1144         return err;
1145
1146 out_free_inet:
1147         sock_diag_unregister(&inet_diag_handler);
1148 out_free_nl:
1149         kfree(inet_diag_table);
1150         goto out;
1151 }
1152
1153 static void __exit inet_diag_exit(void)
1154 {
1155         sock_diag_unregister(&inet6_diag_handler);
1156         sock_diag_unregister(&inet_diag_handler);
1157         sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
1158         kfree(inet_diag_table);
1159 }
1160
1161 module_init(inet_diag_init);
1162 module_exit(inet_diag_exit);
1163 MODULE_LICENSE("GPL");
1164 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1165 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);