Cleanup NodeContextTestCase unit tests
[yardstick.git] / yardstick / common / utils.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 import collections
17 from contextlib import closing
18 import datetime
19 import errno
20 import importlib
21 import ipaddress
22 import logging
23 import os
24 import random
25 import re
26 import signal
27 import socket
28 import subprocess
29 import sys
30 import time
31
32 import six
33 from flask import jsonify
34 from six.moves import configparser
35 from oslo_serialization import jsonutils
36 from oslo_utils import encodeutils
37
38 import yardstick
39 from yardstick.common import exceptions
40 from yardstick.common.yaml_loader import yaml_load
41
42
43 logger = logging.getLogger(__name__)
44 logger.setLevel(logging.DEBUG)
45
46
47 # Decorator for cli-args
48 def cliargs(*args, **kwargs):
49     def _decorator(func):
50         func.__dict__.setdefault('arguments', []).insert(0, (args, kwargs))
51         return func
52     return _decorator
53
54
55 def itersubclasses(cls, _seen=None):
56     """Generator over all subclasses of a given class in depth first order."""
57
58     if not isinstance(cls, type):
59         raise TypeError("itersubclasses must be called with "
60                         "new-style classes, not %.100r" % cls)
61     _seen = _seen or set()
62     try:
63         subs = cls.__subclasses__()
64     except TypeError:   # fails only when cls is type
65         subs = cls.__subclasses__(cls)
66     for sub in subs:
67         if sub not in _seen:
68             _seen.add(sub)
69             yield sub
70             for sub in itersubclasses(sub, _seen):
71                 yield sub
72
73
74 def import_modules_from_package(package, raise_exception=False):
75     """Import modules given a package name
76
77     :param: package - Full package name. For example: rally.deploy.engines
78     """
79     yardstick_root = os.path.dirname(os.path.dirname(yardstick.__file__))
80     path = os.path.join(yardstick_root, *package.split('.'))
81     for root, _, files in os.walk(path):
82         matches = (filename for filename in files if filename.endswith('.py')
83                    and not filename.startswith('__'))
84         new_package = os.path.relpath(root, yardstick_root).replace(os.sep,
85                                                                     '.')
86         module_names = set(
87             '{}.{}'.format(new_package, filename.rsplit('.py', 1)[0])
88             for filename in matches)
89         # Find modules which haven't already been imported
90         missing_modules = module_names.difference(sys.modules)
91         logger.debug('Importing modules: %s', missing_modules)
92         for module_name in missing_modules:
93             try:
94                 importlib.import_module(module_name)
95             except (ImportError, SyntaxError) as exc:
96                 if raise_exception:
97                     raise exc
98                 logger.exception('Unable to import module %s', module_name)
99
100
101 NON_NONE_DEFAULT = object()
102
103
104 def get_key_with_default(data, key, default=NON_NONE_DEFAULT):
105     value = data.get(key, default)
106     if value is NON_NONE_DEFAULT:
107         raise KeyError(key)
108     return value
109
110
111 def make_dict_from_map(data, key_map):
112     return {dest_key: get_key_with_default(data, src_key, default)
113             for dest_key, (src_key, default) in key_map.items()}
114
115
116 def makedirs(d):
117     try:
118         os.makedirs(d)
119     except OSError as e:
120         if e.errno != errno.EEXIST:
121             raise
122
123
124 def remove_file(path):
125     try:
126         os.remove(path)
127     except OSError as e:
128         if e.errno != errno.ENOENT:
129             raise
130
131
132 def execute_command(cmd, **kwargs):
133     exec_msg = "Executing command: '%s'" % cmd
134     logger.debug(exec_msg)
135
136     output = subprocess.check_output(cmd.split(), **kwargs)
137     return encodeutils.safe_decode(output, incoming='utf-8').split(os.linesep)
138
139
140 def source_env(env_file):
141     p = subprocess.Popen(". %s; env" % env_file, stdout=subprocess.PIPE,
142                          shell=True)
143     output = p.communicate()[0]
144
145     # sometimes output type would be binary_type, and it don't have splitlines
146     # method, so we need to decode
147     if isinstance(output, six.binary_type):
148         output = encodeutils.safe_decode(output)
149     env = dict(line.split('=', 1) for line in output.splitlines() if '=' in line)
150     os.environ.update(env)
151     return env
152
153
154 def read_json_from_file(path):
155     with open(path, 'r') as f:
156         j = f.read()
157     # don't use jsonutils.load() it conflicts with already decoded input
158     return jsonutils.loads(j)
159
160
161 def write_json_to_file(path, data, mode='w'):
162     with open(path, mode) as f:
163         jsonutils.dump(data, f)
164
165
166 def write_file(path, data, mode='w'):
167     with open(path, mode) as f:
168         f.write(data)
169
170
171 def parse_ini_file(path):
172     parser = configparser.ConfigParser()
173
174     try:
175         files = parser.read(path)
176     except configparser.MissingSectionHeaderError:
177         logger.exception('invalid file type')
178         raise
179     else:
180         if not files:
181             raise RuntimeError('file not exist')
182
183     try:
184         default = {k: v for k, v in parser.items('DEFAULT')}
185     except configparser.NoSectionError:
186         default = {}
187
188     config = dict(DEFAULT=default,
189                   **{s: {k: v for k, v in parser.items(
190                       s)} for s in parser.sections()})
191
192     return config
193
194
195 def get_port_mac(sshclient, port):
196     cmd = "ifconfig |grep HWaddr |grep %s |awk '{print $5}' " % port
197     status, stdout, stderr = sshclient.execute(cmd)
198
199     if status:
200         raise RuntimeError(stderr)
201     return stdout.rstrip()
202
203
204 def get_port_ip(sshclient, port):
205     cmd = "ifconfig %s |grep 'inet addr' |awk '{print $2}' " \
206         "|cut -d ':' -f2 " % port
207     status, stdout, stderr = sshclient.execute(cmd)
208
209     if status:
210         raise RuntimeError(stderr)
211     return stdout.rstrip()
212
213
214 def flatten_dict_key(data):
215     next_data = {}
216
217     # use list, because iterable is too generic
218     if not any(isinstance(v, (collections.Mapping, list))
219                for v in data.values()):
220         return data
221
222     for k, v in data.items():
223         if isinstance(v, collections.Mapping):
224             for n_k, n_v in v.items():
225                 next_data["%s.%s" % (k, n_k)] = n_v
226         # use list because iterable is too generic
227         elif isinstance(v, collections.Iterable) and not isinstance(v, six.string_types):
228             for index, item in enumerate(v):
229                 next_data["%s%d" % (k, index)] = item
230         else:
231             next_data[k] = v
232
233     return flatten_dict_key(next_data)
234
235
236 def translate_to_str(obj):
237     if isinstance(obj, collections.Mapping):
238         return {str(k): translate_to_str(v) for k, v in obj.items()}
239     elif isinstance(obj, list):
240         return [translate_to_str(ele) for ele in obj]
241     elif isinstance(obj, six.text_type):
242         return str(obj)
243     return obj
244
245
246 def result_handler(status, data):
247     result = {
248         'status': status,
249         'result': data
250     }
251     return jsonify(result)
252
253
254 def change_obj_to_dict(obj):
255     dic = {}
256     for k, v in vars(obj).items():
257         try:
258             vars(v)
259         except TypeError:
260             dic.update({k: v})
261     return dic
262
263
264 def set_dict_value(dic, keys, value):
265     return_dic = dic
266
267     for key in keys.split('.'):
268         return_dic.setdefault(key, {})
269         if key == keys.split('.')[-1]:
270             return_dic[key] = value
271         else:
272             return_dic = return_dic[key]
273     return dic
274
275
276 def get_free_port(ip):
277     with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
278         port = random.randint(5000, 10000)
279         while s.connect_ex((ip, port)) == 0:
280             port = random.randint(5000, 10000)
281         return port
282
283
284 def mac_address_to_hex_list(mac):
285     octets = ["0x{:02x}".format(int(elem, 16)) for elem in mac.split(':')]
286     assert len(octets) == 6 and all(len(octet) == 4 for octet in octets)
287     return octets
288
289
290 def safe_ip_address(ip_addr):
291     """ get ip address version v6 or v4 """
292     try:
293         return ipaddress.ip_address(six.text_type(ip_addr))
294     except ValueError:
295         logging.error("%s is not valid", ip_addr)
296         return None
297
298
299 def get_ip_version(ip_addr):
300     """ get ip address version v6 or v4 """
301     try:
302         address = ipaddress.ip_address(six.text_type(ip_addr))
303     except ValueError:
304         logging.error("%s is not valid", ip_addr)
305         return None
306     else:
307         return address.version
308
309
310 def make_ip_addr(ip, mask):
311     """
312     :param ip[str]: ip adddress
313     :param mask[str]: /24 prefix of 255.255.255.0 netmask
314     :return: IPv4Interface object
315     """
316     try:
317         return ipaddress.ip_interface(six.text_type('/'.join([ip, mask])))
318     except (TypeError, ValueError):
319         # None so we can skip later
320         return None
321
322
323 def ip_to_hex(ip_addr, separator=''):
324     try:
325         address = ipaddress.ip_address(six.text_type(ip_addr))
326     except ValueError:
327         logging.error("%s is not valid", ip_addr)
328         return ip_addr
329
330     if address.version != 4:
331         return ip_addr
332
333     if not separator:
334         return '{:08x}'.format(int(address))
335
336     return separator.join('{:02x}'.format(octet) for octet in address.packed)
337
338
339 def try_int(s, *args):
340     """Convert to integer if possible."""
341     try:
342         return int(s)
343     except (TypeError, ValueError):
344         return args[0] if args else s
345
346
347 class SocketTopology(dict):
348
349     @classmethod
350     def parse_cpuinfo(cls, cpuinfo):
351         socket_map = {}
352
353         lines = cpuinfo.splitlines()
354
355         core_details = []
356         core_lines = {}
357         for line in lines:
358             if line.strip():
359                 name, value = line.split(":", 1)
360                 core_lines[name.strip()] = try_int(value.strip())
361             else:
362                 core_details.append(core_lines)
363                 core_lines = {}
364
365         for core in core_details:
366             socket_map.setdefault(core["physical id"], {}).setdefault(
367                 core["core id"], {})[core["processor"]] = (
368                 core["processor"], core["core id"], core["physical id"])
369
370         return cls(socket_map)
371
372     def sockets(self):
373         return sorted(self.keys())
374
375     def cores(self):
376         return sorted(core for cores in self.values() for core in cores)
377
378     def processors(self):
379         return sorted(
380             proc for cores in self.values() for procs in cores.values() for
381             proc in procs)
382
383
384 def config_to_dict(config):
385     return {section: dict(config.items(section)) for section in
386             config.sections()}
387
388
389 def validate_non_string_sequence(value, default=None, raise_exc=None):
390     # NOTE(ralonsoh): refactor this function to check if raise_exc is an
391     # Exception. Remove duplicate code, this function is duplicated in this
392     # repository.
393     if isinstance(value, collections.Sequence) and not isinstance(value, six.string_types):
394         return value
395     if raise_exc:
396         raise raise_exc  # pylint: disable=raising-bad-type
397     return default
398
399
400 def join_non_strings(separator, *non_strings):
401     try:
402         non_strings = validate_non_string_sequence(non_strings[0], raise_exc=RuntimeError)
403     except (IndexError, RuntimeError):
404         pass
405     return str(separator).join(str(non_string) for non_string in non_strings)
406
407
408 def safe_decode_utf8(s):
409     """Safe decode a str from UTF"""
410     if six.PY3 and isinstance(s, bytes):
411         return s.decode('utf-8', 'surrogateescape')
412     return s
413
414
415 class ErrorClass(object):
416
417     def __init__(self, *args, **kwargs):
418         if 'test' not in kwargs:
419             raise RuntimeError
420
421     def __getattr__(self, item):
422         raise AttributeError
423
424
425 class Timer(object):
426     def __init__(self, timeout=None, raise_exception=True):
427         super(Timer, self).__init__()
428         self.start = self.delta = None
429         self._timeout = int(timeout) if timeout else None
430         self._timeout_flag = False
431         self._raise_exception = raise_exception
432
433     def _timeout_handler(self, *args):
434         self._timeout_flag = True
435         if self._raise_exception:
436             raise exceptions.TimerTimeout(timeout=self._timeout)
437         self.__exit__()
438
439     def __enter__(self):
440         self.start = datetime.datetime.now()
441         if self._timeout:
442             signal.signal(signal.SIGALRM, self._timeout_handler)
443             signal.alarm(self._timeout)
444         return self
445
446     def __exit__(self, *_):
447         if self._timeout:
448             signal.alarm(0)
449         self.delta = datetime.datetime.now() - self.start
450
451     def __getattr__(self, item):
452         return getattr(self.delta, item)
453
454     def __iter__(self):
455         self._raise_exception = False
456         return self.__enter__()
457
458     def next(self):  # pragma: no cover
459         # NOTE(ralonsoh): Python 2 support.
460         if not self._timeout_flag:
461             return datetime.datetime.now()
462         raise StopIteration()
463
464     def __next__(self):  # pragma: no cover
465         # NOTE(ralonsoh): Python 3 support.
466         return self.next()
467
468     def __del__(self):  # pragma: no cover
469         signal.alarm(0)
470
471
472 def read_meminfo(ssh_client):
473     """Read "/proc/meminfo" file and parse all keys and values"""
474
475     cpuinfo = six.BytesIO()
476     ssh_client.get_file_obj('/proc/meminfo', cpuinfo)
477     lines = cpuinfo.getvalue().decode('utf-8')
478     matches = re.findall(r"([\w\(\)]+):\s+(\d+)( kB)*", lines)
479     output = {}
480     for match in matches:
481         output[match[0]] = match[1]
482
483     return output
484
485
486 def find_relative_file(path, task_path):
487     """
488     Find file in one of places: in abs of path or relative to a directory path,
489     in this order.
490
491     :param path:
492     :param task_path:
493     :return str: full path to file
494     """
495     # fixme: create schema to validate all fields have been provided
496     for lookup in [os.path.abspath(path), os.path.join(task_path, path)]:
497         try:
498             with open(lookup):
499                 return lookup
500         except IOError:
501             pass
502     raise IOError(errno.ENOENT, 'Unable to find {} file'.format(path))
503
504
505 def open_relative_file(path, task_path):
506     try:
507         return open(path)
508     except IOError as e:
509         if e.errno == errno.ENOENT:
510             return open(os.path.join(task_path, path))
511         raise
512
513
514 def wait_until_true(predicate, timeout=60, sleep=1, exception=None):
515     """Wait until callable predicate is evaluated as True
516
517     :param predicate: (func) callable deciding whether waiting should continue
518     :param timeout: (int) timeout in seconds how long should function wait
519     :param sleep: (int) polling interval for results in seconds
520     :param exception: exception instance to raise on timeout. If None is passed
521                       (default) then WaitTimeout exception is raised.
522     """
523     try:
524         with Timer(timeout=timeout):
525             while not predicate():
526                 time.sleep(sleep)
527     except exceptions.TimerTimeout:
528         if exception and issubclass(exception, Exception):
529             raise exception  # pylint: disable=raising-bad-type
530         raise exceptions.WaitTimeout
531
532
533 def read_yaml_file(path):
534     """Read yaml file"""
535
536     with open(path) as stream:
537         data = yaml_load(stream)
538     return data