Merge "Enable "wait_until_true" when used ouf the main thread"
[yardstick.git] / yardstick / common / utils.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 import collections
17 from contextlib import closing
18 import datetime
19 import errno
20 import importlib
21 import ipaddress
22 import logging
23 import os
24 import random
25 import re
26 import signal
27 import socket
28 import subprocess
29 import sys
30 import time
31 import threading
32
33 import six
34 from flask import jsonify
35 from six.moves import configparser
36 from oslo_serialization import jsonutils
37 from oslo_utils import encodeutils
38
39 import yardstick
40 from yardstick.common import exceptions
41
42
43 logger = logging.getLogger(__name__)
44 logger.setLevel(logging.DEBUG)
45
46
47 # Decorator for cli-args
48 def cliargs(*args, **kwargs):
49     def _decorator(func):
50         func.__dict__.setdefault('arguments', []).insert(0, (args, kwargs))
51         return func
52     return _decorator
53
54
55 def itersubclasses(cls, _seen=None):
56     """Generator over all subclasses of a given class in depth first order."""
57
58     if not isinstance(cls, type):
59         raise TypeError("itersubclasses must be called with "
60                         "new-style classes, not %.100r" % cls)
61     _seen = _seen or set()
62     try:
63         subs = cls.__subclasses__()
64     except TypeError:   # fails only when cls is type
65         subs = cls.__subclasses__(cls)
66     for sub in subs:
67         if sub not in _seen:
68             _seen.add(sub)
69             yield sub
70             for sub in itersubclasses(sub, _seen):
71                 yield sub
72
73
74 def import_modules_from_package(package, raise_exception=False):
75     """Import modules given a package name
76
77     :param: package - Full package name. For example: rally.deploy.engines
78     """
79     yardstick_root = os.path.dirname(os.path.dirname(yardstick.__file__))
80     path = os.path.join(yardstick_root, *package.split('.'))
81     for root, _, files in os.walk(path):
82         matches = (filename for filename in files if filename.endswith('.py')
83                    and not filename.startswith('__'))
84         new_package = os.path.relpath(root, yardstick_root).replace(os.sep,
85                                                                     '.')
86         module_names = set(
87             '{}.{}'.format(new_package, filename.rsplit('.py', 1)[0])
88             for filename in matches)
89         # Find modules which haven't already been imported
90         missing_modules = module_names.difference(sys.modules)
91         logger.debug('Importing modules: %s', missing_modules)
92         for module_name in missing_modules:
93             try:
94                 importlib.import_module(module_name)
95             except (ImportError, SyntaxError) as exc:
96                 if raise_exception:
97                     raise exc
98                 logger.exception('Unable to import module %s', module_name)
99
100
101 NON_NONE_DEFAULT = object()
102
103
104 def get_key_with_default(data, key, default=NON_NONE_DEFAULT):
105     value = data.get(key, default)
106     if value is NON_NONE_DEFAULT:
107         raise KeyError(key)
108     return value
109
110
111 def make_dict_from_map(data, key_map):
112     return {dest_key: get_key_with_default(data, src_key, default)
113             for dest_key, (src_key, default) in key_map.items()}
114
115
116 def makedirs(d):
117     try:
118         os.makedirs(d)
119     except OSError as e:
120         if e.errno != errno.EEXIST:
121             raise
122
123
124 def remove_file(path):
125     try:
126         os.remove(path)
127     except OSError as e:
128         if e.errno != errno.ENOENT:
129             raise
130
131
132 def execute_command(cmd, **kwargs):
133     exec_msg = "Executing command: '%s'" % cmd
134     logger.debug(exec_msg)
135
136     output = subprocess.check_output(cmd.split(), **kwargs)
137     return encodeutils.safe_decode(output, incoming='utf-8').split(os.linesep)
138
139
140 def source_env(env_file):
141     p = subprocess.Popen(". %s; env" % env_file, stdout=subprocess.PIPE,
142                          shell=True)
143     output = p.communicate()[0]
144
145     # sometimes output type would be binary_type, and it don't have splitlines
146     # method, so we need to decode
147     if isinstance(output, six.binary_type):
148         output = encodeutils.safe_decode(output)
149     env = dict(line.split('=', 1) for line in output.splitlines() if '=' in line)
150     os.environ.update(env)
151     return env
152
153
154 def read_json_from_file(path):
155     with open(path, 'r') as f:
156         j = f.read()
157     # don't use jsonutils.load() it conflicts with already decoded input
158     return jsonutils.loads(j)
159
160
161 def write_json_to_file(path, data, mode='w'):
162     with open(path, mode) as f:
163         jsonutils.dump(data, f)
164
165
166 def write_file(path, data, mode='w'):
167     with open(path, mode) as f:
168         f.write(data)
169
170
171 def parse_ini_file(path):
172     parser = configparser.ConfigParser()
173
174     try:
175         files = parser.read(path)
176     except configparser.MissingSectionHeaderError:
177         logger.exception('invalid file type')
178         raise
179     else:
180         if not files:
181             raise RuntimeError('file not exist')
182
183     try:
184         default = {k: v for k, v in parser.items('DEFAULT')}
185     except configparser.NoSectionError:
186         default = {}
187
188     config = dict(DEFAULT=default,
189                   **{s: {k: v for k, v in parser.items(
190                       s)} for s in parser.sections()})
191
192     return config
193
194
195 def get_port_mac(sshclient, port):
196     cmd = "ifconfig |grep HWaddr |grep %s |awk '{print $5}' " % port
197     status, stdout, stderr = sshclient.execute(cmd)
198
199     if status:
200         raise RuntimeError(stderr)
201     return stdout.rstrip()
202
203
204 def get_port_ip(sshclient, port):
205     cmd = "ifconfig %s |grep 'inet addr' |awk '{print $2}' " \
206         "|cut -d ':' -f2 " % port
207     status, stdout, stderr = sshclient.execute(cmd)
208
209     if status:
210         raise RuntimeError(stderr)
211     return stdout.rstrip()
212
213
214 def flatten_dict_key(data):
215     next_data = {}
216
217     # use list, because iterable is too generic
218     if not any(isinstance(v, (collections.Mapping, list))
219                for v in data.values()):
220         return data
221
222     for k, v in data.items():
223         if isinstance(v, collections.Mapping):
224             for n_k, n_v in v.items():
225                 next_data["%s.%s" % (k, n_k)] = n_v
226         # use list because iterable is too generic
227         elif isinstance(v, collections.Iterable) and not isinstance(v, six.string_types):
228             for index, item in enumerate(v):
229                 next_data["%s%d" % (k, index)] = item
230         else:
231             next_data[k] = v
232
233     return flatten_dict_key(next_data)
234
235
236 def translate_to_str(obj):
237     if isinstance(obj, collections.Mapping):
238         return {str(k): translate_to_str(v) for k, v in obj.items()}
239     elif isinstance(obj, list):
240         return [translate_to_str(ele) for ele in obj]
241     elif isinstance(obj, six.text_type):
242         return str(obj)
243     return obj
244
245
246 def result_handler(status, data):
247     result = {
248         'status': status,
249         'result': data
250     }
251     return jsonify(result)
252
253
254 def change_obj_to_dict(obj):
255     dic = {}
256     for k, v in vars(obj).items():
257         try:
258             vars(v)
259         except TypeError:
260             dic.update({k: v})
261     return dic
262
263
264 def set_dict_value(dic, keys, value):
265     return_dic = dic
266
267     for key in keys.split('.'):
268         return_dic.setdefault(key, {})
269         if key == keys.split('.')[-1]:
270             return_dic[key] = value
271         else:
272             return_dic = return_dic[key]
273     return dic
274
275
276 def get_free_port(ip):
277     with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
278         port = random.randint(5000, 10000)
279         while s.connect_ex((ip, port)) == 0:
280             port = random.randint(5000, 10000)
281         return port
282
283
284 def mac_address_to_hex_list(mac):
285     octets = ["0x{:02x}".format(int(elem, 16)) for elem in mac.split(':')]
286     assert len(octets) == 6 and all(len(octet) == 4 for octet in octets)
287     return octets
288
289
290 def safe_ip_address(ip_addr):
291     """ get ip address version v6 or v4 """
292     try:
293         return ipaddress.ip_address(six.text_type(ip_addr))
294     except ValueError:
295         logging.error("%s is not valid", ip_addr)
296         return None
297
298
299 def get_ip_version(ip_addr):
300     """ get ip address version v6 or v4 """
301     try:
302         address = ipaddress.ip_address(six.text_type(ip_addr))
303     except ValueError:
304         logging.error("%s is not valid", ip_addr)
305         return None
306     else:
307         return address.version
308
309
310 def make_ip_addr(ip, mask):
311     """
312     :param ip[str]: ip adddress
313     :param mask[str]: /24 prefix of 255.255.255.0 netmask
314     :return: IPv4Interface object
315     """
316     try:
317         return ipaddress.ip_interface(six.text_type('/'.join([ip, mask])))
318     except (TypeError, ValueError):
319         # None so we can skip later
320         return None
321
322
323 def ip_to_hex(ip_addr, separator=''):
324     try:
325         address = ipaddress.ip_address(six.text_type(ip_addr))
326     except ValueError:
327         logging.error("%s is not valid", ip_addr)
328         return ip_addr
329
330     if address.version != 4:
331         return ip_addr
332
333     if not separator:
334         return '{:08x}'.format(int(address))
335
336     return separator.join('{:02x}'.format(octet) for octet in address.packed)
337
338
339 def get_mask_from_ip_range(ip_low, ip_high):
340     _ip_low = ipaddress.ip_address(ip_low)
341     _ip_high = ipaddress.ip_address(ip_high)
342     _ip_low_int = int(_ip_low)
343     _ip_high_int = int(_ip_high)
344     return _ip_high.max_prefixlen - (_ip_high_int ^ _ip_low_int).bit_length()
345
346
347 def try_int(s, *args):
348     """Convert to integer if possible."""
349     try:
350         return int(s)
351     except (TypeError, ValueError):
352         return args[0] if args else s
353
354
355 class SocketTopology(dict):
356
357     @classmethod
358     def parse_cpuinfo(cls, cpuinfo):
359         socket_map = {}
360
361         lines = cpuinfo.splitlines()
362
363         core_details = []
364         core_lines = {}
365         for line in lines:
366             if line.strip():
367                 name, value = line.split(":", 1)
368                 core_lines[name.strip()] = try_int(value.strip())
369             else:
370                 core_details.append(core_lines)
371                 core_lines = {}
372
373         for core in core_details:
374             socket_map.setdefault(core["physical id"], {}).setdefault(
375                 core["core id"], {})[core["processor"]] = (
376                 core["processor"], core["core id"], core["physical id"])
377
378         return cls(socket_map)
379
380     def sockets(self):
381         return sorted(self.keys())
382
383     def cores(self):
384         return sorted(core for cores in self.values() for core in cores)
385
386     def processors(self):
387         return sorted(
388             proc for cores in self.values() for procs in cores.values() for
389             proc in procs)
390
391
392 def config_to_dict(config):
393     return {section: dict(config.items(section)) for section in
394             config.sections()}
395
396
397 def validate_non_string_sequence(value, default=None, raise_exc=None):
398     # NOTE(ralonsoh): refactor this function to check if raise_exc is an
399     # Exception. Remove duplicate code, this function is duplicated in this
400     # repository.
401     if isinstance(value, collections.Sequence) and not isinstance(value, six.string_types):
402         return value
403     if raise_exc:
404         raise raise_exc  # pylint: disable=raising-bad-type
405     return default
406
407
408 def join_non_strings(separator, *non_strings):
409     try:
410         non_strings = validate_non_string_sequence(non_strings[0], raise_exc=RuntimeError)
411     except (IndexError, RuntimeError):
412         pass
413     return str(separator).join(str(non_string) for non_string in non_strings)
414
415
416 def safe_decode_utf8(s):
417     """Safe decode a str from UTF"""
418     if six.PY3 and isinstance(s, bytes):
419         return s.decode('utf-8', 'surrogateescape')
420     return s
421
422
423 class ErrorClass(object):
424
425     def __init__(self, *args, **kwargs):
426         if 'test' not in kwargs:
427             raise RuntimeError
428
429     def __getattr__(self, item):
430         raise AttributeError
431
432
433 class Timer(object):
434     def __init__(self, timeout=None, raise_exception=True):
435         super(Timer, self).__init__()
436         self.start = self.delta = None
437         self._timeout = int(timeout) if timeout else None
438         self._timeout_flag = False
439         self._raise_exception = raise_exception
440
441     def _timeout_handler(self, *args):
442         self._timeout_flag = True
443         if self._raise_exception:
444             raise exceptions.TimerTimeout(timeout=self._timeout)
445         self.__exit__()
446
447     def __enter__(self):
448         self.start = datetime.datetime.now()
449         if self._timeout:
450             signal.signal(signal.SIGALRM, self._timeout_handler)
451             signal.alarm(self._timeout)
452         return self
453
454     def __exit__(self, *_):
455         if self._timeout:
456             signal.alarm(0)
457         self.delta = datetime.datetime.now() - self.start
458
459     def __getattr__(self, item):
460         return getattr(self.delta, item)
461
462     def __iter__(self):
463         self._raise_exception = False
464         return self.__enter__()
465
466     def next(self):  # pragma: no cover
467         # NOTE(ralonsoh): Python 2 support.
468         if not self._timeout_flag:
469             return datetime.datetime.now()
470         raise StopIteration()
471
472     def __next__(self):  # pragma: no cover
473         # NOTE(ralonsoh): Python 3 support.
474         return self.next()
475
476     def __del__(self):  # pragma: no cover
477         signal.alarm(0)
478
479     def delta_time_sec(self):
480         return (datetime.datetime.now() - self.start).total_seconds()
481
482
483 def read_meminfo(ssh_client):
484     """Read "/proc/meminfo" file and parse all keys and values"""
485
486     cpuinfo = six.BytesIO()
487     ssh_client.get_file_obj('/proc/meminfo', cpuinfo)
488     lines = cpuinfo.getvalue().decode('utf-8')
489     matches = re.findall(r"([\w\(\)]+):\s+(\d+)( kB)*", lines)
490     output = {}
491     for match in matches:
492         output[match[0]] = match[1]
493
494     return output
495
496
497 def find_relative_file(path, task_path):
498     """
499     Find file in one of places: in abs of path or relative to a directory path,
500     in this order.
501
502     :param path:
503     :param task_path:
504     :return str: full path to file
505     """
506     # fixme: create schema to validate all fields have been provided
507     for lookup in [os.path.abspath(path), os.path.join(task_path, path)]:
508         try:
509             with open(lookup):
510                 return lookup
511         except IOError:
512             pass
513     raise IOError(errno.ENOENT, 'Unable to find {} file'.format(path))
514
515
516 def open_relative_file(path, task_path):
517     try:
518         return open(path)
519     except IOError as e:
520         if e.errno == errno.ENOENT:
521             return open(os.path.join(task_path, path))
522         raise
523
524
525 def wait_until_true(predicate, timeout=60, sleep=1, exception=None):
526     """Wait until callable predicate is evaluated as True
527
528     When in a thread different from the main one, Timer(timeout) will fail
529     because signal is not handled. In this case
530
531     :param predicate: (func) callable deciding whether waiting should continue
532     :param timeout: (int) timeout in seconds how long should function wait
533     :param sleep: (int) polling interval for results in seconds
534     :param exception: exception instance to raise on timeout. If None is passed
535                       (default) then WaitTimeout exception is raised.
536     """
537     if isinstance(threading.current_thread(), threading._MainThread):
538         try:
539             with Timer(timeout=timeout):
540                 while not predicate():
541                     time.sleep(sleep)
542         except exceptions.TimerTimeout:
543             if exception and issubclass(exception, Exception):
544                 raise exception  # pylint: disable=raising-bad-type
545             raise exceptions.WaitTimeout
546     else:
547         with Timer() as timer:
548             while timer.delta_time_sec() < timeout:
549                 if predicate():
550                     return
551                 time.sleep(sleep)
552         if exception and issubclass(exception, Exception):
553             raise exception  # pylint: disable=raising-bad-type
554         raise exceptions.WaitTimeout
555
556
557 def send_socket_command(host, port, command):
558     """Send a string command to a specific port in a host
559
560     :param host: (str) ip or hostname of the host
561     :param port: (int) port number
562     :param command: (str) command to send
563     :return: 0 if success, error number if error
564     """
565     sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
566     ret = 0
567     try:
568         err_number = sock.connect_ex((host, int(port)))
569         if err_number != 0:
570             return err_number
571         sock.sendall(six.b(command))
572     except Exception:  # pylint: disable=broad-except
573         ret = 1
574     finally:
575         sock.close()
576     return ret