These changes are the raw update to linux-4.4.6-rt14. Kernel sources
[kvmfornfv.git] / kernel / lib / test_rhashtable.c
index b295754..8c1ad1c 100644 (file)
@@ -1,14 +1,9 @@
 /*
  * Resizable, Scalable, Concurrent Hash Table
  *
- * Copyright (c) 2014 Thomas Graf <tgraf@suug.ch>
+ * Copyright (c) 2014-2015 Thomas Graf <tgraf@suug.ch>
  * Copyright (c) 2008-2014 Patrick McHardy <kaber@trash.net>
  *
- * Based on the following paper:
- * https://www.usenix.org/legacy/event/atc11/tech/final_files/Triplett.pdf
- *
- * Code partially derived from nft_hash
- *
  * This program is free software; you can redistribute it and/or modify
  * it under the terms of the GNU General Public License version 2 as
  * published by the Free Software Foundation.
 #include <linux/init.h>
 #include <linux/jhash.h>
 #include <linux/kernel.h>
+#include <linux/kthread.h>
 #include <linux/module.h>
 #include <linux/rcupdate.h>
 #include <linux/rhashtable.h>
+#include <linux/semaphore.h>
 #include <linux/slab.h>
+#include <linux/sched.h>
+#include <linux/vmalloc.h>
+
+#define MAX_ENTRIES    1000000
+#define TEST_INSERT_FAIL INT_MAX
+
+static int entries = 50000;
+module_param(entries, int, 0);
+MODULE_PARM_DESC(entries, "Number of entries to add (default: 50000)");
+
+static int runs = 4;
+module_param(runs, int, 0);
+MODULE_PARM_DESC(runs, "Number of test runs per variant (default: 4)");
 
+static int max_size = 65536;
+module_param(max_size, int, 0);
+MODULE_PARM_DESC(runs, "Maximum table size (default: 65536)");
 
-#define TEST_HT_SIZE   8
-#define TEST_ENTRIES   2048
-#define TEST_PTR       ((void *) 0xdeadbeef)
-#define TEST_NEXPANDS  4
+static bool shrinking = false;
+module_param(shrinking, bool, 0);
+MODULE_PARM_DESC(shrinking, "Enable automatic shrinking (default: off)");
+
+static int size = 8;
+module_param(size, int, 0);
+MODULE_PARM_DESC(size, "Initial size hint of table (default: 8)");
+
+static int tcount = 10;
+module_param(tcount, int, 0);
+MODULE_PARM_DESC(tcount, "Number of threads to spawn (default: 10)");
 
 struct test_obj {
-       void                    *ptr;
        int                     value;
        struct rhash_head       node;
 };
 
-static const struct rhashtable_params test_rht_params = {
-       .nelem_hint = TEST_HT_SIZE,
+struct thread_data {
+       int id;
+       struct task_struct *task;
+       struct test_obj *objs;
+};
+
+static struct test_obj array[MAX_ENTRIES];
+
+static struct rhashtable_params test_rht_params = {
        .head_offset = offsetof(struct test_obj, node),
        .key_offset = offsetof(struct test_obj, value),
        .key_len = sizeof(int),
@@ -47,15 +73,21 @@ static const struct rhashtable_params test_rht_params = {
        .nulls_base = (3U << RHT_BASE_SHIFT),
 };
 
+static struct semaphore prestart_sem;
+static struct semaphore startup_sem = __SEMAPHORE_INITIALIZER(startup_sem, 0);
+
 static int __init test_rht_lookup(struct rhashtable *ht)
 {
        unsigned int i;
 
-       for (i = 0; i < TEST_ENTRIES * 2; i++) {
+       for (i = 0; i < entries * 2; i++) {
                struct test_obj *obj;
                bool expected = !(i % 2);
                u32 key = i;
 
+               if (array[i / 2].value == TEST_INSERT_FAIL)
+                       expected = false;
+
                obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
 
                if (expected && !obj) {
@@ -66,140 +98,302 @@ static int __init test_rht_lookup(struct rhashtable *ht)
                                key);
                        return -EEXIST;
                } else if (expected && obj) {
-                       if (obj->ptr != TEST_PTR || obj->value != i) {
-                               pr_warn("Test failed: Lookup value mismatch %p!=%p, %u!=%u\n",
-                                       obj->ptr, TEST_PTR, obj->value, i);
+                       if (obj->value != i) {
+                               pr_warn("Test failed: Lookup value mismatch %u!=%u\n",
+                                       obj->value, i);
                                return -EINVAL;
                        }
                }
+
+               cond_resched_rcu();
        }
 
        return 0;
 }
 
-static void test_bucket_stats(struct rhashtable *ht, bool quiet)
+static void test_bucket_stats(struct rhashtable *ht)
 {
-       unsigned int cnt, rcu_cnt, i, total = 0;
+       unsigned int err, total = 0, chain_len = 0;
+       struct rhashtable_iter hti;
        struct rhash_head *pos;
-       struct test_obj *obj;
-       struct bucket_table *tbl;
 
-       tbl = rht_dereference_rcu(ht->tbl, ht);
-       for (i = 0; i < tbl->size; i++) {
-               rcu_cnt = cnt = 0;
+       err = rhashtable_walk_init(ht, &hti);
+       if (err) {
+               pr_warn("Test failed: allocation error");
+               return;
+       }
 
-               if (!quiet)
-                       pr_info(" [%#4x/%u]", i, tbl->size);
+       err = rhashtable_walk_start(&hti);
+       if (err && err != -EAGAIN) {
+               pr_warn("Test failed: iterator failed: %d\n", err);
+               return;
+       }
 
-               rht_for_each_entry_rcu(obj, pos, tbl, i, node) {
-                       cnt++;
-                       total++;
-                       if (!quiet)
-                               pr_cont(" [%p],", obj);
+       while ((pos = rhashtable_walk_next(&hti))) {
+               if (PTR_ERR(pos) == -EAGAIN) {
+                       pr_info("Info: encountered resize\n");
+                       chain_len++;
+                       continue;
+               } else if (IS_ERR(pos)) {
+                       pr_warn("Test failed: rhashtable_walk_next() error: %ld\n",
+                               PTR_ERR(pos));
+                       break;
                }
 
-               rht_for_each_entry_rcu(obj, pos, tbl, i, node)
-                       rcu_cnt++;
-
-               if (rcu_cnt != cnt)
-                       pr_warn("Test failed: Chain count mismach %d != %d",
-                               cnt, rcu_cnt);
-
-               if (!quiet)
-                       pr_cont("\n  [%#x] first element: %p, chain length: %u\n",
-                               i, tbl->buckets[i], cnt);
+               total++;
        }
 
-       pr_info("  Traversal complete: counted=%u, nelems=%u, entries=%d\n",
-               total, atomic_read(&ht->nelems), TEST_ENTRIES);
+       rhashtable_walk_stop(&hti);
+       rhashtable_walk_exit(&hti);
 
-       if (total != atomic_read(&ht->nelems) || total != TEST_ENTRIES)
+       pr_info("  Traversal complete: counted=%u, nelems=%u, entries=%d, table-jumps=%u\n",
+               total, atomic_read(&ht->nelems), entries, chain_len);
+
+       if (total != atomic_read(&ht->nelems) || total != entries)
                pr_warn("Test failed: Total count mismatch ^^^");
 }
 
-static int __init test_rhashtable(struct rhashtable *ht)
+static s64 __init test_rhashtable(struct rhashtable *ht)
 {
-       struct bucket_table *tbl;
        struct test_obj *obj;
-       struct rhash_head *pos, *next;
        int err;
-       unsigned int i;
+       unsigned int i, insert_fails = 0;
+       s64 start, end;
 
        /*
         * Insertion Test:
-        * Insert TEST_ENTRIES into table with all keys even numbers
+        * Insert entries into table with all keys even numbers
         */
