These changes are the raw update to linux-4.4.6-rt14. Kernel sources
[kvmfornfv.git] / kernel / arch / x86 / mm / mpx.c
index 4d1c11c..ef05755 100644 (file)
 #include <linux/syscalls.h>
 #include <linux/sched/sysctl.h>
 
-#include <asm/i387.h>
 #include <asm/insn.h>
 #include <asm/mman.h>
 #include <asm/mmu_context.h>
 #include <asm/mpx.h>
 #include <asm/processor.h>
-#include <asm/fpu-internal.h>
+#include <asm/fpu/internal.h>
+
+#define CREATE_TRACE_POINTS
+#include <asm/trace/mpx.h>
+
+static inline unsigned long mpx_bd_size_bytes(struct mm_struct *mm)
+{
+       if (is_64bit_mm(mm))
+               return MPX_BD_SIZE_BYTES_64;
+       else
+               return MPX_BD_SIZE_BYTES_32;
+}
+
+static inline unsigned long mpx_bt_size_bytes(struct mm_struct *mm)
+{
+       if (is_64bit_mm(mm))
+               return MPX_BT_SIZE_BYTES_64;
+       else
+               return MPX_BT_SIZE_BYTES_32;
+}
 
 /*
  * This is really a simplified "vm_mmap". it only handles MPX
  */
 static unsigned long mpx_mmap(unsigned long len)
 {
-       unsigned long ret;
-       unsigned long addr, pgoff;
        struct mm_struct *mm = current->mm;
-       vm_flags_t vm_flags;
-       struct vm_area_struct *vma;
+       unsigned long addr, populate;
 
-       /* Only bounds table and bounds directory can be allocated here */
-       if (len != MPX_BD_SIZE_BYTES && len != MPX_BT_SIZE_BYTES)
+       /* Only bounds table can be allocated here */
+       if (len != mpx_bt_size_bytes(mm))
                return -EINVAL;
 
        down_write(&mm->mmap_sem);
-
-       /* Too many mappings? */
-       if (mm->map_count > sysctl_max_map_count) {
-               ret = -ENOMEM;
-               goto out;
-       }
-
-       /* Obtain the address to map to. we verify (or select) it and ensure
-        * that it represents a valid section of the address space.
-        */
-       addr = get_unmapped_area(NULL, 0, len, 0, MAP_ANONYMOUS | MAP_PRIVATE);
-       if (addr & ~PAGE_MASK) {
-               ret = addr;
-               goto out;
-       }
-
-       vm_flags = VM_READ | VM_WRITE | VM_MPX |
-                       mm->def_flags | VM_MAYREAD | VM_MAYWRITE | VM_MAYEXEC;
-
-       /* Set pgoff according to addr for anon_vma */
-       pgoff = addr >> PAGE_SHIFT;
-
-       ret = mmap_region(NULL, addr, len, vm_flags, pgoff);
-       if (IS_ERR_VALUE(ret))
-               goto out;
-
-       vma = find_vma(mm, ret);
-       if (!vma) {
-               ret = -ENOMEM;
-               goto out;
-       }
-
-       if (vm_flags & VM_LOCKED) {
-               up_write(&mm->mmap_sem);
-               mm_populate(ret, len);
-               return ret;
-       }
-
-out:
+       addr = do_mmap(NULL, 0, len, PROT_READ | PROT_WRITE,
+                       MAP_ANONYMOUS | MAP_PRIVATE, VM_MPX, 0, &populate);
        up_write(&mm->mmap_sem);
-       return ret;
+       if (populate)
+               mm_populate(addr, populate);
+
+       return addr;
 }
 
 enum reg_type {
@@ -120,19 +101,19 @@ static int get_reg_offset(struct insn *insn, struct pt_regs *regs,
        switch (type) {
        case REG_TYPE_RM:
                regno = X86_MODRM_RM(insn->modrm.value);
-               if (X86_REX_B(insn->rex_prefix.value) == 1)
+               if (X86_REX_B(insn->rex_prefix.value))
                        regno += 8;
                break;
 
        case REG_TYPE_INDEX:
                regno = X86_SIB_INDEX(insn->sib.value);
-               if (X86_REX_X(insn->rex_prefix.value) == 1)
+               if (X86_REX_X(insn->rex_prefix.value))
                        regno += 8;
                break;
 
        case REG_TYPE_BASE:
                regno = X86_SIB_BASE(insn->sib.value);
-               if (X86_REX_B(insn->rex_prefix.value) == 1)
+               if (X86_REX_B(insn->rex_prefix.value))
                        regno += 8;
                break;
 
@@ -142,7 +123,7 @@ static int get_reg_offset(struct insn *insn, struct pt_regs *regs,
                break;
        }
 
-       if (regno > nr_registers) {
+       if (regno >= nr_registers) {
                WARN_ONCE(1, "decoded an instruction with an invalid register");
                return -EINVAL;
        }
@@ -254,10 +235,10 @@ bad_opcode:
  *
  * The caller is expected to kfree() the returned siginfo_t.
  */
-siginfo_t *mpx_generate_siginfo(struct pt_regs *regs,
-                               struct xsave_struct *xsave_buf)
+siginfo_t *mpx_generate_siginfo(struct pt_regs *regs)
 {
-       struct bndreg *bndregs, *bndreg;
+       const struct mpx_bndreg_state *bndregs;
+       const struct mpx_bndreg *bndreg;
        siginfo_t *info = NULL;
        struct insn insn;
        uint8_t bndregno;
@@ -277,14 +258,14 @@ siginfo_t *mpx_generate_siginfo(struct pt_regs *regs,
                err = -EINVAL;
                goto err_out;
        }
-       /* get the bndregs _area_ of the xsave structure */
-       bndregs = get_xsave_addr(xsave_buf, XSTATE_BNDREGS);
+       /* get bndregs field from current task's xsave area */
+       bndregs = get_xsave_field_ptr(XFEATURE_MASK_BNDREGS);
        if (!bndregs) {
                err = -EINVAL;
                goto err_out;
        }
        /* now go select the individual register in the set of 4 */
-       bndreg = &bndregs[bndregno];
+       bndreg = &bndregs->bndreg[bndregno];
 
        info = kzalloc(sizeof(*info), GFP_KERNEL);
        if (!info) {
@@ -316,6 +297,7 @@ siginfo_t *mpx_generate_siginfo(struct pt_regs *regs,
                err = -EINVAL;
                goto err_out;
        }
+       trace_mpx_bounds_register_exception(info->si_addr, bndreg);
        return info;
 err_out:
        /* info might be NULL, but kfree() handles that */
@@ -323,25 +305,18 @@ err_out:
        return ERR_PTR(err);
 }
 
-static __user void *task_get_bounds_dir(struct task_struct *tsk)
+static __user void *mpx_get_bounds_dir(void)
 {
-       struct bndcsr *bndcsr;
+       const struct mpx_bndcsr *bndcsr;
 
        if (!cpu_feature_enabled(X86_FEATURE_MPX))
                return MPX_INVALID_BOUNDS_DIR;
 
-       /*
-        * 32-bit binaries on 64-bit kernels are currently
-        * unsupported.
-        */
-       if (IS_ENABLED(CONFIG_X86_64) && test_thread_flag(TIF_IA32))
-               return MPX_INVALID_BOUNDS_DIR;
        /*
         * The bounds directory pointer is stored in a register
         * only accessible if we first do an xsave.
         */
-       fpu_save_init(&tsk->thread.fpu);
-       bndcsr = get_xsave_addr(&tsk->thread.fpu.state->xsave, XSTATE_BNDCSR);
+       bndcsr = get_xsave_field_ptr(XFEATURE_MASK_BNDCSR);
        if (!bndcsr)
                return MPX_INVALID_BOUNDS_DIR;
 
@@ -360,10 +335,10 @@ static __user void *task_get_bounds_dir(struct task_struct *tsk)
                (bndcsr->bndcfgu & MPX_BNDCFG_ADDR_MASK);
 }
 
-int mpx_enable_management(struct task_struct *tsk)
+int mpx_enable_management(void)
 {
        void __user *bd_base = MPX_INVALID_BOUNDS_DIR;
-       struct mm_struct *mm = tsk->mm;
+       struct mm_struct *mm = current->mm;
        int ret = 0;
 
        /*
@@ -372,11 +347,12 @@ int mpx_enable_management(struct task_struct *tsk)
         * directory into XSAVE/XRSTOR Save Area and enable MPX through
         * XRSTOR instruction.
         *
-        * fpu_xsave() is expected to be very expensive. Storing the bounds
-        * directory here means that we do not have to do xsave in the unmap
-        * path; we can just use mm->bd_addr instead.
+        * The copy_xregs_to_kernel() beneath get_xsave_field_ptr() is
+        * expected to be relatively expensive. Storing the bounds
+        * directory here means that we do not have to do xsave in the
+        * unmap path; we can just use mm->bd_addr instead.
         */
-       bd_base = task_get_bounds_dir(tsk);
+       bd_base = mpx_get_bounds_dir();
        down_write(&mm->mmap_sem);
        mm->bd_addr = bd_base;
        if (mm->bd_addr == MPX_INVALID_BOUNDS_DIR)
@@ -386,7 +362,7 @@ int mpx_enable_management(struct task_struct *tsk)
        return ret;
 }
 
-int mpx_disable_management(struct task_struct *tsk)
+int mpx_disable_management(void)
 {
        struct mm_struct *mm = current->mm;
 
@@ -399,29 +375,59 @@ int mpx_disable_management(struct task_struct *tsk)
        return 0;
 }
 
+static int mpx_cmpxchg_bd_entry(struct mm_struct *mm,
+               unsigned long *curval,
+               unsigned long __user *addr,
+               unsigned long old_val, unsigned long new_val)
+{
+       int ret;
+       /*
+        * user_atomic_cmpxchg_inatomic() actually uses sizeof()
+        * the pointer that we pass to it to figure out how much
+        * data to cmpxchg.  We have to be careful here not to
+        * pass a pointer to a 64-bit data type when we only want
+        * a 32-bit copy.
+        */
+       if (is_64bit_mm(mm)) {
+               ret = user_atomic_cmpxchg_inatomic(curval,
+                               addr, old_val, new_val);
+       } else {
+               u32 uninitialized_var(curval_32);
+               u32 old_val_32 = old_val;
+               u32 new_val_32 = new_val;
+               u32 __user *addr_32 = (u32 __user *)addr;
+
+               ret = user_atomic_cmpxchg_inatomic(&curval_32,
+                               addr_32, old_val_32, new_val_32);
+               *curval = curval_32;
+       }
+       return ret;
+}
+
 /*
- * With 32-bit mode, MPX_BT_SIZE_BYTES is 4MB, and the size of each
- * bounds table is 16KB. With 64-bit mode, MPX_BT_SIZE_BYTES is 2GB,
+ * With 32-bit mode, a bounds directory is 4MB, and the size of each
+ * bounds table is 16KB. With 64-bit mode, a bounds directory is 2GB,
  * and the size of each bounds table is 4MB.
  */
-static int allocate_bt(long __user *bd_entry)
+static int allocate_bt(struct mm_struct *mm, long __user *bd_entry)
 {
        unsigned long expected_old_val = 0;
        unsigned long actual_old_val = 0;
        unsigned long bt_addr;
+       unsigned long bd_new_entry;
        int ret = 0;
 
        /*
         * Carve the virtual space out of userspace for the new
         * bounds table:
         */
-       bt_addr = mpx_mmap(MPX_BT_SIZE_BYTES);
+       bt_addr = mpx_mmap(mpx_bt_size_bytes(mm));
        if (IS_ERR((void *)bt_addr))
                return PTR_ERR((void *)bt_addr);
        /*
         * Set the valid flag (kinda like _PAGE_PRESENT in a pte)
         */
-       bt_addr = bt_addr | MPX_BD_ENTRY_VALID_FLAG;
+       bd_new_entry = bt_addr | MPX_BD_ENTRY_VALID_FLAG;
 
        /*
         * Go poke the address of the new bounds table in to the
@@ -434,8 +440,8 @@ static int allocate_bt(long __user *bd_entry)
         * mmap_sem at this point, unlike some of the other part
         * of the MPX code that have to pagefault_disable().
         */
-       ret = user_atomic_cmpxchg_inatomic(&actual_old_val, bd_entry,
-                                          expected_old_val, bt_addr);
+       ret = mpx_cmpxchg_bd_entry(mm, &actual_old_val, bd_entry,
+                                  expected_old_val, bd_new_entry);
        if (ret)
                goto out_unmap;
 
@@ -463,9 +469,10 @@ static int allocate_bt(long __user *bd_entry)
                ret = -EINVAL;
                goto out_unmap;
        }
+       trace_mpx_new_bounds_table(bt_addr);
        return 0;
 out_unmap:
-       vm_munmap(bt_addr & MPX_BT_ADDR_MASK, MPX_BT_SIZE_BYTES);
+       vm_munmap(bt_addr, mpx_bt_size_bytes(mm));
        return ret;
 }
 
@@ -480,12 +487,13 @@ out_unmap:
  * bound table is 16KB. With 64-bit mode, the size of BD is 2GB,
  * and the size of each bound table is 4MB.
  */
-static int do_mpx_bt_fault(struct xsave_struct *xsave_buf)
+static int do_mpx_bt_fault(void)
 {
        unsigned long bd_entry, bd_base;
-       struct bndcsr *bndcsr;
+       const struct mpx_bndcsr *bndcsr;
+       struct mm_struct *mm = current->mm;
 
-       bndcsr = get_xsave_addr(xsave_buf, XSTATE_BNDCSR);
+       bndcsr = get_xsave_field_ptr(XFEATURE_MASK_BNDCSR);
        if (!bndcsr)
                return -EINVAL;
        /*
@@ -502,13 +510,13 @@ static int do_mpx_bt_fault(struct xsave_struct *xsave_buf)
         * the directory is.
         */
        if ((bd_entry < bd_base) ||
-           (bd_entry >= bd_base + MPX_BD_SIZE_BYTES))
+           (bd_entry >= bd_base + mpx_bd_size_bytes(mm)))
                return -EINVAL;
 
-       return allocate_bt((long __user *)bd_entry);
+       return allocate_bt(mm, (long __user *)bd_entry);
 }
 
-int mpx_handle_bd_fault(struct xsave_struct *xsave_buf)
+int mpx_handle_bd_fault(void)
 {
        /*
         * Userspace never asked us to manage the bounds tables,
@@ -517,7 +525,7 @@ int mpx_handle_bd_fault(struct xsave_struct *xsave_buf)
        if (!kernel_managing_mpx_tables(current->mm))
                return -EINVAL;
 
-       if (do_mpx_bt_fault(xsave_buf)) {
+       if (do_mpx_bt_fault()) {
                force_sig(SIGSEGV, current);
                /*
                 * The force_sig() is essentially "handling" this
@@ -554,29 +562,78 @@ static int mpx_resolve_fault(long __user *addr, int write)
        return 0;
 }
 
+static unsigned long mpx_bd_entry_to_bt_addr(struct mm_struct *mm,
+                                            unsigned long bd_entry)
+{
+       unsigned long bt_addr = bd_entry;
+       int align_to_bytes;
+       /*
+        * Bit 0 in a bt_entry is always the valid bit.
+        */
+       bt_addr &= ~MPX_BD_ENTRY_VALID_FLAG;
+       /*
+        * Tables are naturally aligned at 8-byte boundaries
+        * on 64-bit and 4-byte boundaries on 32-bit.  The
+        * documentation makes it appear that the low bits
+        * are ignored by the hardware, so we do the same.
+        */
+       if (is_64bit_mm(mm))
+               align_to_bytes = 8;
+       else
+               align_to_bytes = 4;
+       bt_addr &= ~(align_to_bytes-1);
+       return bt_addr;
+}
+
+/*
+ * We only want to do a 4-byte get_user() on 32-bit.  Otherwise,
+ * we might run off the end of the bounds table if we are on
+ * a 64-bit kernel and try to get 8 bytes.
+ */
+int get_user_bd_entry(struct mm_struct *mm, unsigned long *bd_entry_ret,
+               long __user *bd_entry_ptr)
+{
+       u32 bd_entry_32;
+       int ret;
+
+       if (is_64bit_mm(mm))
+               return get_user(*bd_entry_ret, bd_entry_ptr);
+
+       /*
+        * Note that get_user() uses the type of the *pointer* to
+        * establish the size of the get, not the destination.
+        */
+       ret = get_user(bd_entry_32, (u32 __user *)bd_entry_ptr);
+       *bd_entry_ret = bd_entry_32;
+       return ret;
+}
+
 /*
  * Get the base of bounds tables pointed by specific bounds
  * directory entry.
  */
 static int get_bt_addr(struct mm_struct *mm,
-                       long __user *bd_entry, unsigned long *bt_addr)
+                       long __user *bd_entry_ptr,
+                       unsigned long *bt_addr_result)
 {
        int ret;
        int valid_bit;
+       unsigned long bd_entry;
+       unsigned long bt_addr;
 
-       if (!access_ok(VERIFY_READ, (bd_entry), sizeof(*bd_entry)))
+       if (!access_ok(VERIFY_READ, (bd_entry_ptr), sizeof(*bd_entry_ptr)))
                return -EFAULT;
 
        while (1) {
                int need_write = 0;
 
                pagefault_disable();
-               ret = get_user(*bt_addr, bd_entry);
+               ret = get_user_bd_entry(mm, &bd_entry, bd_entry_ptr);
                pagefault_enable();
                if (!ret)
                        break;
                if (ret == -EFAULT)
-                       ret = mpx_resolve_fault(bd_entry, need_write);
+                       ret = mpx_resolve_fault(bd_entry_ptr, need_write);
                /*
                 * If we could not resolve the fault, consider it
                 * userspace's fault and error out.
@@ -585,8 +642,8 @@ static int get_bt_addr(struct mm_struct *mm,
                        return ret;
        }
 
-       valid_bit = *bt_addr & MPX_BD_ENTRY_VALID_FLAG;
-       *bt_addr &= MPX_BT_ADDR_MASK;
+       valid_bit = bd_entry & MPX_BD_ENTRY_VALID_FLAG;
+       bt_addr = mpx_bd_entry_to_bt_addr(mm, bd_entry);
 
        /*
         * When the kernel is managing bounds tables, a bounds directory
@@ -595,7 +652,7 @@ static int get_bt_addr(struct mm_struct *mm,
         * data in the address field, we know something is wrong. This
         * -EINVAL return will cause a SIGSEGV.
         */
-       if (!valid_bit && *bt_addr)
+       if (!valid_bit && bt_addr)
                return -EINVAL;
        /*
         * Do we have an completely zeroed bt entry?  That is OK.  It
@@ -606,19 +663,112 @@ static int get_bt_addr(struct mm_struct *mm,
        if (!valid_bit)
                return -ENOENT;
 
+       *bt_addr_result = bt_addr;
        return 0;
 }
 
+static inline int bt_entry_size_bytes(struct mm_struct *mm)
+{
+       if (is_64bit_mm(mm))
+               return MPX_BT_ENTRY_BYTES_64;
+       else
+               return MPX_BT_ENTRY_BYTES_32;
+}
+
+/*
+ * Take a virtual address and turns it in to the offset in bytes
+ * inside of the bounds table where the bounds table entry
+ * controlling 'addr' can be found.
+ */
+static unsigned long mpx_get_bt_entry_offset_bytes(struct mm_struct *mm,
+               unsigned long addr)
+{
+       unsigned long bt_table_nr_entries;
+       unsigned long offset = addr;
+
+       if (is_64bit_mm(mm)) {
+               /* Bottom 3 bits are ignored on 64-bit */
+               offset >>= 3;
+               bt_table_nr_entries = MPX_BT_NR_ENTRIES_64;
+       } else {
+               /* Bottom 2 bits are ignored on 32-bit */
+               offset >>= 2;
+               bt_table_nr_entries = MPX_BT_NR_ENTRIES_32;
+       }
+       /*
+        * We know the size of the table in to which we are
+        * indexing, and we have eliminated all the low bits
+        * which are ignored for indexing.
+        *
+        * Mask out all the high bits which we do not need
+        * to index in to the table.  Note that the tables
+        * are always powers of two so this gives us a proper
+        * mask.
+        */
+       offset &= (bt_table_nr_entries-1);
+       /*
+        * We now have an entry offset in terms of *entries* in
+        * the table.  We need to scale it back up to bytes.
+        */
+       offset *= bt_entry_size_bytes(mm);
+       return offset;
+}
+
+/*
+ * How much virtual address space does a single bounds
+ * directory entry cover?
+ *
+ * Note, we need a long long because 4GB doesn't fit in
+ * to a long on 32-bit.
+ */
+static inline unsigned long bd_entry_virt_space(struct mm_struct *mm)
+{
+       unsigned long long virt_space;
+       unsigned long long GB = (1ULL << 30);
+
+       /*
+        * This covers 32-bit emulation as well as 32-bit kernels
+        * running on 64-bit harware.
+        */
+       if (!is_64bit_mm(mm))
+               return (4ULL * GB) / MPX_BD_NR_ENTRIES_32;
+
+       /*
+        * 'x86_virt_bits' returns what the hardware is capable
+        * of, and returns the full >32-bit adddress space when
+        * running 32-bit kernels on 64-bit hardware.
+        */
+       virt_space = (1ULL << boot_cpu_data.x86_virt_bits);
+       return virt_space / MPX_BD_NR_ENTRIES_64;
+}
+
 /*
  * Free the backing physical pages of bounds table 'bt_addr'.
  * Assume start...end is within that bounds table.
  */
-static int zap_bt_entries(struct mm_struct *mm,
+static noinline int zap_bt_entries_mapping(struct mm_struct *mm,
                unsigned long bt_addr,
-               unsigned long start, unsigned long end)
+               unsigned long start_mapping, unsigned long end_mapping)
 {
        struct vm_area_struct *vma;
        unsigned long addr, len;
+       unsigned long start;
+       unsigned long end;
+
+       /*
+        * if we 'end' on a boundary, the offset will be 0 which
+        * is not what we want.  Back it up a byte to get the
+        * last bt entry.  Then once we have the entry itself,
+        * move 'end' back up by the table entry size.
+        */
+       start = bt_addr + mpx_get_bt_entry_offset_bytes(mm, start_mapping);
+       end   = bt_addr + mpx_get_bt_entry_offset_bytes(mm, end_mapping - 1);
+       /*
+        * Move end back up by one entry.  Among other things
+        * this ensures that it remains page-aligned and does
+        * not screw up zap_page_range()
+        */
+       end += bt_entry_size_bytes(mm);
 
        /*
         * Find the first overlapping vma. If vma->vm_start > start, there
@@ -630,7 +780,7 @@ static int zap_bt_entries(struct mm_struct *mm,
                return -EINVAL;
 
        /*
-        * A NUMA policy on a VM_MPX VMA could cause this bouds table to
+        * A NUMA policy on a VM_MPX VMA could cause this bounds table to
         * be split. So we need to look across the entire 'start -> end'
         * range of this bounds table, find all of the VM_MPX VMAs, and
         * zap only those.
@@ -648,27 +798,65 @@ static int zap_bt_entries(struct mm_struct *mm,
 
                len = min(vma->vm_end, end) - addr;
                zap_page_range(vma, addr, len, NULL);
+               trace_mpx_unmap_zap(addr, addr+len);
 
                vma = vma->vm_next;
                addr = vma->vm_start;
        }
-
        return 0;
 }
 
-static int unmap_single_bt(struct mm_struct *mm,
+static unsigned long mpx_get_bd_entry_offset(struct mm_struct *mm,
+               unsigned long addr)
+{
+       /*
+        * There are several ways to derive the bd offsets.  We
+        * use the following approach here:
+        * 1. We know the size of the virtual address space
+        * 2. We know the number of entries in a bounds table
+        * 3. We know that each entry covers a fixed amount of
+        *    virtual address space.
+        * So, we can just divide the virtual address by the
+        * virtual space used by one entry to determine which
+        * entry "controls" the given virtual address.
+        */
+       if (is_64bit_mm(mm)) {
+               int bd_entry_size = 8; /* 64-bit pointer */
+               /*
+                * Take the 64-bit addressing hole in to account.
+                */
+               addr &= ((1UL << boot_cpu_data.x86_virt_bits) - 1);
+               return (addr / bd_entry_virt_space(mm)) * bd_entry_size;
+       } else {
+               int bd_entry_size = 4; /* 32-bit pointer */
+               /*
+                * 32-bit has no hole so this case needs no mask
+                */
+               return (addr / bd_entry_virt_space(mm)) * bd_entry_size;
+       }
+       /*
+        * The two return calls above are exact copies.  If we
+        * pull out a single copy and put it in here, gcc won't
+        * realize that we're doing a power-of-2 divide and use
+        * shifts.  It uses a real divide.  If we put them up
+        * there, it manages to figure it out (gcc 4.8.3).
+        */
+}
+
+static int unmap_entire_bt(struct mm_struct *mm,
                long __user *bd_entry, unsigned long bt_addr)
 {
        unsigned long expected_old_val = bt_addr | MPX_BD_ENTRY_VALID_FLAG;
-       unsigned long actual_old_val = 0;
+       unsigned long uninitialized_var(actual_old_val);
        int ret;
 
        while (1) {
                int need_write = 1;
+               unsigned long cleared_bd_entry = 0;
 
                pagefault_disable();
-               ret = user_atomic_cmpxchg_inatomic(&actual_old_val, bd_entry,
-                                                  expected_old_val, 0);
+               ret = mpx_cmpxchg_bd_entry(mm, &actual_old_val,
+                               bd_entry, expected_old_val, cleared_bd_entry);
                pagefault_enable();
                if (!ret)
                        break;
@@ -687,9 +875,8 @@ static int unmap_single_bt(struct mm_struct *mm,
        if (actual_old_val != expected_old_val) {
                /*
                 * Someone else raced with us to unmap the table.
-                * There was no bounds table pointed to by the
-                * directory, so declare success.  Somebody freed
-                * it.
+                * That is OK, since we were both trying to do
+                * the same thing.  Declare success.
                 */
                if (!actual_old_val)
                        return 0;
@@ -702,176 +889,113 @@ static int unmap_single_bt(struct mm_struct *mm,
                 */
                return -EINVAL;
        }
-
        /*
         * Note, we are likely being called under do_munmap() already. To
         * avoid recursion, do_munmap() will check whether it comes
         * from one bounds table through VM_MPX flag.
         */
-       return do_munmap(mm, bt_addr, MPX_BT_SIZE_BYTES);
+       return do_munmap(mm, bt_addr, mpx_bt_size_bytes(mm));
 }
 
-/*
- * If the bounds table pointed by bounds directory 'bd_entry' is
- * not shared, unmap this whole bounds table. Otherwise, only free
- * those backing physical pages of bounds table entries covered
- * in this virtual address region start...end.
- */
-static int unmap_shared_bt(struct mm_struct *mm,
-               long __user *bd_entry, unsigned long start,
-               unsigned long end, bool prev_shared, bool next_shared)
+static int try_unmap_single_bt(struct mm_struct *mm,
+              unsigned long start, unsigned long end)
 {
-       unsigned long bt_addr;
-       int ret;
-
-       ret = get_bt_addr(mm, bd_entry, &bt_addr);
+       struct vm_area_struct *next;
+       struct vm_area_struct *prev;
        /*
-        * We could see an "error" ret for not-present bounds
-        * tables (not really an error), or actual errors, but
-        * stop unmapping either way.
+        * "bta" == Bounds Table Area: the area controlled by the
+        * bounds table that we are unmapping.
         */
-       if (ret)
-               return ret;
-
-       if (prev_shared && next_shared)
-               ret = zap_bt_entries(mm, bt_addr,
-                               bt_addr+MPX_GET_BT_ENTRY_OFFSET(start),
-                               bt_addr+MPX_GET_BT_ENTRY_OFFSET(end));
-       else if (prev_shared)
-               ret = zap_bt_entries(mm, bt_addr,
-                               bt_addr+MPX_GET_BT_ENTRY_OFFSET(start),
-                               bt_addr+MPX_BT_SIZE_BYTES);
-       else if (next_shared)
-               ret = zap_bt_entries(mm, bt_addr, bt_addr,
-                               bt_addr+MPX_GET_BT_ENTRY_OFFSET(end));
-       else
-               ret = unmap_single_bt(mm, bd_entry, bt_addr);
-
-       return ret;
-}
-
-/*
- * A virtual address region being munmap()ed might share bounds table
- * with adjacent VMAs. We only need to free the backing physical
- * memory of these shared bounds tables entries covered in this virtual
- * address region.
- */
-static int unmap_edge_bts(struct mm_struct *mm,
-               unsigned long start, unsigned long end)
-{
+       unsigned long bta_start_vaddr = start & ~(bd_entry_virt_space(mm)-1);
+       unsigned long bta_end_vaddr = bta_start_vaddr + bd_entry_virt_space(mm);
+       unsigned long uninitialized_var(bt_addr);
+       void __user *bde_vaddr;
        int ret;
-       long __user *bde_start, *bde_end;
-       struct vm_area_struct *prev, *next;
-       bool prev_shared = false, next_shared = false;
-
-       bde_start = mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(start);
-       bde_end = mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(end-1);
-
        /*
-        * Check whether bde_start and bde_end are shared with adjacent
-        * VMAs.
-        *
-        * We already unliked the VMAs from the mm's rbtree so 'start'
+        * We already unlinked the VMAs from the mm's rbtree so 'start'
         * is guaranteed to be in a hole. This gets us the first VMA
         * before the hole in to 'prev' and the next VMA after the hole
         * in to 'next'.
         */
        next = find_vma_prev(mm, start, &prev);
-       if (prev && (mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(prev->vm_end-1))
-                       == bde_start)
-               prev_shared = true;
-       if (next && (mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(next->vm_start))
-                       == bde_end)
-               next_shared = true;
-
        /*
-        * This virtual address region being munmap()ed is only
-        * covered by one bounds table.
-        *
-        * In this case, if this table is also shared with adjacent
-        * VMAs, only part of the backing physical memory of the bounds
-        * table need be freeed. Otherwise the whole bounds table need
-        * be unmapped.
-        */
-       if (bde_start == bde_end) {
-               return unmap_shared_bt(mm, bde_start, start, end,
-                               prev_shared, next_shared);
+        * Do not count other MPX bounds table VMAs as neighbors.
+        * Although theoretically possible, we do not allow bounds
+        * tables for bounds tables so our heads do not explode.
+        * If we count them as neighbors here, we may end up with
+        * lots of tables even though we have no actual table
+        * entries in use.
+        */
+       while (next && (next->vm_flags & VM_MPX))
+               next = next->vm_next;
+       while (prev && (prev->vm_flags & VM_MPX))
+               prev = prev->vm_prev;
+       /*
+        * We know 'start' and 'end' lie within an area controlled
+        * by a single bounds table.  See if there are any other
+        * VMAs controlled by that bounds table.  If there are not
+        * then we can "expand" the are we are unmapping to possibly
+        * cover the entire table.
+        */
+       next = find_vma_prev(mm, start, &prev);
+       if ((!prev || prev->vm_end <= bta_start_vaddr) &&
+           (!next || next->vm_start >= bta_end_vaddr)) {
+               /*
+                * No neighbor VMAs controlled by same bounds
+                * table.  Try to unmap the whole thing
+                */
+               start = bta_start_vaddr;
+               end = bta_end_vaddr;
        }
 
+       bde_vaddr = mm->bd_addr + mpx_get_bd_entry_offset(mm, start);
+       ret = get_bt_addr(mm, bde_vaddr, &bt_addr);
        /*
-        * If more than one bounds tables are covered in this virtual
-        * address region being munmap()ed, we need to separately check
-        * whether bde_start and bde_end are shared with adjacent VMAs.
+        * No bounds table there, so nothing to unmap.
         */
-       ret = unmap_shared_bt(mm, bde_start, start, end, prev_shared, false);
-       if (ret)
-               return ret;
-       ret = unmap_shared_bt(mm, bde_end, start, end, false, next_shared);
+       if (ret == -ENOENT) {
+               ret = 0;
+               return 0;
+       }
        if (ret)
                return ret;
-
-       return 0;
+       /*
+        * We are unmapping an entire table.  Either because the
+        * unmap that started this whole process was large enough
+        * to cover an entire table, or that the unmap was small
+        * but was the area covered by a bounds table.
+        */
+       if ((start == bta_start_vaddr) &&
+           (end == bta_end_vaddr))
+               return unmap_entire_bt(mm, bde_vaddr, bt_addr);
+       return zap_bt_entries_mapping(mm, bt_addr, start, end);
 }
 
 static int mpx_unmap_tables(struct mm_struct *mm,
                unsigned long start, unsigned long end)
 {
-       int ret;
-       long __user *bd_entry, *bde_start, *bde_end;
-       unsigned long bt_addr;
-
-       /*
-        * "Edge" bounds tables are those which are being used by the region
-        * (start -> end), but that may be shared with adjacent areas.  If they
-        * turn out to be completely unshared, they will be freed.  If they are
-        * shared, we will free the backing store (like an MADV_DONTNEED) for
-        * areas used by this region.
-        */
-       ret = unmap_edge_bts(mm, start, end);
-       switch (ret) {
-               /* non-present tables are OK */
-               case 0:
-               case -ENOENT:
-                       /* Success, or no tables to unmap */
-                       break;
-               case -EINVAL:
-               case -EFAULT:
-               default:
-                       return ret;
-       }
-
-       /*
-        * Only unmap the bounds table that are
-        *   1. fully covered
-        *   2. not at the edges of the mapping, even if full aligned
-        */
-       bde_start = mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(start);
-       bde_end = mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(end-1);
-       for (bd_entry = bde_start + 1; bd_entry < bde_end; bd_entry++) {
-               ret = get_bt_addr(mm, bd_entry, &bt_addr);
-               switch (ret) {
-                       case 0:
-                               break;
-                       case -ENOENT:
-                               /* No table here, try the next one */
-                               continue;
-                       case -EINVAL:
-                       case -EFAULT:
-                       default:
-                               /*
-                                * Note: we are being strict here.
-                                * Any time we run in to an issue
-                                * unmapping tables, we stop and
-                                * SIGSEGV.
-                                */
-                               return ret;
-               }
-
-               ret = unmap_single_bt(mm, bd_entry, bt_addr);
+       unsigned long one_unmap_start;
+       trace_mpx_unmap_search(start, end);
+
+       one_unmap_start = start;
+       while (one_unmap_start < end) {
+               int ret;
+               unsigned long next_unmap_start = ALIGN(one_unmap_start+1,
+                                                      bd_entry_virt_space(mm));
+               unsigned long one_unmap_end = end;
+               /*
+                * if the end is beyond the current bounds table,
+                * move it back so we only deal with a single one
+                * at a time
+                */
+               if (one_unmap_end > next_unmap_start)
+                       one_unmap_end = next_unmap_start;
+               ret = try_unmap_single_bt(mm, one_unmap_start, one_unmap_end);
                if (ret)
                        return ret;
-       }
 
+               one_unmap_start = next_unmap_start;
+       }
        return 0;
 }