Merge "Rename test/unit/cmd directory"
[yardstick.git] / yardstick / common / utils.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this is a modified copy of rally/rally/common/utils.py
17
18 from __future__ import absolute_import
19 from __future__ import print_function
20
21 import datetime
22 import errno
23 import logging
24 import os
25 import subprocess
26 import sys
27 import collections
28 import socket
29 import random
30 import ipaddress
31 from contextlib import closing
32
33 import six
34 from flask import jsonify
35 from six.moves import configparser
36 from oslo_utils import importutils
37 from oslo_serialization import jsonutils
38
39 import yardstick
40
41 logger = logging.getLogger(__name__)
42 logger.setLevel(logging.DEBUG)
43
44
45 # Decorator for cli-args
46 def cliargs(*args, **kwargs):
47     def _decorator(func):
48         func.__dict__.setdefault('arguments', []).insert(0, (args, kwargs))
49         return func
50     return _decorator
51
52
53 def itersubclasses(cls, _seen=None):
54     """Generator over all subclasses of a given class in depth first order."""
55
56     if not isinstance(cls, type):
57         raise TypeError("itersubclasses must be called with "
58                         "new-style classes, not %.100r" % cls)
59     _seen = _seen or set()
60     try:
61         subs = cls.__subclasses__()
62     except TypeError:   # fails only when cls is type
63         subs = cls.__subclasses__(cls)
64     for sub in subs:
65         if sub not in _seen:
66             _seen.add(sub)
67             yield sub
68             for sub in itersubclasses(sub, _seen):
69                 yield sub
70
71
72 def import_modules_from_package(package):
73     """Import modules from package and append into sys.modules
74
75     :param: package - Full package name. For example: rally.deploy.engines
76     """
77     yardstick_root = os.path.dirname(os.path.dirname(yardstick.__file__))
78     path = os.path.join(yardstick_root, *package.split("."))
79     for root, _, files in os.walk(path):
80         matches = (filename for filename in files if filename.endswith(".py") and
81                    not filename.startswith("__"))
82         new_package = os.path.relpath(root, yardstick_root).replace(os.sep, ".")
83         module_names = set(
84             ("{}.{}".format(new_package, filename.rsplit(".py", 1)[0]) for filename in matches))
85         # find modules which haven't already been imported
86         missing_modules = module_names.difference(sys.modules)
87         logger.debug("importing %s", missing_modules)
88         # we have already checked for already imported modules, so we don't need to check again
89         for module_name in missing_modules:
90             try:
91                 sys.modules[module_name] = importutils.import_module(module_name)
92             except (ImportError, SyntaxError):
93                 logger.exception("unable to import %s", module_name)
94
95
96 def makedirs(d):
97     try:
98         os.makedirs(d)
99     except OSError as e:
100         if e.errno != errno.EEXIST:
101             raise
102
103
104 def remove_file(path):
105     try:
106         os.remove(path)
107     except OSError as e:
108         if e.errno != errno.ENOENT:
109             raise
110
111
112 def execute_command(cmd):
113     exec_msg = "Executing command: '%s'" % cmd
114     logger.debug(exec_msg)
115
116     output = subprocess.check_output(cmd.split()).split(os.linesep)
117
118     return output
119
120
121 def source_env(env_file):
122     p = subprocess.Popen(". %s; env" % env_file, stdout=subprocess.PIPE,
123                          shell=True)
124     output = p.communicate()[0]
125     env = dict(line.split('=', 1) for line in output.splitlines() if '=' in line)
126     os.environ.update(env)
127     return env
128
129
130 def read_json_from_file(path):
131     with open(path, 'r') as f:
132         j = f.read()
133     # don't use jsonutils.load() it conflicts with already decoded input
134     return jsonutils.loads(j)
135
136
137 def write_json_to_file(path, data, mode='w'):
138     with open(path, mode) as f:
139         jsonutils.dump(data, f)
140
141
142 def write_file(path, data, mode='w'):
143     with open(path, mode) as f:
144         f.write(data)
145
146
147 def parse_ini_file(path):
148     parser = configparser.ConfigParser()
149
150     try:
151         files = parser.read(path)
152     except configparser.MissingSectionHeaderError:
153         logger.exception('invalid file type')
154         raise
155     else:
156         if not files:
157             raise RuntimeError('file not exist')
158
159     try:
160         default = {k: v for k, v in parser.items('DEFAULT')}
161     except configparser.NoSectionError:
162         default = {}
163
164     config = dict(DEFAULT=default,
165                   **{s: {k: v for k, v in parser.items(
166                       s)} for s in parser.sections()})
167
168     return config
169
170
171 def get_port_mac(sshclient, port):
172     cmd = "ifconfig |grep HWaddr |grep %s |awk '{print $5}' " % port
173     status, stdout, stderr = sshclient.execute(cmd)
174
175     if status:
176         raise RuntimeError(stderr)
177     return stdout.rstrip()
178
179
180 def get_port_ip(sshclient, port):
181     cmd = "ifconfig %s |grep 'inet addr' |awk '{print $2}' " \
182         "|cut -d ':' -f2 " % port
183     status, stdout, stderr = sshclient.execute(cmd)
184
185     if status:
186         raise RuntimeError(stderr)
187     return stdout.rstrip()
188
189
190 def flatten_dict_key(data):
191     next_data = {}
192
193     # use list, because iterable is too generic
194     if not any(isinstance(v, (collections.Mapping, list))
195                for v in data.values()):
196         return data
197
198     for k, v in data.items():
199         if isinstance(v, collections.Mapping):
200             for n_k, n_v in v.items():
201                 next_data["%s.%s" % (k, n_k)] = n_v
202         # use list because iterable is too generic
203         elif isinstance(v, collections.Iterable) and not isinstance(v, six.string_types):
204             for index, item in enumerate(v):
205                 next_data["%s%d" % (k, index)] = item
206         else:
207             next_data[k] = v
208
209     return flatten_dict_key(next_data)
210
211
212 def translate_to_str(obj):
213     if isinstance(obj, collections.Mapping):
214         return {str(k): translate_to_str(v) for k, v in obj.items()}
215     elif isinstance(obj, list):
216         return [translate_to_str(ele) for ele in obj]
217     elif isinstance(obj, six.text_type):
218         return str(obj)
219     return obj
220
221
222 def result_handler(status, data):
223     result = {
224         'status': status,
225         'result': data
226     }
227     return jsonify(result)
228
229
230 def change_obj_to_dict(obj):
231     dic = {}
232     for k, v in vars(obj).items():
233         try:
234             vars(v)
235         except TypeError:
236             dic.update({k: v})
237     return dic
238
239
240 def set_dict_value(dic, keys, value):
241     return_dic = dic
242
243     for key in keys.split('.'):
244         return_dic.setdefault(key, {})
245         if key == keys.split('.')[-1]:
246             return_dic[key] = value
247         else:
248             return_dic = return_dic[key]
249     return dic
250
251
252 def get_free_port(ip):
253     with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
254         port = random.randint(5000, 10000)
255         while s.connect_ex((ip, port)) == 0:
256             port = random.randint(5000, 10000)
257         return port
258
259
260 def mac_address_to_hex_list(mac):
261     octets = ["0x{:02x}".format(int(elem, 16)) for elem in mac.split(':')]
262     assert len(octets) == 6 and all(len(octet) == 4 for octet in octets)
263     return octets
264
265
266 def safe_ip_address(ip_addr):
267     """ get ip address version v6 or v4 """
268     try:
269         return ipaddress.ip_address(six.text_type(ip_addr))
270     except ValueError:
271         logging.error("%s is not valid", ip_addr)
272         return None
273
274
275 def get_ip_version(ip_addr):
276     """ get ip address version v6 or v4 """
277     try:
278         address = ipaddress.ip_address(six.text_type(ip_addr))
279     except ValueError:
280         logging.error("%s is not valid", ip_addr)
281         return None
282     else:
283         return address.version
284
285
286 def ip_to_hex(ip_addr, separator=''):
287     try:
288         address = ipaddress.ip_address(six.text_type(ip_addr))
289     except ValueError:
290         logging.error("%s is not valid", ip_addr)
291         return ip_addr
292
293     if address.version != 4:
294         return ip_addr
295
296     if not separator:
297         return '{:08x}'.format(int(address))
298
299     return separator.join('{:02x}'.format(octet) for octet in address.packed)
300
301
302 def try_int(s, *args):
303     """Convert to integer if possible."""
304     try:
305         return int(s)
306     except (TypeError, ValueError):
307         return args[0] if args else s
308
309
310 class SocketTopology(dict):
311
312     @classmethod
313     def parse_cpuinfo(cls, cpuinfo):
314         socket_map = {}
315
316         lines = cpuinfo.splitlines()
317
318         core_details = []
319         core_lines = {}
320         for line in lines:
321             if line.strip():
322                 name, value = line.split(":", 1)
323                 core_lines[name.strip()] = try_int(value.strip())
324             else:
325                 core_details.append(core_lines)
326                 core_lines = {}
327
328         for core in core_details:
329             socket_map.setdefault(core["physical id"], {}).setdefault(
330                 core["core id"], {})[core["processor"]] = (
331                 core["processor"], core["core id"], core["physical id"])
332
333         return cls(socket_map)
334
335     def sockets(self):
336         return sorted(self.keys())
337
338     def cores(self):
339         return sorted(core for cores in self.values() for core in cores)
340
341     def processors(self):
342         return sorted(
343             proc for cores in self.values() for procs in cores.values() for
344             proc in procs)
345
346
347 def config_to_dict(config):
348     return {section: dict(config.items(section)) for section in
349             config.sections()}
350
351
352 def validate_non_string_sequence(value, default=None, raise_exc=None):
353     # NOTE(ralonsoh): refactor this function to check if raise_exc is an
354     # Exception. Remove duplicate code, this function is duplicated in this
355     # repository.
356     if isinstance(value, collections.Sequence) and not isinstance(value, six.string_types):
357         return value
358     if raise_exc:
359         raise raise_exc  # pylint: disable=raising-bad-type
360     return default
361
362
363 def join_non_strings(separator, *non_strings):
364     try:
365         non_strings = validate_non_string_sequence(non_strings[0], raise_exc=RuntimeError)
366     except (IndexError, RuntimeError):
367         pass
368     return str(separator).join(str(non_string) for non_string in non_strings)
369
370
371 def safe_decode_utf8(s):
372     """Safe decode a str from UTF"""
373     if six.PY3 and isinstance(s, bytes):
374         return s.decode('utf-8', 'surrogateescape')
375     return s
376
377
378 class ErrorClass(object):
379
380     def __init__(self, *args, **kwargs):
381         if 'test' not in kwargs:
382             raise RuntimeError
383
384     def __getattr__(self, item):
385         raise AttributeError
386
387
388 class Timer(object):
389     def __init__(self):
390         super(Timer, self).__init__()
391         self.start = self.delta = None
392
393     def __enter__(self):
394         self.start = datetime.datetime.now()
395         return self
396
397     def __exit__(self, *_):
398         self.delta = datetime.datetime.now() - self.start
399
400     def __getattr__(self, item):
401         return getattr(self.delta, item)