Merge "Use assertIn(x, y) instead of other variations"
[yardstick.git] / yardstick / common / utils.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 import collections
17 from contextlib import closing
18 import datetime
19 import errno
20 import importlib
21 import ipaddress
22 import logging
23 import os
24 import random
25 import socket
26 import subprocess
27 import sys
28
29 import six
30 from flask import jsonify
31 from six.moves import configparser
32 from oslo_serialization import jsonutils
33
34 import yardstick
35
36 logger = logging.getLogger(__name__)
37 logger.setLevel(logging.DEBUG)
38
39
40 # Decorator for cli-args
41 def cliargs(*args, **kwargs):
42     def _decorator(func):
43         func.__dict__.setdefault('arguments', []).insert(0, (args, kwargs))
44         return func
45     return _decorator
46
47
48 def itersubclasses(cls, _seen=None):
49     """Generator over all subclasses of a given class in depth first order."""
50
51     if not isinstance(cls, type):
52         raise TypeError("itersubclasses must be called with "
53                         "new-style classes, not %.100r" % cls)
54     _seen = _seen or set()
55     try:
56         subs = cls.__subclasses__()
57     except TypeError:   # fails only when cls is type
58         subs = cls.__subclasses__(cls)
59     for sub in subs:
60         if sub not in _seen:
61             _seen.add(sub)
62             yield sub
63             for sub in itersubclasses(sub, _seen):
64                 yield sub
65
66
67 def import_modules_from_package(package):
68     """Import modules given a package name
69
70     :param: package - Full package name. For example: rally.deploy.engines
71     """
72     yardstick_root = os.path.dirname(os.path.dirname(yardstick.__file__))
73     path = os.path.join(yardstick_root, *package.split('.'))
74     for root, _, files in os.walk(path):
75         matches = (filename for filename in files if filename.endswith('.py')
76                    and not filename.startswith('__'))
77         new_package = os.path.relpath(root, yardstick_root).replace(os.sep,
78                                                                     '.')
79         module_names = set(
80             '{}.{}'.format(new_package, filename.rsplit('.py', 1)[0])
81             for filename in matches)
82         # Find modules which haven't already been imported
83         missing_modules = module_names.difference(sys.modules)
84         logger.debug('Importing modules: %s', missing_modules)
85         for module_name in missing_modules:
86             try:
87                 importlib.import_module(module_name)
88             except (ImportError, SyntaxError):
89                 logger.exception('Unable to import module %s', module_name)
90
91
92 def makedirs(d):
93     try:
94         os.makedirs(d)
95     except OSError as e:
96         if e.errno != errno.EEXIST:
97             raise
98
99
100 def remove_file(path):
101     try:
102         os.remove(path)
103     except OSError as e:
104         if e.errno != errno.ENOENT:
105             raise
106
107
108 def execute_command(cmd):
109     exec_msg = "Executing command: '%s'" % cmd
110     logger.debug(exec_msg)
111
112     output = subprocess.check_output(cmd.split()).split(os.linesep)
113
114     return output
115
116
117 def source_env(env_file):
118     p = subprocess.Popen(". %s; env" % env_file, stdout=subprocess.PIPE,
119                          shell=True)
120     output = p.communicate()[0]
121     env = dict(line.split('=', 1) for line in output.splitlines() if '=' in line)
122     os.environ.update(env)
123     return env
124
125
126 def read_json_from_file(path):
127     with open(path, 'r') as f:
128         j = f.read()
129     # don't use jsonutils.load() it conflicts with already decoded input
130     return jsonutils.loads(j)
131
132
133 def write_json_to_file(path, data, mode='w'):
134     with open(path, mode) as f:
135         jsonutils.dump(data, f)
136
137
138 def write_file(path, data, mode='w'):
139     with open(path, mode) as f:
140         f.write(data)
141
142
143 def parse_ini_file(path):
144     parser = configparser.ConfigParser()
145
146     try:
147         files = parser.read(path)
148     except configparser.MissingSectionHeaderError:
149         logger.exception('invalid file type')
150         raise
151     else:
152         if not files:
153             raise RuntimeError('file not exist')
154
155     try:
156         default = {k: v for k, v in parser.items('DEFAULT')}
157     except configparser.NoSectionError:
158         default = {}
159
160     config = dict(DEFAULT=default,
161                   **{s: {k: v for k, v in parser.items(
162                       s)} for s in parser.sections()})
163
164     return config
165
166
167 def get_port_mac(sshclient, port):
168     cmd = "ifconfig |grep HWaddr |grep %s |awk '{print $5}' " % port
169     status, stdout, stderr = sshclient.execute(cmd)
170
171     if status:
172         raise RuntimeError(stderr)
173     return stdout.rstrip()
174
175
176 def get_port_ip(sshclient, port):
177     cmd = "ifconfig %s |grep 'inet addr' |awk '{print $2}' " \
178         "|cut -d ':' -f2 " % port
179     status, stdout, stderr = sshclient.execute(cmd)
180
181     if status:
182         raise RuntimeError(stderr)
183     return stdout.rstrip()
184
185
186 def flatten_dict_key(data):
187     next_data = {}
188
189     # use list, because iterable is too generic
190     if not any(isinstance(v, (collections.Mapping, list))
191                for v in data.values()):
192         return data
193
194     for k, v in data.items():
195         if isinstance(v, collections.Mapping):
196             for n_k, n_v in v.items():
197                 next_data["%s.%s" % (k, n_k)] = n_v
198         # use list because iterable is too generic
199         elif isinstance(v, collections.Iterable) and not isinstance(v, six.string_types):
200             for index, item in enumerate(v):
201                 next_data["%s%d" % (k, index)] = item
202         else:
203             next_data[k] = v
204
205     return flatten_dict_key(next_data)
206
207
208 def translate_to_str(obj):
209     if isinstance(obj, collections.Mapping):
210         return {str(k): translate_to_str(v) for k, v in obj.items()}
211     elif isinstance(obj, list):
212         return [translate_to_str(ele) for ele in obj]
213     elif isinstance(obj, six.text_type):
214         return str(obj)
215     return obj
216
217
218 def result_handler(status, data):
219     result = {
220         'status': status,
221         'result': data
222     }
223     return jsonify(result)
224
225
226 def change_obj_to_dict(obj):
227     dic = {}
228     for k, v in vars(obj).items():
229         try:
230             vars(v)
231         except TypeError:
232             dic.update({k: v})
233     return dic
234
235
236 def set_dict_value(dic, keys, value):
237     return_dic = dic
238
239     for key in keys.split('.'):
240         return_dic.setdefault(key, {})
241         if key == keys.split('.')[-1]:
242             return_dic[key] = value
243         else:
244             return_dic = return_dic[key]
245     return dic
246
247
248 def get_free_port(ip):
249     with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
250         port = random.randint(5000, 10000)
251         while s.connect_ex((ip, port)) == 0:
252             port = random.randint(5000, 10000)
253         return port
254
255
256 def mac_address_to_hex_list(mac):
257     octets = ["0x{:02x}".format(int(elem, 16)) for elem in mac.split(':')]
258     assert len(octets) == 6 and all(len(octet) == 4 for octet in octets)
259     return octets
260
261
262 def safe_ip_address(ip_addr):
263     """ get ip address version v6 or v4 """
264     try:
265         return ipaddress.ip_address(six.text_type(ip_addr))
266     except ValueError:
267         logging.error("%s is not valid", ip_addr)
268         return None
269
270
271 def get_ip_version(ip_addr):
272     """ get ip address version v6 or v4 """
273     try:
274         address = ipaddress.ip_address(six.text_type(ip_addr))
275     except ValueError:
276         logging.error("%s is not valid", ip_addr)
277         return None
278     else:
279         return address.version
280
281
282 def ip_to_hex(ip_addr, separator=''):
283     try:
284         address = ipaddress.ip_address(six.text_type(ip_addr))
285     except ValueError:
286         logging.error("%s is not valid", ip_addr)
287         return ip_addr
288
289     if address.version != 4:
290         return ip_addr
291
292     if not separator:
293         return '{:08x}'.format(int(address))
294
295     return separator.join('{:02x}'.format(octet) for octet in address.packed)
296
297
298 def try_int(s, *args):
299     """Convert to integer if possible."""
300     try:
301         return int(s)
302     except (TypeError, ValueError):
303         return args[0] if args else s
304
305
306 class SocketTopology(dict):
307
308     @classmethod
309     def parse_cpuinfo(cls, cpuinfo):
310         socket_map = {}
311
312         lines = cpuinfo.splitlines()
313
314         core_details = []
315         core_lines = {}
316         for line in lines:
317             if line.strip():
318                 name, value = line.split(":", 1)
319                 core_lines[name.strip()] = try_int(value.strip())
320             else:
321                 core_details.append(core_lines)
322                 core_lines = {}
323
324         for core in core_details:
325             socket_map.setdefault(core["physical id"], {}).setdefault(
326                 core["core id"], {})[core["processor"]] = (
327                 core["processor"], core["core id"], core["physical id"])
328
329         return cls(socket_map)
330
331     def sockets(self):
332         return sorted(self.keys())
333
334     def cores(self):
335         return sorted(core for cores in self.values() for core in cores)
336
337     def processors(self):
338         return sorted(
339             proc for cores in self.values() for procs in cores.values() for
340             proc in procs)
341
342
343 def config_to_dict(config):
344     return {section: dict(config.items(section)) for section in
345             config.sections()}
346
347
348 def validate_non_string_sequence(value, default=None, raise_exc=None):
349     # NOTE(ralonsoh): refactor this function to check if raise_exc is an
350     # Exception. Remove duplicate code, this function is duplicated in this
351     # repository.
352     if isinstance(value, collections.Sequence) and not isinstance(value, six.string_types):
353         return value
354     if raise_exc:
355         raise raise_exc  # pylint: disable=raising-bad-type
356     return default
357
358
359 def join_non_strings(separator, *non_strings):
360     try:
361         non_strings = validate_non_string_sequence(non_strings[0], raise_exc=RuntimeError)
362     except (IndexError, RuntimeError):
363         pass
364     return str(separator).join(str(non_string) for non_string in non_strings)
365
366
367 def safe_decode_utf8(s):
368     """Safe decode a str from UTF"""
369     if six.PY3 and isinstance(s, bytes):
370         return s.decode('utf-8', 'surrogateescape')
371     return s
372
373
374 class ErrorClass(object):
375
376     def __init__(self, *args, **kwargs):
377         if 'test' not in kwargs:
378             raise RuntimeError
379
380     def __getattr__(self, item):
381         raise AttributeError
382
383
384 class Timer(object):
385     def __init__(self):
386         super(Timer, self).__init__()
387         self.start = self.delta = None
388
389     def __enter__(self):
390         self.start = datetime.datetime.now()
391         return self
392
393     def __exit__(self, *_):
394         self.delta = datetime.datetime.now() - self.start
395
396     def __getattr__(self, item):
397         return getattr(self.delta, item)