Upgrade to 4.4.50-rt62
[kvmfornfv.git] / kernel / drivers / misc / cxl / fault.c
index 25a5418..81c3f75 100644 (file)
@@ -166,13 +166,92 @@ static void cxl_handle_page_fault(struct cxl_context *ctx,
        cxl_ack_irq(ctx, CXL_PSL_TFC_An_R, 0);
 }
 
+/*
+ * Returns the mm_struct corresponding to the context ctx via ctx->pid
+ * In case the task has exited we use the task group leader accessible
+ * via ctx->glpid to find the next task in the thread group that has a
+ * valid  mm_struct associated with it. If a task with valid mm_struct
+ * is found the ctx->pid is updated to use the task struct for subsequent
+ * translations. In case no valid mm_struct is found in the task group to
+ * service the fault a NULL is returned.
+ */
+static struct mm_struct *get_mem_context(struct cxl_context *ctx)
+{
+       struct task_struct *task = NULL;
+       struct mm_struct *mm = NULL;
+       struct pid *old_pid = ctx->pid;
+
+       if (old_pid == NULL) {
+               pr_warn("%s: Invalid context for pe=%d\n",
+                        __func__, ctx->pe);
+               return NULL;
+       }
+
+       task = get_pid_task(old_pid, PIDTYPE_PID);
+
+       /*
+        * pid_alive may look racy but this saves us from costly
+        * get_task_mm when the task is a zombie. In worst case
+        * we may think a task is alive, which is about to die
+        * but get_task_mm will return NULL.
+        */
+       if (task != NULL && pid_alive(task))
+               mm = get_task_mm(task);
+
+       /* release the task struct that was taken earlier */
+       if (task)
+               put_task_struct(task);
+       else
+               pr_devel("%s: Context owning pid=%i for pe=%i dead\n",
+                       __func__, pid_nr(old_pid), ctx->pe);
+
+       /*
+        * If we couldn't find the mm context then use the group
+        * leader to iterate over the task group and find a task
+        * that gives us mm_struct.
+        */
+       if (unlikely(mm == NULL && ctx->glpid != NULL)) {
+
+               rcu_read_lock();
+               task = pid_task(ctx->glpid, PIDTYPE_PID);
+               if (task)
+                       do {
+                               mm = get_task_mm(task);
+                               if (mm) {
+                                       ctx->pid = get_task_pid(task,
+                                                               PIDTYPE_PID);
+                                       break;
+                               }
+                               task = next_thread(task);
+                       } while (task && !thread_group_leader(task));
+               rcu_read_unlock();
+
+               /* check if we switched pid */
+               if (ctx->pid != old_pid) {
+                       if (mm)
+                               pr_devel("%s:pe=%i switch pid %i->%i\n",
+                                        __func__, ctx->pe, pid_nr(old_pid),
+                                        pid_nr(ctx->pid));
+                       else
+                               pr_devel("%s:Cannot find mm for pid=%i\n",
+                                        __func__, pid_nr(old_pid));
+
+                       /* drop the reference to older pid */
+                       put_pid(old_pid);
+               }
+       }
+
+       return mm;
+}
+
+
+
 void cxl_handle_fault(struct work_struct *fault_work)
 {
        struct cxl_context *ctx =
                container_of(fault_work, struct cxl_context, fault_work);
        u64 dsisr = ctx->dsisr;
        u64 dar = ctx->dar;
-       struct task_struct *task = NULL;
        struct mm_struct *mm = NULL;
 
        if (cxl_p2n_read(ctx->afu, CXL_PSL_DSISR_An) != dsisr ||
@@ -195,17 +274,17 @@ void cxl_handle_fault(struct work_struct *fault_work)
                "DSISR: %#llx DAR: %#llx\n", ctx->pe, dsisr, dar);
 
        if (!ctx->kernel) {
-               if (!(task = get_pid_task(ctx->pid, PIDTYPE_PID))) {
-                       pr_devel("cxl_handle_fault unable to get task %i\n",
-                                pid_nr(ctx->pid));
+
+               mm = get_mem_context(ctx);
+               /* indicates all the thread in task group have exited */
+               if (mm == NULL) {
+                       pr_devel("%s: unable to get mm for pe=%d pid=%i\n",
+                                __func__, ctx->pe, pid_nr(ctx->pid));
                        cxl_ack_ae(ctx);
                        return;
-               }
-               if (!(mm = get_task_mm(task))) {
-                       pr_devel("cxl_handle_fault unable to get mm %i\n",
-                                pid_nr(ctx->pid));
-                       cxl_ack_ae(ctx);
-                       goto out;
+               } else {
+                       pr_devel("Handling page fault for pe=%d pid=%i\n",
+                                ctx->pe, pid_nr(ctx->pid));
                }
        }
 
@@ -218,33 +297,22 @@ void cxl_handle_fault(struct work_struct *fault_work)
 
        if (mm)
                mmput(mm);
-out:
-       if (task)
-               put_task_struct(task);
 }
 
 static void cxl_prefault_one(struct cxl_context *ctx, u64 ea)
 {
-       int rc;
-       struct task_struct *task;
        struct mm_struct *mm;
 
-       if (!(task = get_pid_task(ctx->pid, PIDTYPE_PID))) {
-               pr_devel("cxl_prefault_one unable to get task %i\n",
-                        pid_nr(ctx->pid));
-               return;
-       }
-       if (!(mm = get_task_mm(task))) {
+       mm = get_mem_context(ctx);
+       if (mm == NULL) {
                pr_devel("cxl_prefault_one unable to get mm %i\n",
                         pid_nr(ctx->pid));
-               put_task_struct(task);
                return;
        }
 
-       rc = cxl_fault_segment(ctx, mm, ea);
+       cxl_fault_segment(ctx, mm, ea);
 
        mmput(mm);
-       put_task_struct(task);
 }
 
 static u64 next_segment(u64 ea, u64 vsid)
@@ -263,18 +331,13 @@ static void cxl_prefault_vma(struct cxl_context *ctx)
        struct copro_slb slb;
        struct vm_area_struct *vma;
        int rc;
-       struct task_struct *task;
        struct mm_struct *mm;
 
-       if (!(task = get_pid_task(ctx->pid, PIDTYPE_PID))) {
-               pr_devel("cxl_prefault_vma unable to get task %i\n",
-                        pid_nr(ctx->pid));
-               return;
-       }
-       if (!(mm = get_task_mm(task))) {
+       mm = get_mem_context(ctx);
+       if (mm == NULL) {
                pr_devel("cxl_prefault_vm unable to get mm %i\n",
                         pid_nr(ctx->pid));
-               goto out1;
+               return;
        }
 
        down_read(&mm->mmap_sem);
@@ -295,8 +358,6 @@ static void cxl_prefault_vma(struct cxl_context *ctx)
        up_read(&mm->mmap_sem);
 
        mmput(mm);
-out1:
-       put_task_struct(task);
 }
 
 void cxl_prefault(struct cxl_context *ctx, u64 wed)