-       pr_info("  Adding %d keys\n", TEST_ENTRIES);
-       for (i = 0; i < TEST_ENTRIES; i++) {
-               struct test_obj *obj;
-
-               obj = kzalloc(sizeof(*obj), GFP_KERNEL);
-               if (!obj) {
-                       err = -ENOMEM;
-                       goto error;
-               }
+       pr_info("  Adding %d keys\n", entries);
+       start = ktime_get_ns();
+       for (i = 0; i < entries; i++) {
+               struct test_obj *obj = &array[i];
 
-               obj->ptr = TEST_PTR;
                obj->value = i * 2;
 
                err = rhashtable_insert_fast(ht, &obj->node, test_rht_params);
-               if (err) {
-                       kfree(obj);
-                       goto error;
+               if (err == -ENOMEM || err == -EBUSY) {
+                       /* Mark failed inserts but continue */
+                       obj->value = TEST_INSERT_FAIL;
+                       insert_fails++;
+               } else if (err) {
+                       return err;
                }
+
+               cond_resched();
        }
 
+       if (insert_fails)
+               pr_info("  %u insertions failed due to memory pressure\n",
+                       insert_fails);
+
+       test_bucket_stats(ht);
        rcu_read_lock();
-       test_bucket_stats(ht, true);
        test_rht_lookup(ht);
        rcu_read_unlock();
 
-       rcu_read_lock();
-       test_bucket_stats(ht, true);
-       rcu_read_unlock();
+       test_bucket_stats(ht);
 
-       pr_info("  Deleting %d keys\n", TEST_ENTRIES);
-       for (i = 0; i < TEST_ENTRIES; i++) {
+       pr_info("  Deleting %d keys\n", entries);
+       for (i = 0; i < entries; i++) {
                u32 key = i * 2;
 
-               obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
-               BUG_ON(!obj);
+               if (array[i].value != TEST_INSERT_FAIL) {
+                       obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
+                       BUG_ON(!obj);
+
+                       rhashtable_remove_fast(ht, &obj->node, test_rht_params);
+               }
 
-               rhashtable_remove_fast(ht, &obj->node, test_rht_params);
-               kfree(obj);
+               cond_resched();
        }
 
-       return 0;
+       end = ktime_get_ns();
+       pr_info("  Duration of test: %lld ns\n", end - start);
+
+       return end - start;
+}
 
-error:
-       tbl = rht_dereference_rcu(ht->tbl, ht);
-       for (i = 0; i < tbl->size; i++)
-               rht_for_each_entry_safe(obj, pos, next, tbl, i, node)
-                       kfree(obj);
+static struct rhashtable ht;
 
+static int thread_lookup_test(struct thread_data *tdata)
+{
+       int i, err = 0;
+
+       for (i = 0; i < entries; i++) {
+               struct test_obj *obj;
+               int key = (tdata->id << 16) | i;
+
+               obj = rhashtable_lookup_fast(&ht, &key, test_rht_params);
+               if (obj && (tdata->objs[i].value == TEST_INSERT_FAIL)) {
+                       pr_err("  found unexpected object %d\n", key);
+                       err++;
+               } else if (!obj && (tdata->objs[i].value != TEST_INSERT_FAIL)) {
+                       pr_err("  object %d not found!\n", key);
+                       err++;
+               } else if (obj && (obj->value != key)) {
+                       pr_err("  wrong object returned (got %d, expected %d)\n",
+                              obj->value, key);
+                       err++;
+               }
+       }
        return err;
 }
 
-static struct rhashtable ht;
+static int threadfunc(void *data)
+{
+       int i, step, err = 0, insert_fails = 0;
+       struct thread_data *tdata = data;
+
+       up(&prestart_sem);
+       if (down_interruptible(&startup_sem))
+               pr_err("  thread[%d]: down_interruptible failed\n", tdata->id);
+
+       for (i = 0; i < entries; i++) {
+               tdata->objs[i].value = (tdata->id << 16) | i;
+               err = rhashtable_insert_fast(&ht, &tdata->objs[i].node,
+                                            test_rht_params);
+               if (err == -ENOMEM || err == -EBUSY) {
+                       tdata->objs[i].value = TEST_INSERT_FAIL;
+                       insert_fails++;
+               } else if (err) {
+                       pr_err("  thread[%d]: rhashtable_insert_fast failed\n",
+                              tdata->id);
+                       goto out;
+               }
+       }
+       if (insert_fails)
+               pr_info("  thread[%d]: %d insert failures\n",
+                       tdata->id, insert_fails);
+
+       err = thread_lookup_test(tdata);
+       if (err) {
+               pr_err("  thread[%d]: rhashtable_lookup_test failed\n",
+                      tdata->id);
+               goto out;
+       }
+
+       for (step = 10; step > 0; step--) {
+               for (i = 0; i < entries; i += step) {
+                       if (tdata->objs[i].value == TEST_INSERT_FAIL)
+                               continue;
+                       err = rhashtable_remove_fast(&ht, &tdata->objs[i].node,
+                                                    test_rht_params);
+                       if (err) {
+                               pr_err("  thread[%d]: rhashtable_remove_fast failed\n",
+                                      tdata->id);
+                               goto out;
+                       }
+                       tdata->objs[i].value = TEST_INSERT_FAIL;
+               }
+               err = thread_lookup_test(tdata);
+               if (err) {
+                       pr_err("  thread[%d]: rhashtable_lookup_test (2) failed\n",
+                              tdata->id);
+                       goto out;
+               }
+       }
+out:
+       while (!kthread_should_stop()) {
+               set_current_state(TASK_INTERRUPTIBLE);
+               schedule();
+       }
+       return err;
+}
 
 static int __init test_rht_init(void)
 {
-       int err;
+       int i, err, started_threads = 0, failed_threads = 0;
+       u64 total_time = 0;
+       struct thread_data *tdata;
+       struct test_obj *objs;
+
+       entries = min(entries, MAX_ENTRIES);
+
+       test_rht_params.automatic_shrinking = shrinking;
+       test_rht_params.max_size = max_size;
+       test_rht_params.nelem_hint = size;
+
+       pr_info("Running rhashtable test nelem=%d, max_size=%d, shrinking=%d\n",
+               size, max_size, shrinking);
+
+       for (i = 0; i < runs; i++) {
+               s64 time;
+
+               pr_info("Test %02d:\n", i);
+               memset(&array, 0, sizeof(array));
+               err = rhashtable_init(&ht, &test_rht_params);
+               if (err < 0) {
+                       pr_warn("Test failed: Unable to initialize hashtable: %d\n",
+                               err);
+                       continue;
+               }
 
-       pr_info("Running resizable hashtable tests...\n");
+               time = test_rhashtable(&ht);
+               rhashtable_destroy(&ht);
+               if (time < 0) {
+                       pr_warn("Test failed: return code %lld\n", time);
+                       return -EINVAL;
+               }
+
+               total_time += time;
+       }
+
+       do_div(total_time, runs);
+       pr_info("Average test time: %llu\n", total_time);
+
+       if (!tcount)
+               return 0;
+
+       pr_info("Testing concurrent rhashtable access from %d threads\n",
+               tcount);
+       sema_init(&prestart_sem, 1 - tcount);
+       tdata = vzalloc(tcount * sizeof(struct thread_data));
+       if (!tdata)
+               return -ENOMEM;
+       objs  = vzalloc(tcount * entries * sizeof(struct test_obj));
+       if (!objs) {
+               vfree(tdata);
+               return -ENOMEM;
+       }
 
        err = rhashtable_init(&ht, &test_rht_params);
        if (err < 0) {
                pr_warn("Test failed: Unable to initialize hashtable: %d\n",
                        err);
-               return err;
+               vfree(tdata);
+               vfree(objs);
+               return -EINVAL;
        }
-
-       err = test_rhashtable(&ht);
-
+       for (i = 0; i < tcount; i++) {
+               tdata[i].id = i;
+               tdata[i].objs = objs + i * entries;
+               tdata[i].task = kthread_run(threadfunc, &tdata[i],
+                                           "rhashtable_thrad[%d]", i);
+               if (IS_ERR(tdata[i].task))
+                       pr_err(" kthread_run failed for thread %d\n", i);
+               else
+                       started_threads++;
+       }
+       if (down_interruptible(&prestart_sem))
+               pr_err("  down interruptible failed\n");
+       for (i = 0; i < tcount; i++)
+               up(&startup_sem);
+       for (i = 0; i < tcount; i++) {
+               if (IS_ERR(tdata[i].task))
+                       continue;
+               if ((err = kthread_stop(tdata[i].task))) {
+                       pr_warn("Test failed: thread %d returned: %d\n",
+                               i, err);
+                       failed_threads++;
+               }
+       }
+       pr_info("Started %d threads, %d failed\n",
+               started_threads, failed_threads);
        rhashtable_destroy(&ht);
-
-       return err;
+       vfree(tdata);
+       vfree(objs);
+       return 0;
 }
 
 static void __exit test_rht_exit(void)