Merge "Add active wait function"
[yardstick.git] / yardstick / common / utils.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 import collections
17 from contextlib import closing
18 import datetime
19 import errno
20 import importlib
21 import ipaddress
22 import logging
23 import os
24 import random
25 import re
26 import signal
27 import socket
28 import subprocess
29 import sys
30 import time
31
32 import six
33 from flask import jsonify
34 from six.moves import configparser
35 from oslo_serialization import jsonutils
36 from oslo_utils import encodeutils
37
38 import yardstick
39 from yardstick.common import exceptions
40
41
42 logger = logging.getLogger(__name__)
43 logger.setLevel(logging.DEBUG)
44
45
46 # Decorator for cli-args
47 def cliargs(*args, **kwargs):
48     def _decorator(func):
49         func.__dict__.setdefault('arguments', []).insert(0, (args, kwargs))
50         return func
51     return _decorator
52
53
54 def itersubclasses(cls, _seen=None):
55     """Generator over all subclasses of a given class in depth first order."""
56
57     if not isinstance(cls, type):
58         raise TypeError("itersubclasses must be called with "
59                         "new-style classes, not %.100r" % cls)
60     _seen = _seen or set()
61     try:
62         subs = cls.__subclasses__()
63     except TypeError:   # fails only when cls is type
64         subs = cls.__subclasses__(cls)
65     for sub in subs:
66         if sub not in _seen:
67             _seen.add(sub)
68             yield sub
69             for sub in itersubclasses(sub, _seen):
70                 yield sub
71
72
73 def import_modules_from_package(package, raise_exception=False):
74     """Import modules given a package name
75
76     :param: package - Full package name. For example: rally.deploy.engines
77     """
78     yardstick_root = os.path.dirname(os.path.dirname(yardstick.__file__))
79     path = os.path.join(yardstick_root, *package.split('.'))
80     for root, _, files in os.walk(path):
81         matches = (filename for filename in files if filename.endswith('.py')
82                    and not filename.startswith('__'))
83         new_package = os.path.relpath(root, yardstick_root).replace(os.sep,
84                                                                     '.')
85         module_names = set(
86             '{}.{}'.format(new_package, filename.rsplit('.py', 1)[0])
87             for filename in matches)
88         # Find modules which haven't already been imported
89         missing_modules = module_names.difference(sys.modules)
90         logger.debug('Importing modules: %s', missing_modules)
91         for module_name in missing_modules:
92             try:
93                 importlib.import_module(module_name)
94             except (ImportError, SyntaxError) as exc:
95                 if raise_exception:
96                     raise exc
97                 logger.exception('Unable to import module %s', module_name)
98
99
100 NON_NONE_DEFAULT = object()
101
102
103 def get_key_with_default(data, key, default=NON_NONE_DEFAULT):
104     value = data.get(key, default)
105     if value is NON_NONE_DEFAULT:
106         raise KeyError(key)
107     return value
108
109
110 def make_dict_from_map(data, key_map):
111     return {dest_key: get_key_with_default(data, src_key, default)
112             for dest_key, (src_key, default) in key_map.items()}
113
114
115 def makedirs(d):
116     try:
117         os.makedirs(d)
118     except OSError as e:
119         if e.errno != errno.EEXIST:
120             raise
121
122
123 def remove_file(path):
124     try:
125         os.remove(path)
126     except OSError as e:
127         if e.errno != errno.ENOENT:
128             raise
129
130
131 def execute_command(cmd, **kwargs):
132     exec_msg = "Executing command: '%s'" % cmd
133     logger.debug(exec_msg)
134
135     output = subprocess.check_output(cmd.split(), **kwargs)
136     return encodeutils.safe_decode(output, incoming='utf-8').split(os.linesep)
137
138
139 def source_env(env_file):
140     p = subprocess.Popen(". %s; env" % env_file, stdout=subprocess.PIPE,
141                          shell=True)
142     output = p.communicate()[0]
143
144     # sometimes output type would be binary_type, and it don't have splitlines
145     # method, so we need to decode
146     if isinstance(output, six.binary_type):
147         output = encodeutils.safe_decode(output)
148     env = dict(line.split('=', 1) for line in output.splitlines() if '=' in line)
149     os.environ.update(env)
150     return env
151
152
153 def read_json_from_file(path):
154     with open(path, 'r') as f:
155         j = f.read()
156     # don't use jsonutils.load() it conflicts with already decoded input
157     return jsonutils.loads(j)
158
159
160 def write_json_to_file(path, data, mode='w'):
161     with open(path, mode) as f:
162         jsonutils.dump(data, f)
163
164
165 def write_file(path, data, mode='w'):
166     with open(path, mode) as f:
167         f.write(data)
168
169
170 def parse_ini_file(path):
171     parser = configparser.ConfigParser()
172
173     try:
174         files = parser.read(path)
175     except configparser.MissingSectionHeaderError:
176         logger.exception('invalid file type')
177         raise
178     else:
179         if not files:
180             raise RuntimeError('file not exist')
181
182     try:
183         default = {k: v for k, v in parser.items('DEFAULT')}
184     except configparser.NoSectionError:
185         default = {}
186
187     config = dict(DEFAULT=default,
188                   **{s: {k: v for k, v in parser.items(
189                       s)} for s in parser.sections()})
190
191     return config
192
193
194 def get_port_mac(sshclient, port):
195     cmd = "ifconfig |grep HWaddr |grep %s |awk '{print $5}' " % port
196     status, stdout, stderr = sshclient.execute(cmd)
197
198     if status:
199         raise RuntimeError(stderr)
200     return stdout.rstrip()
201
202
203 def get_port_ip(sshclient, port):
204     cmd = "ifconfig %s |grep 'inet addr' |awk '{print $2}' " \
205         "|cut -d ':' -f2 " % port
206     status, stdout, stderr = sshclient.execute(cmd)
207
208     if status:
209         raise RuntimeError(stderr)
210     return stdout.rstrip()
211
212
213 def flatten_dict_key(data):
214     next_data = {}
215
216     # use list, because iterable is too generic
217     if not any(isinstance(v, (collections.Mapping, list))
218                for v in data.values()):
219         return data
220
221     for k, v in data.items():
222         if isinstance(v, collections.Mapping):
223             for n_k, n_v in v.items():
224                 next_data["%s.%s" % (k, n_k)] = n_v
225         # use list because iterable is too generic
226         elif isinstance(v, collections.Iterable) and not isinstance(v, six.string_types):
227             for index, item in enumerate(v):
228                 next_data["%s%d" % (k, index)] = item
229         else:
230             next_data[k] = v
231
232     return flatten_dict_key(next_data)
233
234
235 def translate_to_str(obj):
236     if isinstance(obj, collections.Mapping):
237         return {str(k): translate_to_str(v) for k, v in obj.items()}
238     elif isinstance(obj, list):
239         return [translate_to_str(ele) for ele in obj]
240     elif isinstance(obj, six.text_type):
241         return str(obj)
242     return obj
243
244
245 def result_handler(status, data):
246     result = {
247         'status': status,
248         'result': data
249     }
250     return jsonify(result)
251
252
253 def change_obj_to_dict(obj):
254     dic = {}
255     for k, v in vars(obj).items():
256         try:
257             vars(v)
258         except TypeError:
259             dic.update({k: v})
260     return dic
261
262
263 def set_dict_value(dic, keys, value):
264     return_dic = dic
265
266     for key in keys.split('.'):
267         return_dic.setdefault(key, {})
268         if key == keys.split('.')[-1]:
269             return_dic[key] = value
270         else:
271             return_dic = return_dic[key]
272     return dic
273
274
275 def get_free_port(ip):
276     with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
277         port = random.randint(5000, 10000)
278         while s.connect_ex((ip, port)) == 0:
279             port = random.randint(5000, 10000)
280         return port
281
282
283 def mac_address_to_hex_list(mac):
284     octets = ["0x{:02x}".format(int(elem, 16)) for elem in mac.split(':')]
285     assert len(octets) == 6 and all(len(octet) == 4 for octet in octets)
286     return octets
287
288
289 def safe_ip_address(ip_addr):
290     """ get ip address version v6 or v4 """
291     try:
292         return ipaddress.ip_address(six.text_type(ip_addr))
293     except ValueError:
294         logging.error("%s is not valid", ip_addr)
295         return None
296
297
298 def get_ip_version(ip_addr):
299     """ get ip address version v6 or v4 """
300     try:
301         address = ipaddress.ip_address(six.text_type(ip_addr))
302     except ValueError:
303         logging.error("%s is not valid", ip_addr)
304         return None
305     else:
306         return address.version
307
308
309 def ip_to_hex(ip_addr, separator=''):
310     try:
311         address = ipaddress.ip_address(six.text_type(ip_addr))
312     except ValueError:
313         logging.error("%s is not valid", ip_addr)
314         return ip_addr
315
316     if address.version != 4:
317         return ip_addr
318
319     if not separator:
320         return '{:08x}'.format(int(address))
321
322     return separator.join('{:02x}'.format(octet) for octet in address.packed)
323
324
325 def try_int(s, *args):
326     """Convert to integer if possible."""
327     try:
328         return int(s)
329     except (TypeError, ValueError):
330         return args[0] if args else s
331
332
333 class SocketTopology(dict):
334
335     @classmethod
336     def parse_cpuinfo(cls, cpuinfo):
337         socket_map = {}
338
339         lines = cpuinfo.splitlines()
340
341         core_details = []
342         core_lines = {}
343         for line in lines:
344             if line.strip():
345                 name, value = line.split(":", 1)
346                 core_lines[name.strip()] = try_int(value.strip())
347             else:
348                 core_details.append(core_lines)
349                 core_lines = {}
350
351         for core in core_details:
352             socket_map.setdefault(core["physical id"], {}).setdefault(
353                 core["core id"], {})[core["processor"]] = (
354                 core["processor"], core["core id"], core["physical id"])
355
356         return cls(socket_map)
357
358     def sockets(self):
359         return sorted(self.keys())
360
361     def cores(self):
362         return sorted(core for cores in self.values() for core in cores)
363
364     def processors(self):
365         return sorted(
366             proc for cores in self.values() for procs in cores.values() for
367             proc in procs)
368
369
370 def config_to_dict(config):
371     return {section: dict(config.items(section)) for section in
372             config.sections()}
373
374
375 def validate_non_string_sequence(value, default=None, raise_exc=None):
376     # NOTE(ralonsoh): refactor this function to check if raise_exc is an
377     # Exception. Remove duplicate code, this function is duplicated in this
378     # repository.
379     if isinstance(value, collections.Sequence) and not isinstance(value, six.string_types):
380         return value
381     if raise_exc:
382         raise raise_exc  # pylint: disable=raising-bad-type
383     return default
384
385
386 def join_non_strings(separator, *non_strings):
387     try:
388         non_strings = validate_non_string_sequence(non_strings[0], raise_exc=RuntimeError)
389     except (IndexError, RuntimeError):
390         pass
391     return str(separator).join(str(non_string) for non_string in non_strings)
392
393
394 def safe_decode_utf8(s):
395     """Safe decode a str from UTF"""
396     if six.PY3 and isinstance(s, bytes):
397         return s.decode('utf-8', 'surrogateescape')
398     return s
399
400
401 class ErrorClass(object):
402
403     def __init__(self, *args, **kwargs):
404         if 'test' not in kwargs:
405             raise RuntimeError
406
407     def __getattr__(self, item):
408         raise AttributeError
409
410
411 class Timer(object):
412     def __init__(self, timeout=None):
413         super(Timer, self).__init__()
414         self.start = self.delta = None
415         self._timeout = int(timeout) if timeout else None
416
417     def _timeout_handler(self, *args):
418         raise exceptions.TimerTimeout(timeout=self._timeout)
419
420     def __enter__(self):
421         self.start = datetime.datetime.now()
422         if self._timeout:
423             signal.signal(signal.SIGALRM, self._timeout_handler)
424             signal.alarm(self._timeout)
425         return self
426
427     def __exit__(self, *_):
428         if self._timeout:
429             signal.alarm(0)
430         self.delta = datetime.datetime.now() - self.start
431
432     def __getattr__(self, item):
433         return getattr(self.delta, item)
434
435
436 def read_meminfo(ssh_client):
437     """Read "/proc/meminfo" file and parse all keys and values"""
438
439     cpuinfo = six.BytesIO()
440     ssh_client.get_file_obj('/proc/meminfo', cpuinfo)
441     lines = cpuinfo.getvalue().decode('utf-8')
442     matches = re.findall(r"([\w\(\)]+):\s+(\d+)( kB)*", lines)
443     output = {}
444     for match in matches:
445         output[match[0]] = match[1]
446
447     return output
448
449
450 def find_relative_file(path, task_path):
451     """
452     Find file in one of places: in abs of path or relative to a directory path,
453     in this order.
454
455     :param path:
456     :param task_path:
457     :return str: full path to file
458     """
459     # fixme: create schema to validate all fields have been provided
460     for lookup in [os.path.abspath(path), os.path.join(task_path, path)]:
461         try:
462             with open(lookup):
463                 return lookup
464         except IOError:
465             pass
466     raise IOError(errno.ENOENT, 'Unable to find {} file'.format(path))
467
468
469 def open_relative_file(path, task_path):
470     try:
471         return open(path)
472     except IOError as e:
473         if e.errno == errno.ENOENT:
474             return open(os.path.join(task_path, path))
475         raise
476
477
478 def wait_until_true(predicate, timeout=60, sleep=1, exception=None):
479     """Wait until callable predicate is evaluated as True
480
481     :param predicate: (func) callable deciding whether waiting should continue
482     :param timeout: (int) timeout in seconds how long should function wait
483     :param sleep: (int) polling interval for results in seconds
484     :param exception: exception instance to raise on timeout. If None is passed
485                       (default) then WaitTimeout exception is raised.
486     """
487     try:
488         with Timer(timeout=timeout):
489             while not predicate():
490                 time.sleep(sleep)
491     except exceptions.TimerTimeout:
492         if exception and issubclass(exception, Exception):
493             raise exception  # pylint: disable=raising-bad-type
494         raise exceptions.WaitTimeout