869db469b2675cd124f0f2f0f7d3ecd46a1b999d
[yardstick.git] / yardstick / common / utils.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 import collections
17 from contextlib import closing
18 import datetime
19 import errno
20 import importlib
21 import ipaddress
22 import logging
23 import os
24 import random
25 import re
26 import signal
27 import socket
28 import subprocess
29 import sys
30 import time
31
32 import six
33 from flask import jsonify
34 from six.moves import configparser
35 from oslo_serialization import jsonutils
36 from oslo_utils import encodeutils
37
38 import yardstick
39 from yardstick.common import exceptions
40
41
42 logger = logging.getLogger(__name__)
43 logger.setLevel(logging.DEBUG)
44
45
46 # Decorator for cli-args
47 def cliargs(*args, **kwargs):
48     def _decorator(func):
49         func.__dict__.setdefault('arguments', []).insert(0, (args, kwargs))
50         return func
51     return _decorator
52
53
54 def itersubclasses(cls, _seen=None):
55     """Generator over all subclasses of a given class in depth first order."""
56
57     if not isinstance(cls, type):
58         raise TypeError("itersubclasses must be called with "
59                         "new-style classes, not %.100r" % cls)
60     _seen = _seen or set()
61     try:
62         subs = cls.__subclasses__()
63     except TypeError:   # fails only when cls is type
64         subs = cls.__subclasses__(cls)
65     for sub in subs:
66         if sub not in _seen:
67             _seen.add(sub)
68             yield sub
69             for sub in itersubclasses(sub, _seen):
70                 yield sub
71
72
73 def import_modules_from_package(package, raise_exception=False):
74     """Import modules given a package name
75
76     :param: package - Full package name. For example: rally.deploy.engines
77     """
78     yardstick_root = os.path.dirname(os.path.dirname(yardstick.__file__))
79     path = os.path.join(yardstick_root, *package.split('.'))
80     for root, _, files in os.walk(path):
81         matches = (filename for filename in files if filename.endswith('.py')
82                    and not filename.startswith('__'))
83         new_package = os.path.relpath(root, yardstick_root).replace(os.sep,
84                                                                     '.')
85         module_names = set(
86             '{}.{}'.format(new_package, filename.rsplit('.py', 1)[0])
87             for filename in matches)
88         # Find modules which haven't already been imported
89         missing_modules = module_names.difference(sys.modules)
90         logger.debug('Importing modules: %s', missing_modules)
91         for module_name in missing_modules:
92             try:
93                 importlib.import_module(module_name)
94             except (ImportError, SyntaxError) as exc:
95                 if raise_exception:
96                     raise exc
97                 logger.exception('Unable to import module %s', module_name)
98
99
100 NON_NONE_DEFAULT = object()
101
102
103 def get_key_with_default(data, key, default=NON_NONE_DEFAULT):
104     value = data.get(key, default)
105     if value is NON_NONE_DEFAULT:
106         raise KeyError(key)
107     return value
108
109
110 def make_dict_from_map(data, key_map):
111     return {dest_key: get_key_with_default(data, src_key, default)
112             for dest_key, (src_key, default) in key_map.items()}
113
114
115 def makedirs(d):
116     try:
117         os.makedirs(d)
118     except OSError as e:
119         if e.errno != errno.EEXIST:
120             raise
121
122
123 def remove_file(path):
124     try:
125         os.remove(path)
126     except OSError as e:
127         if e.errno != errno.ENOENT:
128             raise
129
130
131 def execute_command(cmd, **kwargs):
132     exec_msg = "Executing command: '%s'" % cmd
133     logger.debug(exec_msg)
134
135     output = subprocess.check_output(cmd.split(), **kwargs)
136     return encodeutils.safe_decode(output, incoming='utf-8').split(os.linesep)
137
138
139 def source_env(env_file):
140     p = subprocess.Popen(". %s; env" % env_file, stdout=subprocess.PIPE,
141                          shell=True)
142     output = p.communicate()[0]
143
144     # sometimes output type would be binary_type, and it don't have splitlines
145     # method, so we need to decode
146     if isinstance(output, six.binary_type):
147         output = encodeutils.safe_decode(output)
148     env = dict(line.split('=', 1) for line in output.splitlines() if '=' in line)
149     os.environ.update(env)
150     return env
151
152
153 def read_json_from_file(path):
154     with open(path, 'r') as f:
155         j = f.read()
156     # don't use jsonutils.load() it conflicts with already decoded input
157     return jsonutils.loads(j)
158
159
160 def write_json_to_file(path, data, mode='w'):
161     with open(path, mode) as f:
162         jsonutils.dump(data, f)
163
164
165 def write_file(path, data, mode='w'):
166     with open(path, mode) as f:
167         f.write(data)
168
169
170 def parse_ini_file(path):
171     parser = configparser.ConfigParser()
172
173     try:
174         files = parser.read(path)
175     except configparser.MissingSectionHeaderError:
176         logger.exception('invalid file type')
177         raise
178     else:
179         if not files:
180             raise RuntimeError('file not exist')
181
182     try:
183         default = {k: v for k, v in parser.items('DEFAULT')}
184     except configparser.NoSectionError:
185         default = {}
186
187     config = dict(DEFAULT=default,
188                   **{s: {k: v for k, v in parser.items(
189                       s)} for s in parser.sections()})
190
191     return config
192
193
194 def get_port_mac(sshclient, port):
195     cmd = "ifconfig |grep HWaddr |grep %s |awk '{print $5}' " % port
196     status, stdout, stderr = sshclient.execute(cmd)
197
198     if status:
199         raise RuntimeError(stderr)
200     return stdout.rstrip()
201
202
203 def get_port_ip(sshclient, port):
204     cmd = "ifconfig %s |grep 'inet addr' |awk '{print $2}' " \
205         "|cut -d ':' -f2 " % port
206     status, stdout, stderr = sshclient.execute(cmd)
207
208     if status:
209         raise RuntimeError(stderr)
210     return stdout.rstrip()
211
212
213 def flatten_dict_key(data):
214     next_data = {}
215
216     # use list, because iterable is too generic
217     if not any(isinstance(v, (collections.Mapping, list))
218                for v in data.values()):
219         return data
220
221     for k, v in data.items():
222         if isinstance(v, collections.Mapping):
223             for n_k, n_v in v.items():
224                 next_data["%s.%s" % (k, n_k)] = n_v
225         # use list because iterable is too generic
226         elif isinstance(v, collections.Iterable) and not isinstance(v, six.string_types):
227             for index, item in enumerate(v):
228                 next_data["%s%d" % (k, index)] = item
229         else:
230             next_data[k] = v
231
232     return flatten_dict_key(next_data)
233
234
235 def translate_to_str(obj):
236     if isinstance(obj, collections.Mapping):
237         return {str(k): translate_to_str(v) for k, v in obj.items()}
238     elif isinstance(obj, list):
239         return [translate_to_str(ele) for ele in obj]
240     elif isinstance(obj, six.text_type):
241         return str(obj)
242     return obj
243
244
245 def result_handler(status, data):
246     result = {
247         'status': status,
248         'result': data
249     }
250     return jsonify(result)
251
252
253 def change_obj_to_dict(obj):
254     dic = {}
255     for k, v in vars(obj).items():
256         try:
257             vars(v)
258         except TypeError:
259             dic.update({k: v})
260     return dic
261
262
263 def set_dict_value(dic, keys, value):
264     return_dic = dic
265
266     for key in keys.split('.'):
267         return_dic.setdefault(key, {})
268         if key == keys.split('.')[-1]:
269             return_dic[key] = value
270         else:
271             return_dic = return_dic[key]
272     return dic
273
274
275 def get_free_port(ip):
276     with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
277         port = random.randint(5000, 10000)
278         while s.connect_ex((ip, port)) == 0:
279             port = random.randint(5000, 10000)
280         return port
281
282
283 def mac_address_to_hex_list(mac):
284     octets = ["0x{:02x}".format(int(elem, 16)) for elem in mac.split(':')]
285     assert len(octets) == 6 and all(len(octet) == 4 for octet in octets)
286     return octets
287
288
289 def safe_ip_address(ip_addr):
290     """ get ip address version v6 or v4 """
291     try:
292         return ipaddress.ip_address(six.text_type(ip_addr))
293     except ValueError:
294         logging.error("%s is not valid", ip_addr)
295         return None
296
297
298 def get_ip_version(ip_addr):
299     """ get ip address version v6 or v4 """
300     try:
301         address = ipaddress.ip_address(six.text_type(ip_addr))
302     except ValueError:
303         logging.error("%s is not valid", ip_addr)
304         return None
305     else:
306         return address.version
307
308
309 def make_ip_addr(ip, mask):
310     """
311     :param ip[str]: ip adddress
312     :param mask[str]: /24 prefix of 255.255.255.0 netmask
313     :return: IPv4Interface object
314     """
315     try:
316         return ipaddress.ip_interface(six.text_type('/'.join([ip, mask])))
317     except (TypeError, ValueError):
318         # None so we can skip later
319         return None
320
321
322 def ip_to_hex(ip_addr, separator=''):
323     try:
324         address = ipaddress.ip_address(six.text_type(ip_addr))
325     except ValueError:
326         logging.error("%s is not valid", ip_addr)
327         return ip_addr
328
329     if address.version != 4:
330         return ip_addr
331
332     if not separator:
333         return '{:08x}'.format(int(address))
334
335     return separator.join('{:02x}'.format(octet) for octet in address.packed)
336
337
338 def try_int(s, *args):
339     """Convert to integer if possible."""
340     try:
341         return int(s)
342     except (TypeError, ValueError):
343         return args[0] if args else s
344
345
346 class SocketTopology(dict):
347
348     @classmethod
349     def parse_cpuinfo(cls, cpuinfo):
350         socket_map = {}
351
352         lines = cpuinfo.splitlines()
353
354         core_details = []
355         core_lines = {}
356         for line in lines:
357             if line.strip():
358                 name, value = line.split(":", 1)
359                 core_lines[name.strip()] = try_int(value.strip())
360             else:
361                 core_details.append(core_lines)
362                 core_lines = {}
363
364         for core in core_details:
365             socket_map.setdefault(core["physical id"], {}).setdefault(
366                 core["core id"], {})[core["processor"]] = (
367                 core["processor"], core["core id"], core["physical id"])
368
369         return cls(socket_map)
370
371     def sockets(self):
372         return sorted(self.keys())
373
374     def cores(self):
375         return sorted(core for cores in self.values() for core in cores)
376
377     def processors(self):
378         return sorted(
379             proc for cores in self.values() for procs in cores.values() for
380             proc in procs)
381
382
383 def config_to_dict(config):
384     return {section: dict(config.items(section)) for section in
385             config.sections()}
386
387
388 def validate_non_string_sequence(value, default=None, raise_exc=None):
389     # NOTE(ralonsoh): refactor this function to check if raise_exc is an
390     # Exception. Remove duplicate code, this function is duplicated in this
391     # repository.
392     if isinstance(value, collections.Sequence) and not isinstance(value, six.string_types):
393         return value
394     if raise_exc:
395         raise raise_exc  # pylint: disable=raising-bad-type
396     return default
397
398
399 def join_non_strings(separator, *non_strings):
400     try:
401         non_strings = validate_non_string_sequence(non_strings[0], raise_exc=RuntimeError)
402     except (IndexError, RuntimeError):
403         pass
404     return str(separator).join(str(non_string) for non_string in non_strings)
405
406
407 def safe_decode_utf8(s):
408     """Safe decode a str from UTF"""
409     if six.PY3 and isinstance(s, bytes):
410         return s.decode('utf-8', 'surrogateescape')
411     return s
412
413
414 class ErrorClass(object):
415
416     def __init__(self, *args, **kwargs):
417         if 'test' not in kwargs:
418             raise RuntimeError
419
420     def __getattr__(self, item):
421         raise AttributeError
422
423
424 class Timer(object):
425     def __init__(self, timeout=None):
426         super(Timer, self).__init__()
427         self.start = self.delta = None
428         self._timeout = int(timeout) if timeout else None
429
430     def _timeout_handler(self, *args):
431         raise exceptions.TimerTimeout(timeout=self._timeout)
432
433     def __enter__(self):
434         self.start = datetime.datetime.now()
435         if self._timeout:
436             signal.signal(signal.SIGALRM, self._timeout_handler)
437             signal.alarm(self._timeout)
438         return self
439
440     def __exit__(self, *_):
441         if self._timeout:
442             signal.alarm(0)
443         self.delta = datetime.datetime.now() - self.start
444
445     def __getattr__(self, item):
446         return getattr(self.delta, item)
447
448
449 def read_meminfo(ssh_client):
450     """Read "/proc/meminfo" file and parse all keys and values"""
451
452     cpuinfo = six.BytesIO()
453     ssh_client.get_file_obj('/proc/meminfo', cpuinfo)
454     lines = cpuinfo.getvalue().decode('utf-8')
455     matches = re.findall(r"([\w\(\)]+):\s+(\d+)( kB)*", lines)
456     output = {}
457     for match in matches:
458         output[match[0]] = match[1]
459
460     return output
461
462
463 def find_relative_file(path, task_path):
464     """
465     Find file in one of places: in abs of path or relative to a directory path,
466     in this order.
467
468     :param path:
469     :param task_path:
470     :return str: full path to file
471     """
472     # fixme: create schema to validate all fields have been provided
473     for lookup in [os.path.abspath(path), os.path.join(task_path, path)]:
474         try:
475             with open(lookup):
476                 return lookup
477         except IOError:
478             pass
479     raise IOError(errno.ENOENT, 'Unable to find {} file'.format(path))
480
481
482 def open_relative_file(path, task_path):
483     try:
484         return open(path)
485     except IOError as e:
486         if e.errno == errno.ENOENT:
487             return open(os.path.join(task_path, path))
488         raise
489
490
491 def wait_until_true(predicate, timeout=60, sleep=1, exception=None):
492     """Wait until callable predicate is evaluated as True
493
494     :param predicate: (func) callable deciding whether waiting should continue
495     :param timeout: (int) timeout in seconds how long should function wait
496     :param sleep: (int) polling interval for results in seconds
497     :param exception: exception instance to raise on timeout. If None is passed
498                       (default) then WaitTimeout exception is raised.
499     """
500     try:
501         with Timer(timeout=timeout):
502             while not predicate():
503                 time.sleep(sleep)
504     except exceptions.TimerTimeout:
505         if exception and issubclass(exception, Exception):
506             raise exception  # pylint: disable=raising-bad-type
507         raise exceptions.WaitTimeout