Merge "loadgen: Support for Stressor-VMs as a Loadgen"
[vswitchperf.git] / tools / collectors / collectd / collectd_bucky.py
1 # Copyright 2014-2018 TRBS, Spirent Communications
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 # use this file except in compliance with the License. You may obtain a copy of
5 # the License at
6 #
7 #   http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 # License for the specific language governing permissions and limitations under
13 # the License.
14
15 # This file is a modified version of scripts present in bucky software
16 # details of bucky can be found at https://github.com/trbs/bucky
17
18 """
19 This module receives the samples from collectd, processes it and
20 enqueues it in a format suitable for easy processing.
21 It also handles secure communication with collectd.
22 """
23 import copy
24 import hmac
25 import logging
26 import multiprocessing
27 import os
28 import socket
29 import struct
30 import sys
31 from hashlib import sha1, sha256
32
33 from Crypto.Cipher import AES
34 from conf import settings
35
36 logging.basicConfig()
37 LOG = logging.getLogger(__name__)
38
39
40 class CollectdError(Exception):
41     """
42     Custom error class.
43     """
44     def __init__(self, mesg):
45         super(CollectdError, self).__init__(mesg)
46         self.mesg = mesg
47
48     def __str__(self):
49         return self.mesg
50
51
52 class ConnectError(CollectdError):
53     """
54     Custom connect error
55     """
56     pass
57
58
59 class ConfigError(CollectdError):
60     """
61     Custom config error
62     """
63     pass
64
65
66 class ProtocolError(CollectdError):
67     """
68     Custom protocol error
69     """
70     pass
71
72
73 class UDPServer(multiprocessing.Process):
74     """
75     Actual UDP server receiving collectd samples over network
76     """
77     def __init__(self, ip, port):
78         super(UDPServer, self).__init__()
79         self.daemon = True
80         addrinfo = socket.getaddrinfo(ip, port,
81                                       socket.AF_UNSPEC, socket.SOCK_DGRAM)
82         afamily, _, _, _, addr = addrinfo[0]
83         ip, port = addr[:2]
84         self.ip_addr = ip
85         self.port = port
86         self.sock = socket.socket(afamily, socket.SOCK_DGRAM)
87         self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
88         try:
89             self.sock.bind((ip, port))
90             LOG.info("Bound socket socket %s:%s", ip, port)
91         except socket.error:
92             LOG.exception("Error binding socket %s:%s.", ip, port)
93             sys.exit(1)
94
95         self.sock_recvfrom = self.sock.recvfrom
96
97     def run(self):
98         """
99         Start receiving messages
100         """
101         recvfrom = self.sock_recvfrom
102         while True:
103             try:
104                 data, addr = recvfrom(65535)
105             except (IOError, KeyboardInterrupt):
106                 continue
107             addr = addr[:2]  # for compatibility with longer ipv6 tuples
108             if data == b'EXIT':
109                 break
110             if not self.handle(data, addr):
111                 break
112         try:
113             self.pre_shutdown()
114         except SystemExit:
115             LOG.exception("Failed pre_shutdown method for %s",
116                           self.__class__.__name__)
117
118     def handle(self, data, addr):
119         """
120         Handle the message.
121         """
122         raise NotImplementedError()
123
124     def pre_shutdown(self):
125         """ Pre shutdown hook """
126         pass
127
128     def close(self):
129         """
130         Close the communication
131         """
132         self.send('EXIT')
133
134     def send(self, data):
135         """
136         Send over the network
137         """
138         sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
139         if not isinstance(data, bytes):
140             data = data.encode()
141         sock.sendto(data, 0, (self.ip_addr, self.port))
142
143
144 class CPUConverter(object):
145     """
146     Converter for CPU samples fom collectd.
147     """
148     PRIORITY = -1
149
150     def __call__(self, sample):
151         return ["cpu", sample["plugin_instance"], sample["type_instance"]]
152
153
154 class InterfaceConverter(object):
155     """
156     Converter for Interface samples from collectd
157     """
158     PRIORITY = -1
159
160     def __call__(self, sample):
161         parts = []
162         parts.append("interface")
163         if sample.get("plugin_instance", ""):
164             parts.append(sample["plugin_instance"].strip())
165         stypei = sample.get("type_instance", "").strip()
166         if stypei:
167             parts.append(stypei)
168         stype = sample.get("type").strip()
169         if stype:
170             parts.append(stype)
171         vname = sample.get("value_name").strip()
172         if vname:
173             parts.append(vname)
174         return parts
175
176
177 class MemoryConverter(object):
178     """
179     Converter for Memory samples from collectd
180     """
181     PRIORITY = -1
182
183     def __call__(self, sample):
184         return ["memory", sample["type_instance"]]
185
186
187 class DefaultConverter(object):
188     """
189     Default converter for samples from collectd
190     """
191     PRIORITY = -1
192
193     def __call__(self, sample):
194         parts = []
195         parts.append(sample["plugin"].strip())
196         if sample.get("plugin_instance"):
197             parts.append(sample["plugin_instance"].strip())
198         stype = sample.get("type", "").strip()
199         if stype and stype != "value":
200             parts.append(stype)
201         stypei = sample.get("type_instance", "").strip()
202         if stypei:
203             parts.append(stypei)
204         vname = sample.get("value_name").strip()
205         if vname and vname != "value":
206             parts.append(vname)
207         return parts
208
209
210 DEFAULT_CONVERTERS = {
211     "cpu": CPUConverter(),
212     "interface": InterfaceConverter(),
213     "memory": MemoryConverter(),
214     "_default": DefaultConverter(),
215 }
216
217
218 class CollectDTypes(object):
219     """
220     Class to handle the sample types. The types.db that comes
221     with collectd, usually, defines the various types.
222     """
223     def __init__(self, types_dbs=None):
224         if types_dbs is None:
225             types_dbs = []
226         dirs = ["/opt/collectd/share/collectd/types.db",
227                 "/usr/local/share/collectd/types.db"]
228         self.types = {}
229         self.type_ranges = {}
230         if not types_dbs:
231             types_dbs = [tdb for tdb in dirs if os.path.exists(tdb)]
232             if not types_dbs:
233                 raise ConfigError("Unable to locate types.db")
234         self.types_dbs = types_dbs
235         self._load_types()
236
237     def get(self, name):
238         """
239         Get the name of the type
240         """
241         t_name = self.types.get(name)
242         if t_name is None:
243             raise ProtocolError("Invalid type name: %s" % name)
244         return t_name
245
246     def _load_types(self):
247         """
248         Load all the types from types_db
249         """
250         for types_db in self.types_dbs:
251             with open(types_db) as handle:
252                 for line in handle:
253                     if line.lstrip()[:1] == "#":
254                         continue
255                     if not line.strip():
256                         continue
257                     self._add_type_line(line)
258             LOG.info("Loaded collectd types from %s", types_db)
259
260     def _add_type_line(self, line):
261         """
262         Add types information
263         """
264         types = {
265             "COUNTER": 0,
266             "GAUGE": 1,
267             "DERIVE": 2,
268             "ABSOLUTE": 3
269         }
270         name, spec = line.split(None, 1)
271         self.types[name] = []
272         self.type_ranges[name] = {}
273         vals = spec.split(", ")
274         for val in vals:
275             vname, vtype, minv, maxv = val.strip().split(":")
276             vtype = types.get(vtype)
277             if vtype is None:
278                 raise ValueError("Invalid value type: %s" % vtype)
279             minv = None if minv == "U" else float(minv)
280             maxv = None if maxv == "U" else float(maxv)
281             self.types[name].append((vname, vtype))
282             self.type_ranges[name][vname] = (minv, maxv)
283
284
285 class CollectDParser(object):
286     """
287     Parser class: Implements the sample parsing operations.
288     The types definition defines the parsing process.
289     """
290     def __init__(self, types_dbs=None, counter_eq_derive=False):
291         if types_dbs is None:
292             types_dbs = []
293         self.types = CollectDTypes(types_dbs=types_dbs)
294         self.counter_eq_derive = counter_eq_derive
295
296     def parse(self, data):
297         """
298         Parse individual samples
299         """
300         for sample in self.parse_samples(data):
301             yield sample
302
303     def parse_samples(self, data):
304         """
305         Extract all the samples from the message.
306         """
307         types = {
308             0x0000: self._parse_string("host"),
309             0x0001: self._parse_time("time"),
310             0x0008: self._parse_time_hires("time"),
311             0x0002: self._parse_string("plugin"),
312             0x0003: self._parse_string("plugin_instance"),
313             0x0004: self._parse_string("type"),
314             0x0005: self._parse_string("type_instance"),
315             0x0006: None,  # handle specially
316             0x0007: self._parse_time("interval"),
317             0x0009: self._parse_time_hires("interval")
318         }
319         sample = {}
320         for (ptype, pdata) in self.parse_data(data):
321             if ptype not in types:
322                 LOG.debug("Ignoring part type: 0x%02x", ptype)
323                 continue
324             if ptype != 0x0006:
325                 types[ptype](sample, pdata)
326                 continue
327             for vname, vtype, val in self.parse_values(sample["type"], pdata):
328                 sample["value_name"] = vname
329                 sample["value_type"] = vtype
330                 sample["value"] = val
331                 yield copy.deepcopy(sample)
332
333     @staticmethod
334     def parse_data(data):
335         """
336         Parse the message
337         """
338         types = set([
339             0x0000, 0x0001, 0x0002, 0x0003, 0x0004,
340             0x0005, 0x0006, 0x0007, 0x0008, 0x0009,
341             0x0100, 0x0101, 0x0200, 0x0210
342         ])
343         while data:
344             if len(data) < 4:
345                 raise ProtocolError("Truncated header.")
346             (part_type, part_len) = struct.unpack("!HH", data[:4])
347             data = data[4:]
348             if part_type not in types:
349                 raise ProtocolError("Invalid part type: 0x%02x" % part_type)
350             part_len -= 4  # includes four header bytes we just parsed
351             if len(data) < part_len:
352                 raise ProtocolError("Truncated value.")
353             part_data, data = data[:part_len], data[part_len:]
354             yield (part_type, part_data)
355
356     def parse_values(self, stype, data):
357         """
358         Parse the value of a particular type
359         """
360         types = {0: "!Q", 1: "<d", 2: "!q", 3: "!Q"}
361         (nvals,) = struct.unpack("!H", data[:2])
362         data = data[2:]
363         if len(data) != 9 * nvals:
364             raise ProtocolError("Invalid value structure length.")
365         vtypes = self.types.get(stype)
366         if nvals != len(vtypes):
367             raise ProtocolError("Values different than types.db info.")
368         for i in range(nvals):
369             vtype = data[i]
370             if vtype != vtypes[i][1]:
371                 if self.counter_eq_derive and \
372                    (vtype, vtypes[i][1]) in ((0, 2), (2, 0)):
373                     # if counter vs derive don't break, assume server is right
374                     LOG.debug("Type mismatch (counter/derive) for %s/%s",
375                               stype, vtypes[i][0])
376                 else:
377                     raise ProtocolError("Type mismatch with types.db")
378         data = data[nvals:]
379         for i in range(nvals):
380             vdata, data = data[:8], data[8:]
381             (val,) = struct.unpack(types[vtypes[i][1]], vdata)
382             yield vtypes[i][0], vtypes[i][1], val
383
384     @staticmethod
385     def _parse_string(name):
386         """
387         Parse string value
388         """
389         def _parser(sample, data):
390             """
391             Actual string parser
392             """
393             data = data.decode()
394             if data[-1] != '\0':
395                 raise ProtocolError("Invalid string detected.")
396             sample[name] = data[:-1]
397         return _parser
398
399     @staticmethod
400     def _parse_time(name):
401         """
402         Parse time value
403         """
404         def _parser(sample, data):
405             """
406             Actual time parser
407             """
408             if len(data) != 8:
409                 raise ProtocolError("Invalid time data length.")
410             (val,) = struct.unpack("!Q", data)
411             sample[name] = float(val)
412         return _parser
413
414     @staticmethod
415     def _parse_time_hires(name):
416         """
417         Parse time hires value
418         """
419         def _parser(sample, data):
420             """
421             Actual time hires parser
422             """
423             if len(data) != 8:
424                 raise ProtocolError("Invalid hires time data length.")
425             (val,) = struct.unpack("!Q", data)
426             sample[name] = val * (2 ** -30)
427         return _parser
428
429
430 class CollectDCrypto(object):
431     """
432     Handle the sercured communications with collectd daemon
433     """
434     def __init__(self):
435         sec_level = settings.getValue('COLLECTD_SECURITY_LEVEL')
436         if sec_level in ("sign", "SIGN", "Sign", 1):
437             self.sec_level = 1
438         elif sec_level in ("encrypt", "ENCRYPT", "Encrypt", 2):
439             self.sec_level = 2
440         else:
441             self.sec_level = 0
442         if self.sec_level:
443             self.auth_file = settings.getValue('COLLECTD_AUTH_FILE')
444             self.auth_db = {}
445             if self.auth_file:
446                 self.load_auth_file()
447             if not self.auth_file:
448                 raise ConfigError("Collectd security level configured but no "
449                                   "auth file specified in configuration")
450             if not self.auth_db:
451                 LOG.warning("Collectd security level configured but no "
452                             "user/passwd entries loaded from auth file")
453
454     def load_auth_file(self):
455         """
456         Loading the authentication file.
457         """
458         try:
459             fil = open(self.auth_file)
460         except IOError as exc:
461             raise ConfigError("Unable to load collectd's auth file: %r" % exc)
462         self.auth_db.clear()
463         for line in fil:
464             line = line.strip()
465             if not line or line[0] == "#":
466                 continue
467             user, passwd = line.split(":", 1)
468             user = user.strip()
469             passwd = passwd.strip()
470             if not user or not passwd:
471                 LOG.warning("Found line with missing user or password")
472                 continue
473             if user in self.auth_db:
474                 LOG.warning("Found multiple entries for single user")
475             self.auth_db[user] = passwd
476         fil.close()
477         LOG.info("Loaded collectd's auth file from %s", self.auth_file)
478
479     def parse(self, data):
480         """
481         Parse the non-encrypted message
482         """
483         if len(data) < 4:
484             raise ProtocolError("Truncated header.")
485         part_type, part_len = struct.unpack("!HH", data[:4])
486         sec_level = {0x0200: 1, 0x0210: 2}.get(part_type, 0)
487         if sec_level < self.sec_level:
488             raise ProtocolError("Packet has lower security level than allowed")
489         if not sec_level:
490             return data
491         if sec_level == 1 and not self.sec_level:
492             return data[part_len:]
493         data = data[4:]
494         part_len -= 4
495         if len(data) < part_len:
496             raise ProtocolError("Truncated part payload.")
497         if sec_level == 1:
498             return self.parse_signed(part_len, data)
499         if sec_level == 2:
500             return self.parse_encrypted(part_len, data)
501
502     def parse_signed(self, part_len, data):
503         """
504         Parse the signed message
505         """
506
507         if part_len <= 32:
508             raise ProtocolError("Truncated signed part.")
509         sig, data = data[:32], data[32:]
510         uname_len = part_len - 32
511         uname = data[:uname_len].decode()
512         if uname not in self.auth_db:
513             raise ProtocolError("Signed packet, unknown user '%s'" % uname)
514         password = self.auth_db[uname].encode()
515         sig2 = hmac.new(password, msg=data, digestmod=sha256).digest()
516         if not self._hashes_match(sig, sig2):
517             raise ProtocolError("Bad signature from user '%s'" % uname)
518         data = data[uname_len:]
519         return data
520
521     def parse_encrypted(self, part_len, data):
522         """
523         Parse the encrypted message
524         """
525         if part_len != len(data):
526             raise ProtocolError("Enc pkt size disaggrees with header.")
527         if len(data) <= 38:
528             raise ProtocolError("Truncated encrypted part.")
529         uname_len, data = struct.unpack("!H", data[:2])[0], data[2:]
530         if len(data) <= uname_len + 36:
531             raise ProtocolError("Truncated encrypted part.")
532         uname, data = data[:uname_len].decode(), data[uname_len:]
533         if uname not in self.auth_db:
534             raise ProtocolError("Couldn't decrypt, unknown user '%s'" % uname)
535         ival, data = data[:16], data[16:]
536         password = self.auth_db[uname].encode()
537         key = sha256(password).digest()
538         pad_bytes = 16 - (len(data) % 16)
539         data += b'\0' * pad_bytes
540         data = AES.new(key, IV=ival, mode=AES.MODE_OFB).decrypt(data)
541         data = data[:-pad_bytes]
542         tag, data = data[:20], data[20:]
543         tag2 = sha1(data).digest()
544         if not self._hashes_match(tag, tag2):
545             raise ProtocolError("Bad checksum on enc pkt for '%s'" % uname)
546         return data
547
548     @staticmethod
549     def _hashes_match(val_a, val_b):
550         """Constant time comparison of bytes """
551         if len(val_a) != len(val_b):
552             return False
553         diff = 0
554         for val_x, val_y in zip(val_a, val_b):
555             diff |= val_x ^ val_y
556         return not diff
557
558
559 class CollectDConverter(object):
560     """
561     Handle all conversions.
562     Coversion: Convert the sample received from collectd to an
563     appropriate format - for easy processing
564     """
565     def __init__(self):
566         self.converters = dict(DEFAULT_CONVERTERS)
567
568     def convert(self, sample):
569         """
570         Main conversion handling.
571         """
572         default = self.converters["_default"]
573         handler = self.converters.get(sample["plugin"], default)
574         try:
575             name_parts = handler(sample)
576             if name_parts is None:
577                 return  # treat None as "ignore sample"
578             name = '.'.join(name_parts)
579         except (AttributeError, IndexError, MemoryError, RuntimeError):
580             LOG.exception("Exception in sample handler  %s (%s):",
581                           sample["plugin"], handler)
582             return
583         host = sample.get("host", "")
584         return (
585             host,
586             name,
587             sample["value_type"],
588             sample["value"],
589             int(sample["time"])
590         )
591
592     def _add_converter(self, name, inst, source="unknown"):
593         """
594         Add new converter types
595         """
596         if name not in self.converters:
597             LOG.info("Converter: %s from %s", name, source)
598             self.converters[name] = inst
599             return
600         kpriority = getattr(inst, "PRIORITY", 0)
601         ipriority = getattr(self.converters[name], "PRIORITY", 0)
602         if kpriority > ipriority:
603             LOG.info("Replacing: %s", name)
604             LOG.info("Converter: %s from %s", name, source)
605             self.converters[name] = inst
606             return
607         LOG.info("Ignoring: %s (%s) from %s (priority: %s vs %s)",
608                  name, inst, source, kpriority, ipriority)
609
610
611 class CollectDHandler(object):
612     """Wraps all CollectD parsing functionality in a class"""
613
614     def __init__(self):
615         self.crypto = CollectDCrypto()
616         collectd_types = []
617         collectd_counter_eq_derive = False
618         self.parser = CollectDParser(collectd_types,
619                                      collectd_counter_eq_derive)
620         self.converter = CollectDConverter()
621         self.prev_samples = {}
622         self.last_sample = None
623
624     def parse(self, data):
625         """
626         Parse the samples from collectd
627         """
628         try:
629             data = self.crypto.parse(data)
630         except ProtocolError as error:
631             LOG.error("Protocol error in CollectDCrypto: %s", error)
632             return
633         try:
634             for sample in self.parser.parse(data):
635                 self.last_sample = sample
636                 stype = sample["type"]
637                 vname = sample["value_name"]
638                 sample = self.converter.convert(sample)
639                 if sample is None:
640                     continue
641                 host, name, vtype, val, time = sample
642                 if not name.strip():
643                     continue
644                 val = self.calculate(host, name, vtype, val, time)
645                 val = self.check_range(stype, vname, val)
646                 if val is not None:
647                     yield host, name, val, time
648         except ProtocolError as error:
649             LOG.error("Protocol error: %s", error)
650             if self.last_sample is not None:
651                 LOG.info("Last sample: %s", self.last_sample)
652
653     def check_range(self, stype, vname, val):
654         """
655         Check the value range
656         """
657         if val is None:
658             return
659         try:
660             vmin, vmax = self.parser.types.type_ranges[stype][vname]
661         except KeyError:
662             LOG.error("Couldn't find vmin, vmax in CollectDTypes")
663             return val
664         if vmin is not None and val < vmin:
665             LOG.debug("Invalid value %s (<%s) for %s", val, vmin, vname)
666             LOG.debug("Last sample: %s", self.last_sample)
667             return
668         if vmax is not None and val > vmax:
669             LOG.debug("Invalid value %s (>%s) for %s", val, vmax, vname)
670             LOG.debug("Last sample: %s", self.last_sample)
671             return
672         return val
673
674     def calculate(self, host, name, vtype, val, time):
675         """
676         Perform calculations for handlers
677         """
678         handlers = {
679             0: self._calc_counter,  # counter
680             1: lambda _host, _name, v, _time: v,  # gauge
681             2: self._calc_derive,  # derive
682             3: self._calc_absolute  # absolute
683         }
684         if vtype not in handlers:
685             LOG.error("Invalid value type %s for %s", vtype, name)
686             LOG.info("Last sample: %s", self.last_sample)
687             return
688         return handlers[vtype](host, name, val, time)
689
690     def _calc_counter(self, host, name, val, time):
691         """
692         Calculating counter values
693         """
694         key = (host, name)
695         if key not in self.prev_samples:
696             self.prev_samples[key] = (val, time)
697             return
698         pval, ptime = self.prev_samples[key]
699         self.prev_samples[key] = (val, time)
700         if time <= ptime:
701             LOG.error("Invalid COUNTER update for: %s:%s", key[0], key[1])
702             LOG.info("Last sample: %s", self.last_sample)
703             return
704         if val < pval:
705             # this is supposed to handle counter wrap around
706             # see https://collectd.org/wiki/index.php/Data_source
707             LOG.debug("COUNTER wrap-around for: %s:%s (%s -> %s)",
708                       host, name, pval, val)
709             if pval < 0x100000000:
710                 val += 0x100000000  # 2**32
711             else:
712                 val += 0x10000000000000000  # 2**64
713         return float(val - pval) / (time - ptime)
714
715     def _calc_derive(self, host, name, val, time):
716         """
717         Calculating derived values
718         """
719         key = (host, name)
720         if key not in self.prev_samples:
721             self.prev_samples[key] = (val, time)
722             return
723         pval, ptime = self.prev_samples[key]
724         self.prev_samples[key] = (val, time)
725         if time <= ptime:
726             LOG.debug("Invalid DERIVE update for: %s:%s", key[0], key[1])
727             LOG.debug("Last sample: %s", self.last_sample)
728             return
729         return float(abs(val - pval)) / (time - ptime)
730
731     def _calc_absolute(self, host, name, val, time):
732         """
733         Calculating absolute values
734         """
735         key = (host, name)
736         if key not in self.prev_samples:
737             self.prev_samples[key] = (val, time)
738             return
739         _, ptime = self.prev_samples[key]
740         self.prev_samples[key] = (val, time)
741         if time <= ptime:
742             LOG.error("Invalid ABSOLUTE update for: %s:%s", key[0], key[1])
743             LOG.info("Last sample: %s", self.last_sample)
744             return
745         return float(val) / (time - ptime)
746
747
748 class CollectDServer(UDPServer):
749     """Single processes CollectDServer"""
750
751     def __init__(self, queue):
752         super(CollectDServer, self).__init__(settings.getValue('COLLECTD_IP'),
753                                              settings.getValue('COLLECTD_PORT'))
754         self.handler = CollectDHandler()
755         self.queue = queue
756
757     def handle(self, data, addr):
758         for sample in self.handler.parse(data):
759             self.queue.put(sample)
760         return True
761
762     def pre_shutdown(self):
763         LOG.info("Sutting down CollectDServer")
764
765
766 def get_collectd_server(queue):
767     """Get the collectd server """
768     server = CollectDServer
769     return server(queue)