Kernel bump from 4.1.3-rt to 4.1.7-rt.
[kvmfornfv.git] / kernel / arch / x86 / mm / mpx.c
1 /*
2  * mpx.c - Memory Protection eXtensions
3  *
4  * Copyright (c) 2014, Intel Corporation.
5  * Qiaowei Ren <qiaowei.ren@intel.com>
6  * Dave Hansen <dave.hansen@intel.com>
7  */
8 #include <linux/kernel.h>
9 #include <linux/slab.h>
10 #include <linux/syscalls.h>
11 #include <linux/sched/sysctl.h>
12
13 #include <asm/i387.h>
14 #include <asm/insn.h>
15 #include <asm/mman.h>
16 #include <asm/mmu_context.h>
17 #include <asm/mpx.h>
18 #include <asm/processor.h>
19 #include <asm/fpu-internal.h>
20
21 /*
22  * This is really a simplified "vm_mmap". it only handles MPX
23  * bounds tables (the bounds directory is user-allocated).
24  */
25 static unsigned long mpx_mmap(unsigned long len)
26 {
27         unsigned long ret;
28         unsigned long addr, pgoff;
29         struct mm_struct *mm = current->mm;
30         vm_flags_t vm_flags;
31         struct vm_area_struct *vma;
32
33         /* Only bounds table and bounds directory can be allocated here */
34         if (len != MPX_BD_SIZE_BYTES && len != MPX_BT_SIZE_BYTES)
35                 return -EINVAL;
36
37         down_write(&mm->mmap_sem);
38
39         /* Too many mappings? */
40         if (mm->map_count > sysctl_max_map_count) {
41                 ret = -ENOMEM;
42                 goto out;
43         }
44
45         /* Obtain the address to map to. we verify (or select) it and ensure
46          * that it represents a valid section of the address space.
47          */
48         addr = get_unmapped_area(NULL, 0, len, 0, MAP_ANONYMOUS | MAP_PRIVATE);
49         if (addr & ~PAGE_MASK) {
50                 ret = addr;
51                 goto out;
52         }
53
54         vm_flags = VM_READ | VM_WRITE | VM_MPX |
55                         mm->def_flags | VM_MAYREAD | VM_MAYWRITE | VM_MAYEXEC;
56
57         /* Set pgoff according to addr for anon_vma */
58         pgoff = addr >> PAGE_SHIFT;
59
60         ret = mmap_region(NULL, addr, len, vm_flags, pgoff);
61         if (IS_ERR_VALUE(ret))
62                 goto out;
63
64         vma = find_vma(mm, ret);
65         if (!vma) {
66                 ret = -ENOMEM;
67                 goto out;
68         }
69
70         if (vm_flags & VM_LOCKED) {
71                 up_write(&mm->mmap_sem);
72                 mm_populate(ret, len);
73                 return ret;
74         }
75
76 out:
77         up_write(&mm->mmap_sem);
78         return ret;
79 }
80
81 enum reg_type {
82         REG_TYPE_RM = 0,
83         REG_TYPE_INDEX,
84         REG_TYPE_BASE,
85 };
86
87 static int get_reg_offset(struct insn *insn, struct pt_regs *regs,
88                           enum reg_type type)
89 {
90         int regno = 0;
91
92         static const int regoff[] = {
93                 offsetof(struct pt_regs, ax),
94                 offsetof(struct pt_regs, cx),
95                 offsetof(struct pt_regs, dx),
96                 offsetof(struct pt_regs, bx),
97                 offsetof(struct pt_regs, sp),
98                 offsetof(struct pt_regs, bp),
99                 offsetof(struct pt_regs, si),
100                 offsetof(struct pt_regs, di),
101 #ifdef CONFIG_X86_64
102                 offsetof(struct pt_regs, r8),
103                 offsetof(struct pt_regs, r9),
104                 offsetof(struct pt_regs, r10),
105                 offsetof(struct pt_regs, r11),
106                 offsetof(struct pt_regs, r12),
107                 offsetof(struct pt_regs, r13),
108                 offsetof(struct pt_regs, r14),
109                 offsetof(struct pt_regs, r15),
110 #endif
111         };
112         int nr_registers = ARRAY_SIZE(regoff);
113         /*
114          * Don't possibly decode a 32-bit instructions as
115          * reading a 64-bit-only register.
116          */
117         if (IS_ENABLED(CONFIG_X86_64) && !insn->x86_64)
118                 nr_registers -= 8;
119
120         switch (type) {
121         case REG_TYPE_RM:
122                 regno = X86_MODRM_RM(insn->modrm.value);
123                 if (X86_REX_B(insn->rex_prefix.value) == 1)
124                         regno += 8;
125                 break;
126
127         case REG_TYPE_INDEX:
128                 regno = X86_SIB_INDEX(insn->sib.value);
129                 if (X86_REX_X(insn->rex_prefix.value) == 1)
130                         regno += 8;
131                 break;
132
133         case REG_TYPE_BASE:
134                 regno = X86_SIB_BASE(insn->sib.value);
135                 if (X86_REX_B(insn->rex_prefix.value) == 1)
136                         regno += 8;
137                 break;
138
139         default:
140                 pr_err("invalid register type");
141                 BUG();
142                 break;
143         }
144
145         if (regno > nr_registers) {
146                 WARN_ONCE(1, "decoded an instruction with an invalid register");
147                 return -EINVAL;
148         }
149         return regoff[regno];
150 }
151
152 /*
153  * return the address being referenced be instruction
154  * for rm=3 returning the content of the rm reg
155  * for rm!=3 calculates the address using SIB and Disp
156  */
157 static void __user *mpx_get_addr_ref(struct insn *insn, struct pt_regs *regs)
158 {
159         unsigned long addr, base, indx;
160         int addr_offset, base_offset, indx_offset;
161         insn_byte_t sib;
162
163         insn_get_modrm(insn);
164         insn_get_sib(insn);
165         sib = insn->sib.value;
166
167         if (X86_MODRM_MOD(insn->modrm.value) == 3) {
168                 addr_offset = get_reg_offset(insn, regs, REG_TYPE_RM);
169                 if (addr_offset < 0)
170                         goto out_err;
171                 addr = regs_get_register(regs, addr_offset);
172         } else {
173                 if (insn->sib.nbytes) {
174                         base_offset = get_reg_offset(insn, regs, REG_TYPE_BASE);
175                         if (base_offset < 0)
176                                 goto out_err;
177
178                         indx_offset = get_reg_offset(insn, regs, REG_TYPE_INDEX);
179                         if (indx_offset < 0)
180                                 goto out_err;
181
182                         base = regs_get_register(regs, base_offset);
183                         indx = regs_get_register(regs, indx_offset);
184                         addr = base + indx * (1 << X86_SIB_SCALE(sib));
185                 } else {
186                         addr_offset = get_reg_offset(insn, regs, REG_TYPE_RM);
187                         if (addr_offset < 0)
188                                 goto out_err;
189                         addr = regs_get_register(regs, addr_offset);
190                 }
191                 addr += insn->displacement.value;
192         }
193         return (void __user *)addr;
194 out_err:
195         return (void __user *)-1;
196 }
197
198 static int mpx_insn_decode(struct insn *insn,
199                            struct pt_regs *regs)
200 {
201         unsigned char buf[MAX_INSN_SIZE];
202         int x86_64 = !test_thread_flag(TIF_IA32);
203         int not_copied;
204         int nr_copied;
205
206         not_copied = copy_from_user(buf, (void __user *)regs->ip, sizeof(buf));
207         nr_copied = sizeof(buf) - not_copied;
208         /*
209          * The decoder _should_ fail nicely if we pass it a short buffer.
210          * But, let's not depend on that implementation detail.  If we
211          * did not get anything, just error out now.
212          */
213         if (!nr_copied)
214                 return -EFAULT;
215         insn_init(insn, buf, nr_copied, x86_64);
216         insn_get_length(insn);
217         /*
218          * copy_from_user() tries to get as many bytes as we could see in
219          * the largest possible instruction.  If the instruction we are
220          * after is shorter than that _and_ we attempt to copy from
221          * something unreadable, we might get a short read.  This is OK
222          * as long as the read did not stop in the middle of the
223          * instruction.  Check to see if we got a partial instruction.
224          */
225         if (nr_copied < insn->length)
226                 return -EFAULT;
227
228         insn_get_opcode(insn);
229         /*
230          * We only _really_ need to decode bndcl/bndcn/bndcu
231          * Error out on anything else.
232          */
233         if (insn->opcode.bytes[0] != 0x0f)
234                 goto bad_opcode;
235         if ((insn->opcode.bytes[1] != 0x1a) &&
236             (insn->opcode.bytes[1] != 0x1b))
237                 goto bad_opcode;
238
239         return 0;
240 bad_opcode:
241         return -EINVAL;
242 }
243
244 /*
245  * If a bounds overflow occurs then a #BR is generated. This
246  * function decodes MPX instructions to get violation address
247  * and set this address into extended struct siginfo.
248  *
249  * Note that this is not a super precise way of doing this.
250  * Userspace could have, by the time we get here, written
251  * anything it wants in to the instructions.  We can not
252  * trust anything about it.  They might not be valid
253  * instructions or might encode invalid registers, etc...
254  *
255  * The caller is expected to kfree() the returned siginfo_t.
256  */
257 siginfo_t *mpx_generate_siginfo(struct pt_regs *regs,
258                                 struct xsave_struct *xsave_buf)
259 {
260         struct bndreg *bndregs, *bndreg;
261         siginfo_t *info = NULL;
262         struct insn insn;
263         uint8_t bndregno;
264         int err;
265
266         err = mpx_insn_decode(&insn, regs);
267         if (err)
268                 goto err_out;
269
270         /*
271          * We know at this point that we are only dealing with
272          * MPX instructions.
273          */
274         insn_get_modrm(&insn);
275         bndregno = X86_MODRM_REG(insn.modrm.value);
276         if (bndregno > 3) {
277                 err = -EINVAL;
278                 goto err_out;
279         }
280         /* get the bndregs _area_ of the xsave structure */
281         bndregs = get_xsave_addr(xsave_buf, XSTATE_BNDREGS);
282         if (!bndregs) {
283                 err = -EINVAL;
284                 goto err_out;
285         }
286         /* now go select the individual register in the set of 4 */
287         bndreg = &bndregs[bndregno];
288
289         info = kzalloc(sizeof(*info), GFP_KERNEL);
290         if (!info) {
291                 err = -ENOMEM;
292                 goto err_out;
293         }
294         /*
295          * The registers are always 64-bit, but the upper 32
296          * bits are ignored in 32-bit mode.  Also, note that the
297          * upper bounds are architecturally represented in 1's
298          * complement form.
299          *
300          * The 'unsigned long' cast is because the compiler
301          * complains when casting from integers to different-size
302          * pointers.
303          */
304         info->si_lower = (void __user *)(unsigned long)bndreg->lower_bound;
305         info->si_upper = (void __user *)(unsigned long)~bndreg->upper_bound;
306         info->si_addr_lsb = 0;
307         info->si_signo = SIGSEGV;
308         info->si_errno = 0;
309         info->si_code = SEGV_BNDERR;
310         info->si_addr = mpx_get_addr_ref(&insn, regs);
311         /*
312          * We were not able to extract an address from the instruction,
313          * probably because there was something invalid in it.
314          */
315         if (info->si_addr == (void *)-1) {
316                 err = -EINVAL;
317                 goto err_out;
318         }
319         return info;
320 err_out:
321         /* info might be NULL, but kfree() handles that */
322         kfree(info);
323         return ERR_PTR(err);
324 }
325
326 static __user void *task_get_bounds_dir(struct task_struct *tsk)
327 {
328         struct bndcsr *bndcsr;
329
330         if (!cpu_feature_enabled(X86_FEATURE_MPX))
331                 return MPX_INVALID_BOUNDS_DIR;
332
333         /*
334          * 32-bit binaries on 64-bit kernels are currently
335          * unsupported.
336          */
337         if (IS_ENABLED(CONFIG_X86_64) && test_thread_flag(TIF_IA32))
338                 return MPX_INVALID_BOUNDS_DIR;
339         /*
340          * The bounds directory pointer is stored in a register
341          * only accessible if we first do an xsave.
342          */
343         fpu_save_init(&tsk->thread.fpu);
344         bndcsr = get_xsave_addr(&tsk->thread.fpu.state->xsave, XSTATE_BNDCSR);
345         if (!bndcsr)
346                 return MPX_INVALID_BOUNDS_DIR;
347
348         /*
349          * Make sure the register looks valid by checking the
350          * enable bit.
351          */
352         if (!(bndcsr->bndcfgu & MPX_BNDCFG_ENABLE_FLAG))
353                 return MPX_INVALID_BOUNDS_DIR;
354
355         /*
356          * Lastly, mask off the low bits used for configuration
357          * flags, and return the address of the bounds table.
358          */
359         return (void __user *)(unsigned long)
360                 (bndcsr->bndcfgu & MPX_BNDCFG_ADDR_MASK);
361 }
362
363 int mpx_enable_management(struct task_struct *tsk)
364 {
365         void __user *bd_base = MPX_INVALID_BOUNDS_DIR;
366         struct mm_struct *mm = tsk->mm;
367         int ret = 0;
368
369         /*
370          * runtime in the userspace will be responsible for allocation of
371          * the bounds directory. Then, it will save the base of the bounds
372          * directory into XSAVE/XRSTOR Save Area and enable MPX through
373          * XRSTOR instruction.
374          *
375          * fpu_xsave() is expected to be very expensive. Storing the bounds
376          * directory here means that we do not have to do xsave in the unmap
377          * path; we can just use mm->bd_addr instead.
378          */
379         bd_base = task_get_bounds_dir(tsk);
380         down_write(&mm->mmap_sem);
381         mm->bd_addr = bd_base;
382         if (mm->bd_addr == MPX_INVALID_BOUNDS_DIR)
383                 ret = -ENXIO;
384
385         up_write(&mm->mmap_sem);
386         return ret;
387 }
388
389 int mpx_disable_management(struct task_struct *tsk)
390 {
391         struct mm_struct *mm = current->mm;
392
393         if (!cpu_feature_enabled(X86_FEATURE_MPX))
394                 return -ENXIO;
395
396         down_write(&mm->mmap_sem);
397         mm->bd_addr = MPX_INVALID_BOUNDS_DIR;
398         up_write(&mm->mmap_sem);
399         return 0;
400 }
401
402 /*
403  * With 32-bit mode, MPX_BT_SIZE_BYTES is 4MB, and the size of each
404  * bounds table is 16KB. With 64-bit mode, MPX_BT_SIZE_BYTES is 2GB,
405  * and the size of each bounds table is 4MB.
406  */
407 static int allocate_bt(long __user *bd_entry)
408 {
409         unsigned long expected_old_val = 0;
410         unsigned long actual_old_val = 0;
411         unsigned long bt_addr;
412         int ret = 0;
413
414         /*
415          * Carve the virtual space out of userspace for the new
416          * bounds table:
417          */
418         bt_addr = mpx_mmap(MPX_BT_SIZE_BYTES);
419         if (IS_ERR((void *)bt_addr))
420                 return PTR_ERR((void *)bt_addr);
421         /*
422          * Set the valid flag (kinda like _PAGE_PRESENT in a pte)
423          */
424         bt_addr = bt_addr | MPX_BD_ENTRY_VALID_FLAG;
425
426         /*
427          * Go poke the address of the new bounds table in to the
428          * bounds directory entry out in userspace memory.  Note:
429          * we may race with another CPU instantiating the same table.
430          * In that case the cmpxchg will see an unexpected
431          * 'actual_old_val'.
432          *
433          * This can fault, but that's OK because we do not hold
434          * mmap_sem at this point, unlike some of the other part
435          * of the MPX code that have to pagefault_disable().
436          */
437         ret = user_atomic_cmpxchg_inatomic(&actual_old_val, bd_entry,
438                                            expected_old_val, bt_addr);
439         if (ret)
440                 goto out_unmap;
441
442         /*
443          * The user_atomic_cmpxchg_inatomic() will only return nonzero
444          * for faults, *not* if the cmpxchg itself fails.  Now we must
445          * verify that the cmpxchg itself completed successfully.
446          */
447         /*
448          * We expected an empty 'expected_old_val', but instead found
449          * an apparently valid entry.  Assume we raced with another
450          * thread to instantiate this table and desclare succecss.
451          */
452         if (actual_old_val & MPX_BD_ENTRY_VALID_FLAG) {
453                 ret = 0;
454                 goto out_unmap;
455         }
456         /*
457          * We found a non-empty bd_entry but it did not have the
458          * VALID_FLAG set.  Return an error which will result in
459          * a SEGV since this probably means that somebody scribbled
460          * some invalid data in to a bounds table.
461          */
462         if (expected_old_val != actual_old_val) {
463                 ret = -EINVAL;
464                 goto out_unmap;
465         }
466         return 0;
467 out_unmap:
468         vm_munmap(bt_addr & MPX_BT_ADDR_MASK, MPX_BT_SIZE_BYTES);
469         return ret;
470 }
471
472 /*
473  * When a BNDSTX instruction attempts to save bounds to a bounds
474  * table, it will first attempt to look up the table in the
475  * first-level bounds directory.  If it does not find a table in
476  * the directory, a #BR is generated and we get here in order to
477  * allocate a new table.
478  *
479  * With 32-bit mode, the size of BD is 4MB, and the size of each
480  * bound table is 16KB. With 64-bit mode, the size of BD is 2GB,
481  * and the size of each bound table is 4MB.
482  */
483 static int do_mpx_bt_fault(struct xsave_struct *xsave_buf)
484 {
485         unsigned long bd_entry, bd_base;
486         struct bndcsr *bndcsr;
487
488         bndcsr = get_xsave_addr(xsave_buf, XSTATE_BNDCSR);
489         if (!bndcsr)
490                 return -EINVAL;
491         /*
492          * Mask off the preserve and enable bits
493          */
494         bd_base = bndcsr->bndcfgu & MPX_BNDCFG_ADDR_MASK;
495         /*
496          * The hardware provides the address of the missing or invalid
497          * entry via BNDSTATUS, so we don't have to go look it up.
498          */
499         bd_entry = bndcsr->bndstatus & MPX_BNDSTA_ADDR_MASK;
500         /*
501          * Make sure the directory entry is within where we think
502          * the directory is.
503          */
504         if ((bd_entry < bd_base) ||
505             (bd_entry >= bd_base + MPX_BD_SIZE_BYTES))
506                 return -EINVAL;
507
508         return allocate_bt((long __user *)bd_entry);
509 }
510
511 int mpx_handle_bd_fault(struct xsave_struct *xsave_buf)
512 {
513         /*
514          * Userspace never asked us to manage the bounds tables,
515          * so refuse to help.
516          */
517         if (!kernel_managing_mpx_tables(current->mm))
518                 return -EINVAL;
519
520         if (do_mpx_bt_fault(xsave_buf)) {
521                 force_sig(SIGSEGV, current);
522                 /*
523                  * The force_sig() is essentially "handling" this
524                  * exception, so we do not pass up the error
525                  * from do_mpx_bt_fault().
526                  */
527         }
528         return 0;
529 }
530
531 /*
532  * A thin wrapper around get_user_pages().  Returns 0 if the
533  * fault was resolved or -errno if not.
534  */
535 static int mpx_resolve_fault(long __user *addr, int write)
536 {
537         long gup_ret;
538         int nr_pages = 1;
539         int force = 0;
540
541         gup_ret = get_user_pages(current, current->mm, (unsigned long)addr,
542                                  nr_pages, write, force, NULL, NULL);
543         /*
544          * get_user_pages() returns number of pages gotten.
545          * 0 means we failed to fault in and get anything,
546          * probably because 'addr' is bad.
547          */
548         if (!gup_ret)
549                 return -EFAULT;
550         /* Other error, return it */
551         if (gup_ret < 0)
552                 return gup_ret;
553         /* must have gup'd a page and gup_ret>0, success */
554         return 0;
555 }
556
557 /*
558  * Get the base of bounds tables pointed by specific bounds
559  * directory entry.
560  */
561 static int get_bt_addr(struct mm_struct *mm,
562                         long __user *bd_entry, unsigned long *bt_addr)
563 {
564         int ret;
565         int valid_bit;
566
567         if (!access_ok(VERIFY_READ, (bd_entry), sizeof(*bd_entry)))
568                 return -EFAULT;
569
570         while (1) {
571                 int need_write = 0;
572
573                 pagefault_disable();
574                 ret = get_user(*bt_addr, bd_entry);
575                 pagefault_enable();
576                 if (!ret)
577                         break;
578                 if (ret == -EFAULT)
579                         ret = mpx_resolve_fault(bd_entry, need_write);
580                 /*
581                  * If we could not resolve the fault, consider it
582                  * userspace's fault and error out.
583                  */
584                 if (ret)
585                         return ret;
586         }
587
588         valid_bit = *bt_addr & MPX_BD_ENTRY_VALID_FLAG;
589         *bt_addr &= MPX_BT_ADDR_MASK;
590
591         /*
592          * When the kernel is managing bounds tables, a bounds directory
593          * entry will either have a valid address (plus the valid bit)
594          * *OR* be completely empty. If we see a !valid entry *and* some
595          * data in the address field, we know something is wrong. This
596          * -EINVAL return will cause a SIGSEGV.
597          */
598         if (!valid_bit && *bt_addr)
599                 return -EINVAL;
600         /*
601          * Do we have an completely zeroed bt entry?  That is OK.  It
602          * just means there was no bounds table for this memory.  Make
603          * sure to distinguish this from -EINVAL, which will cause
604          * a SEGV.
605          */
606         if (!valid_bit)
607                 return -ENOENT;
608
609         return 0;
610 }
611
612 /*
613  * Free the backing physical pages of bounds table 'bt_addr'.
614  * Assume start...end is within that bounds table.
615  */
616 static int zap_bt_entries(struct mm_struct *mm,
617                 unsigned long bt_addr,
618                 unsigned long start, unsigned long end)
619 {
620         struct vm_area_struct *vma;
621         unsigned long addr, len;
622
623         /*
624          * Find the first overlapping vma. If vma->vm_start > start, there
625          * will be a hole in the bounds table. This -EINVAL return will
626          * cause a SIGSEGV.
627          */
628         vma = find_vma(mm, start);
629         if (!vma || vma->vm_start > start)
630                 return -EINVAL;
631
632         /*
633          * A NUMA policy on a VM_MPX VMA could cause this bouds table to
634          * be split. So we need to look across the entire 'start -> end'
635          * range of this bounds table, find all of the VM_MPX VMAs, and
636          * zap only those.
637          */
638         addr = start;
639         while (vma && vma->vm_start < end) {
640                 /*
641                  * We followed a bounds directory entry down
642                  * here.  If we find a non-MPX VMA, that's bad,
643                  * so stop immediately and return an error.  This
644                  * probably results in a SIGSEGV.
645                  */
646                 if (!(vma->vm_flags & VM_MPX))
647                         return -EINVAL;
648
649                 len = min(vma->vm_end, end) - addr;
650                 zap_page_range(vma, addr, len, NULL);
651
652                 vma = vma->vm_next;
653                 addr = vma->vm_start;
654         }
655
656         return 0;
657 }
658
659 static int unmap_single_bt(struct mm_struct *mm,
660                 long __user *bd_entry, unsigned long bt_addr)
661 {
662         unsigned long expected_old_val = bt_addr | MPX_BD_ENTRY_VALID_FLAG;
663         unsigned long actual_old_val = 0;
664         int ret;
665
666         while (1) {
667                 int need_write = 1;
668
669                 pagefault_disable();
670                 ret = user_atomic_cmpxchg_inatomic(&actual_old_val, bd_entry,
671                                                    expected_old_val, 0);
672                 pagefault_enable();
673                 if (!ret)
674                         break;
675                 if (ret == -EFAULT)
676                         ret = mpx_resolve_fault(bd_entry, need_write);
677                 /*
678                  * If we could not resolve the fault, consider it
679                  * userspace's fault and error out.
680                  */
681                 if (ret)
682                         return ret;
683         }
684         /*
685          * The cmpxchg was performed, check the results.
686          */
687         if (actual_old_val != expected_old_val) {
688                 /*
689                  * Someone else raced with us to unmap the table.
690                  * There was no bounds table pointed to by the
691                  * directory, so declare success.  Somebody freed
692                  * it.
693                  */
694                 if (!actual_old_val)
695                         return 0;
696                 /*
697                  * Something messed with the bounds directory
698                  * entry.  We hold mmap_sem for read or write
699                  * here, so it could not be a _new_ bounds table
700                  * that someone just allocated.  Something is
701                  * wrong, so pass up the error and SIGSEGV.
702                  */
703                 return -EINVAL;
704         }
705
706         /*
707          * Note, we are likely being called under do_munmap() already. To
708          * avoid recursion, do_munmap() will check whether it comes
709          * from one bounds table through VM_MPX flag.
710          */
711         return do_munmap(mm, bt_addr, MPX_BT_SIZE_BYTES);
712 }
713
714 /*
715  * If the bounds table pointed by bounds directory 'bd_entry' is
716  * not shared, unmap this whole bounds table. Otherwise, only free
717  * those backing physical pages of bounds table entries covered
718  * in this virtual address region start...end.
719  */
720 static int unmap_shared_bt(struct mm_struct *mm,
721                 long __user *bd_entry, unsigned long start,
722                 unsigned long end, bool prev_shared, bool next_shared)
723 {
724         unsigned long bt_addr;
725         int ret;
726
727         ret = get_bt_addr(mm, bd_entry, &bt_addr);
728         /*
729          * We could see an "error" ret for not-present bounds
730          * tables (not really an error), or actual errors, but
731          * stop unmapping either way.
732          */
733         if (ret)
734                 return ret;
735
736         if (prev_shared && next_shared)
737                 ret = zap_bt_entries(mm, bt_addr,
738                                 bt_addr+MPX_GET_BT_ENTRY_OFFSET(start),
739                                 bt_addr+MPX_GET_BT_ENTRY_OFFSET(end));
740         else if (prev_shared)
741                 ret = zap_bt_entries(mm, bt_addr,
742                                 bt_addr+MPX_GET_BT_ENTRY_OFFSET(start),
743                                 bt_addr+MPX_BT_SIZE_BYTES);
744         else if (next_shared)
745                 ret = zap_bt_entries(mm, bt_addr, bt_addr,
746                                 bt_addr+MPX_GET_BT_ENTRY_OFFSET(end));
747         else
748                 ret = unmap_single_bt(mm, bd_entry, bt_addr);
749
750         return ret;
751 }
752
753 /*
754  * A virtual address region being munmap()ed might share bounds table
755  * with adjacent VMAs. We only need to free the backing physical
756  * memory of these shared bounds tables entries covered in this virtual
757  * address region.
758  */
759 static int unmap_edge_bts(struct mm_struct *mm,
760                 unsigned long start, unsigned long end)
761 {
762         int ret;
763         long __user *bde_start, *bde_end;
764         struct vm_area_struct *prev, *next;
765         bool prev_shared = false, next_shared = false;
766
767         bde_start = mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(start);
768         bde_end = mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(end-1);
769
770         /*
771          * Check whether bde_start and bde_end are shared with adjacent
772          * VMAs.
773          *
774          * We already unliked the VMAs from the mm's rbtree so 'start'
775          * is guaranteed to be in a hole. This gets us the first VMA
776          * before the hole in to 'prev' and the next VMA after the hole
777          * in to 'next'.
778          */
779         next = find_vma_prev(mm, start, &prev);
780         if (prev && (mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(prev->vm_end-1))
781                         == bde_start)
782                 prev_shared = true;
783         if (next && (mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(next->vm_start))
784                         == bde_end)
785                 next_shared = true;
786
787         /*
788          * This virtual address region being munmap()ed is only
789          * covered by one bounds table.
790          *
791          * In this case, if this table is also shared with adjacent
792          * VMAs, only part of the backing physical memory of the bounds
793          * table need be freeed. Otherwise the whole bounds table need
794          * be unmapped.
795          */
796         if (bde_start == bde_end) {
797                 return unmap_shared_bt(mm, bde_start, start, end,
798                                 prev_shared, next_shared);
799         }
800
801         /*
802          * If more than one bounds tables are covered in this virtual
803          * address region being munmap()ed, we need to separately check
804          * whether bde_start and bde_end are shared with adjacent VMAs.
805          */
806         ret = unmap_shared_bt(mm, bde_start, start, end, prev_shared, false);
807         if (ret)
808                 return ret;
809         ret = unmap_shared_bt(mm, bde_end, start, end, false, next_shared);
810         if (ret)
811                 return ret;
812
813         return 0;
814 }
815
816 static int mpx_unmap_tables(struct mm_struct *mm,
817                 unsigned long start, unsigned long end)
818 {
819         int ret;
820         long __user *bd_entry, *bde_start, *bde_end;
821         unsigned long bt_addr;
822
823         /*
824          * "Edge" bounds tables are those which are being used by the region
825          * (start -> end), but that may be shared with adjacent areas.  If they
826          * turn out to be completely unshared, they will be freed.  If they are
827          * shared, we will free the backing store (like an MADV_DONTNEED) for
828          * areas used by this region.
829          */
830         ret = unmap_edge_bts(mm, start, end);
831         switch (ret) {
832                 /* non-present tables are OK */
833                 case 0:
834                 case -ENOENT:
835                         /* Success, or no tables to unmap */
836                         break;
837                 case -EINVAL:
838                 case -EFAULT:
839                 default:
840                         return ret;
841         }
842
843         /*
844          * Only unmap the bounds table that are
845          *   1. fully covered
846          *   2. not at the edges of the mapping, even if full aligned
847          */
848         bde_start = mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(start);
849         bde_end = mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(end-1);
850         for (bd_entry = bde_start + 1; bd_entry < bde_end; bd_entry++) {
851                 ret = get_bt_addr(mm, bd_entry, &bt_addr);
852                 switch (ret) {
853                         case 0:
854                                 break;
855                         case -ENOENT:
856                                 /* No table here, try the next one */
857                                 continue;
858                         case -EINVAL:
859                         case -EFAULT:
860                         default:
861                                 /*
862                                  * Note: we are being strict here.
863                                  * Any time we run in to an issue
864                                  * unmapping tables, we stop and
865                                  * SIGSEGV.
866                                  */
867                                 return ret;
868                 }
869
870                 ret = unmap_single_bt(mm, bd_entry, bt_addr);
871                 if (ret)
872                         return ret;
873         }
874
875         return 0;
876 }
877
878 /*
879  * Free unused bounds tables covered in a virtual address region being
880  * munmap()ed. Assume end > start.
881  *
882  * This function will be called by do_munmap(), and the VMAs covering
883  * the virtual address region start...end have already been split if
884  * necessary, and the 'vma' is the first vma in this range (start -> end).
885  */
886 void mpx_notify_unmap(struct mm_struct *mm, struct vm_area_struct *vma,
887                 unsigned long start, unsigned long end)
888 {
889         int ret;
890
891         /*
892          * Refuse to do anything unless userspace has asked
893          * the kernel to help manage the bounds tables,
894          */
895         if (!kernel_managing_mpx_tables(current->mm))
896                 return;
897         /*
898          * This will look across the entire 'start -> end' range,
899          * and find all of the non-VM_MPX VMAs.
900          *
901          * To avoid recursion, if a VM_MPX vma is found in the range
902          * (start->end), we will not continue follow-up work. This
903          * recursion represents having bounds tables for bounds tables,
904          * which should not occur normally. Being strict about it here
905          * helps ensure that we do not have an exploitable stack overflow.
906          */
907         do {
908                 if (vma->vm_flags & VM_MPX)
909                         return;
910                 vma = vma->vm_next;
911         } while (vma && vma->vm_start < end);
912
913         ret = mpx_unmap_tables(mm, start, end);
914         if (ret)
915                 force_sig(SIGSEGV, current);
916 }