Merge "Tools: Improve Stability."
[vswitchperf.git] / tools / collectors / collectd / collectd_bucky.py
1 # Copyright 2014-2018 TRBS, Spirent Communications
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 # use this file except in compliance with the License. You may obtain a copy of
5 # the License at
6 #
7 #   http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 # License for the specific language governing permissions and limitations under
13 # the License.
14
15 # This file is a modified version of scripts present in bucky software
16 # details of bucky can be found at https://github.com/trbs/bucky
17
18 """
19 This module receives the samples from collectd, processes it and
20 enqueues it in a format suitable for easy processing.
21 It also handles secure communication with collectd.
22 """
23 import copy
24 import hmac
25 import logging
26 import multiprocessing
27 import os
28 import socket
29 import struct
30 import sys
31 from hashlib import sha1, sha256
32
33 from Crypto.Cipher import AES
34 from conf import settings
35
36 logging.basicConfig()
37 LOG = logging.getLogger(__name__)
38
39
40 class CollectdError(Exception):
41     """
42     Custom error class.
43     """
44     def __init__(self, mesg):
45         super(CollectdError, self).__init__(mesg)
46         self.mesg = mesg
47
48     def __str__(self):
49         return self.mesg
50
51
52 class ConnectError(CollectdError):
53     """
54     Custom connect error
55     """
56     pass
57
58
59 class ConfigError(CollectdError):
60     """
61     Custom config error
62     """
63     pass
64
65
66 class ProtocolError(CollectdError):
67     """
68     Custom protocol error
69     """
70     pass
71
72
73 class UDPServer(multiprocessing.Process):
74     """
75     Actual UDP server receiving collectd samples over network
76     """
77     def __init__(self, ip, port):
78         super(UDPServer, self).__init__()
79         self.daemon = True
80         addrinfo = socket.getaddrinfo(ip, port,
81                                       socket.AF_UNSPEC, socket.SOCK_DGRAM)
82         afamily, _, _, _, addr = addrinfo[0]
83         ip, port = addr[:2]
84         self.ip_addr = ip
85         self.port = port
86         self.sock = socket.socket(afamily, socket.SOCK_DGRAM)
87         self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
88         try:
89             self.sock.bind((ip, port))
90             LOG.info("Bound socket socket %s:%s", ip, port)
91         except socket.error:
92             LOG.exception("Error binding socket %s:%s.", ip, port)
93             sys.exit(1)
94
95         self.sock_recvfrom = self.sock.recvfrom
96
97     def run(self):
98         """
99         Start receiving messages
100         """
101         recvfrom = self.sock_recvfrom
102         while True:
103             try:
104                 data, addr = recvfrom(65535)
105             except (IOError, KeyboardInterrupt):
106                 continue
107             addr = addr[:2]  # for compatibility with longer ipv6 tuples
108             if data == b'EXIT':
109                 break
110             if not self.handle(data, addr):
111                 break
112         try:
113             self.pre_shutdown()
114         except SystemExit:
115             LOG.exception("Failed pre_shutdown method for %s",
116                           self.__class__.__name__)
117
118     def handle(self, data, addr):
119         """
120         Handle the message.
121         """
122         raise NotImplementedError()
123
124     def pre_shutdown(self):
125         """ Pre shutdown hook """
126         pass
127
128     def close(self):
129         """
130         Close the communication
131         """
132         self.send('EXIT')
133
134     def send(self, data):
135         """
136         Send over the network
137         """
138         sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
139         if not isinstance(data, bytes):
140             data = data.encode()
141         sock.sendto(data, 0, (self.ip_addr, self.port))
142
143
144 class CPUConverter(object):
145     """
146     Converter for CPU samples fom collectd.
147     """
148     PRIORITY = -1
149
150     def __call__(self, sample):
151         return ["cpu", sample["plugin_instance"], sample["type_instance"]]
152
153
154 class InterfaceConverter(object):
155     """
156     Converter for Interface samples from collectd
157     """
158     PRIORITY = -1
159
160     def __call__(self, sample):
161         parts = []
162         parts.append("interface")
163         if sample.get("plugin_instance", ""):
164             parts.append(sample["plugin_instance"].strip())
165         stypei = sample.get("type_instance", "").strip()
166         if stypei:
167             parts.append(stypei)
168         stype = sample.get("type").strip()
169         if stype:
170             parts.append(stype)
171         vname = sample.get("value_name").strip()
172         if vname:
173             parts.append(vname)
174         return parts
175
176
177 class MemoryConverter(object):
178     """
179     Converter for Memory samples from collectd
180     """
181     PRIORITY = -1
182
183     def __call__(self, sample):
184         return ["memory", sample["type_instance"]]
185
186
187 class DefaultConverter(object):
188     """
189     Default converter for samples from collectd
190     """
191     PRIORITY = -1
192
193     def __call__(self, sample):
194         parts = []
195         parts.append(sample["plugin"].strip())
196         if sample.get("plugin_instance"):
197             parts.append(sample["plugin_instance"].strip())
198         stype = sample.get("type", "").strip()
199         if stype and stype != "value":
200             parts.append(stype)
201         stypei = sample.get("type_instance", "").strip()
202         if stypei:
203             parts.append(stypei)
204         vname = sample.get("value_name").strip()
205         if vname and vname != "value":
206             parts.append(vname)
207         return parts
208
209
210 DEFAULT_CONVERTERS = {
211     "cpu": CPUConverter(),
212     "interface": InterfaceConverter(),
213     "memory": MemoryConverter(),
214     "_default": DefaultConverter(),
215 }
216
217
218 class CollectDTypes(object):
219     """
220     Class to handle the sample types. The types.db that comes
221     with collectd, usually, defines the various types.
222     """
223     def __init__(self, types_dbs=None):
224         if types_dbs is None:
225             types_dbs = []
226         dirs = ["/opt/collectd/share/collectd/types.db",
227                 "/usr/local/share/collectd/types.db"]
228         self.types = {}
229         self.type_ranges = {}
230         if not types_dbs:
231             types_dbs = [tdb for tdb in dirs if os.path.exists(tdb)]
232             if not types_dbs:
233                 raise ConfigError("Unable to locate types.db")
234         self.types_dbs = types_dbs
235         self._load_types()
236
237     def get(self, name):
238         """
239         Get the name of the type
240         """
241         t_name = self.types.get(name)
242         if t_name is None:
243             raise ProtocolError("Invalid type name: %s" % name)
244         return t_name
245
246     def _load_types(self):
247         """
248         Load all the types from types_db
249         """
250         for types_db in self.types_dbs:
251             with open(types_db) as handle:
252                 for line in handle:
253                     if line.lstrip()[:1] == "#":
254                         continue
255                     if not line.strip():
256                         continue
257                     self._add_type_line(line)
258             LOG.info("Loaded collectd types from %s", types_db)
259
260     def _add_type_line(self, line):
261         """
262         Add types information
263         """
264         types = {
265             "COUNTER": 0,
266             "GAUGE": 1,
267             "DERIVE": 2,
268             "ABSOLUTE": 3
269         }
270         name, spec = line.split(None, 1)
271         self.types[name] = []
272         self.type_ranges[name] = {}
273         vals = spec.split(", ")
274         for val in vals:
275             vname, vtype, minv, maxv = val.strip().split(":")
276             vtype = types.get(vtype)
277             if vtype is None:
278                 raise ValueError("Invalid value type: %s" % vtype)
279             minv = None if minv == "U" else float(minv)
280             maxv = None if maxv == "U" else float(maxv)
281             self.types[name].append((vname, vtype))
282             self.type_ranges[name][vname] = (minv, maxv)
283
284
285 class CollectDParser(object):
286     """
287     Parser class: Implements the sample parsing operations.
288     The types definition defines the parsing process.
289     """
290     def __init__(self, types_dbs=None, counter_eq_derive=False):
291         if types_dbs is None:
292             types_dbs = []
293         self.types = CollectDTypes(types_dbs=types_dbs)
294         self.counter_eq_derive = counter_eq_derive
295
296     def parse(self, data):
297         """
298         Parse individual samples
299         """
300         for sample in self.parse_samples(data):
301             yield sample
302
303     def parse_samples(self, data):
304         """
305         Extract all the samples from the message.
306         """
307         types = {
308             0x0000: self._parse_string("host"),
309             0x0001: self._parse_time("time"),
310             0x0008: self._parse_time_hires("time"),
311             0x0002: self._parse_string("plugin"),
312             0x0003: self._parse_string("plugin_instance"),
313             0x0004: self._parse_string("type"),
314             0x0005: self._parse_string("type_instance"),
315             0x0006: None,  # handle specially
316             0x0007: self._parse_time("interval"),
317             0x0009: self._parse_time_hires("interval")
318         }
319         sample = {}
320         for (ptype, pdata) in self.parse_data(data):
321             if ptype not in types:
322                 LOG.debug("Ignoring part type: 0x%02x", ptype)
323                 continue
324             if ptype != 0x0006:
325                 types[ptype](sample, pdata)
326                 continue
327             for vname, vtype, val in self.parse_values(sample["type"], pdata):
328                 sample["value_name"] = vname
329                 sample["value_type"] = vtype
330                 sample["value"] = val
331                 yield copy.deepcopy(sample)
332
333     @staticmethod
334     def parse_data(data):
335         """
336         Parse the message
337         """
338         types = set([
339             0x0000, 0x0001, 0x0002, 0x0003, 0x0004,
340             0x0005, 0x0006, 0x0007, 0x0008, 0x0009,
341             0x0100, 0x0101, 0x0200, 0x0210
342         ])
343         while data:
344             if len(data) < 4:
345                 raise ProtocolError("Truncated header.")
346             (part_type, part_len) = struct.unpack("!HH", data[:4])
347             data = data[4:]
348             if part_type not in types:
349                 raise ProtocolError("Invalid part type: 0x%02x" % part_type)
350             part_len -= 4  # includes four header bytes we just parsed
351             if len(data) < part_len:
352                 raise ProtocolError("Truncated value.")
353             part_data, data = data[:part_len], data[part_len:]
354             yield (part_type, part_data)
355
356     def parse_values(self, stype, data):
357         """
358         Parse the value of a particular type
359         """
360         types = {0: "!Q", 1: "<d", 2: "!q", 3: "!Q"}
361         (nvals,) = struct.unpack("!H", data[:2])
362         data = data[2:]
363         if len(data) != 9 * nvals:
364             raise ProtocolError("Invalid value structure length.")
365         vtypes = self.types.get(stype)
366         if nvals != len(vtypes):
367             raise ProtocolError("Values different than types.db info.")
368         for i in range(nvals):
369             vtype = data[i]
370             if vtype != vtypes[i][1]:
371                 if self.counter_eq_derive and \
372                    (vtype, vtypes[i][1]) in ((0, 2), (2, 0)):
373                     # if counter vs derive don't break, assume server is right
374                     LOG.debug("Type mismatch (counter/derive) for %s/%s",
375                               stype, vtypes[i][0])
376                 else:
377                     raise ProtocolError("Type mismatch with types.db")
378         data = data[nvals:]
379         for i in range(nvals):
380             vdata, data = data[:8], data[8:]
381             (val,) = struct.unpack(types[vtypes[i][1]], vdata)
382             yield vtypes[i][0], vtypes[i][1], val
383
384     @staticmethod
385     def _parse_string(name):
386         """
387         Parse string value
388         """
389         def _parser(sample, data):
390             """
391             Actual string parser
392             """
393             data = data.decode()
394             if data[-1] != '\0':
395                 raise ProtocolError("Invalid string detected.")
396             sample[name] = data[:-1]
397         return _parser
398
399     @staticmethod
400     def _parse_time(name):
401         """
402         Parse time value
403         """
404         def _parser(sample, data):
405             """
406             Actual time parser
407             """
408             if len(data) != 8:
409                 raise ProtocolError("Invalid time data length.")
410             (val,) = struct.unpack("!Q", data)
411             sample[name] = float(val)
412         return _parser
413
414     @staticmethod
415     def _parse_time_hires(name):
416         """
417         Parse time hires value
418         """
419         def _parser(sample, data):
420             """
421             Actual time hires parser
422             """
423             if len(data) != 8:
424                 raise ProtocolError("Invalid hires time data length.")
425             (val,) = struct.unpack("!Q", data)
426             sample[name] = val * (2 ** -30)
427         return _parser
428
429
430 class CollectDCrypto(object):
431     """
432     Handle the sercured communications with collectd daemon
433     """
434     def __init__(self):
435         sec_level = settings.getValue('COLLECTD_SECURITY_LEVEL')
436         if sec_level in ("sign", "SIGN", "Sign", 1):
437             self.sec_level = 1
438         elif sec_level in ("encrypt", "ENCRYPT", "Encrypt", 2):
439             self.sec_level = 2
440         else:
441             self.sec_level = 0
442         if self.sec_level:
443             self.auth_file = settings.getValue('COLLECTD_AUTH_FILE')
444             self.auth_db = {}
445             if self.auth_file:
446                 self.load_auth_file()
447             if not self.auth_file:
448                 raise ConfigError("Collectd security level configured but no "
449                                   "auth file specified in configuration")
450             if not self.auth_db:
451                 LOG.warning("Collectd security level configured but no "
452                             "user/passwd entries loaded from auth file")
453
454     def load_auth_file(self):
455         """
456         Loading the authentication file.
457         """
458         try:
459             fil = open(self.auth_file)
460         except IOError as exc:
461             raise ConfigError("Unable to load collectd's auth file: %r" % exc)
462         self.auth_db.clear()
463         for line in fil:
464             line = line.strip()
465             if not line or line[0] == "#":
466                 continue
467             user, passwd = line.split(":", 1)
468             user = user.strip()
469             passwd = passwd.strip()
470             if not user or not passwd:
471                 LOG.warning("Found line with missing user or password")
472                 continue
473             if user in self.auth_db:
474                 LOG.warning("Found multiple entries for single user")
475             self.auth_db[user] = passwd
476         fil.close()
477         LOG.info("Loaded collectd's auth file from %s", self.auth_file)
478
479     def parse(self, data):
480         """
481         Parse the non-encrypted message
482         """
483         if len(data) < 4:
484             raise ProtocolError("Truncated header.")
485         part_type, part_len = struct.unpack("!HH", data[:4])
486         sec_level = {0x0200: 1, 0x0210: 2}.get(part_type, 0)
487         if sec_level < self.sec_level:
488             raise ProtocolError("Packet has lower security level than allowed")
489         if not sec_level:
490             return data
491         if sec_level == 1 and not self.sec_level:
492             return data[part_len:]
493         data = data[4:]
494         part_len -= 4
495         if len(data) < part_len:
496             raise ProtocolError("Truncated part payload.")
497         if sec_level == 1:
498             return self.parse_signed(part_len, data)
499         if sec_level == 2:
500             return self.parse_encrypted(part_len, data)
501         return None
502
503     def parse_signed(self, part_len, data):
504         """
505         Parse the signed message
506         """
507
508         if part_len <= 32:
509             raise ProtocolError("Truncated signed part.")
510         sig, data = data[:32], data[32:]
511         uname_len = part_len - 32
512         uname = data[:uname_len].decode()
513         if uname not in self.auth_db:
514             raise ProtocolError("Signed packet, unknown user '%s'" % uname)
515         password = self.auth_db[uname].encode()
516         sig2 = hmac.new(password, msg=data, digestmod=sha256).digest()
517         if not self._hashes_match(sig, sig2):
518             raise ProtocolError("Bad signature from user '%s'" % uname)
519         data = data[uname_len:]
520         return data
521
522     def parse_encrypted(self, part_len, data):
523         """
524         Parse the encrypted message
525         """
526         if part_len != len(data):
527             raise ProtocolError("Enc pkt size disaggrees with header.")
528         if len(data) <= 38:
529             raise ProtocolError("Truncated encrypted part.")
530         uname_len, data = struct.unpack("!H", data[:2])[0], data[2:]
531         if len(data) <= uname_len + 36:
532             raise ProtocolError("Truncated encrypted part.")
533         uname, data = data[:uname_len].decode(), data[uname_len:]
534         if uname not in self.auth_db:
535             raise ProtocolError("Couldn't decrypt, unknown user '%s'" % uname)
536         ival, data = data[:16], data[16:]
537         password = self.auth_db[uname].encode()
538         key = sha256(password).digest()
539         pad_bytes = 16 - (len(data) % 16)
540         data += b'\0' * pad_bytes
541         data = AES.new(key, IV=ival, mode=AES.MODE_OFB).decrypt(data)
542         data = data[:-pad_bytes]
543         tag, data = data[:20], data[20:]
544         tag2 = sha1(data).digest()
545         if not self._hashes_match(tag, tag2):
546             raise ProtocolError("Bad checksum on enc pkt for '%s'" % uname)
547         return data
548
549     @staticmethod
550     def _hashes_match(val_a, val_b):
551         """Constant time comparison of bytes """
552         if len(val_a) != len(val_b):
553             return False
554         diff = 0
555         for val_x, val_y in zip(val_a, val_b):
556             diff |= val_x ^ val_y
557         return not diff
558
559
560 class CollectDConverter(object):
561     """
562     Handle all conversions.
563     Coversion: Convert the sample received from collectd to an
564     appropriate format - for easy processing
565     """
566     def __init__(self):
567         self.converters = dict(DEFAULT_CONVERTERS)
568
569     def convert(self, sample):
570         """
571         Main conversion handling.
572         """
573         default = self.converters["_default"]
574         handler = self.converters.get(sample["plugin"], default)
575         try:
576             name_parts = handler(sample)
577             if name_parts is None:
578                 return None  # treat None as "ignore sample"
579             name = '.'.join(name_parts)
580         except (AttributeError, IndexError, MemoryError, RuntimeError):
581             LOG.exception("Exception in sample handler  %s (%s):",
582                           sample["plugin"], handler)
583             return None
584         host = sample.get("host", "")
585         return (
586             host,
587             name,
588             sample["value_type"],
589             sample["value"],
590             int(sample["time"])
591         )
592
593     def _add_converter(self, name, inst, source="unknown"):
594         """
595         Add new converter types
596         """
597         if name not in self.converters:
598             LOG.info("Converter: %s from %s", name, source)
599             self.converters[name] = inst
600             return
601         kpriority = getattr(inst, "PRIORITY", 0)
602         ipriority = getattr(self.converters[name], "PRIORITY", 0)
603         if kpriority > ipriority:
604             LOG.info("Replacing: %s", name)
605             LOG.info("Converter: %s from %s", name, source)
606             self.converters[name] = inst
607             return
608         LOG.info("Ignoring: %s (%s) from %s (priority: %s vs %s)",
609                  name, inst, source, kpriority, ipriority)
610
611
612 class CollectDHandler(object):
613     """Wraps all CollectD parsing functionality in a class"""
614
615     def __init__(self):
616         self.crypto = CollectDCrypto()
617         collectd_types = []
618         collectd_counter_eq_derive = False
619         self.parser = CollectDParser(collectd_types,
620                                      collectd_counter_eq_derive)
621         self.converter = CollectDConverter()
622         self.prev_samples = {}
623         self.last_sample = None
624
625     def parse(self, data):
626         """
627         Parse the samples from collectd
628         """
629         try:
630             data = self.crypto.parse(data)
631         except ProtocolError as error:
632             LOG.error("Protocol error in CollectDCrypto: %s", error)
633             return
634         try:
635             for sample in self.parser.parse(data):
636                 self.last_sample = sample
637                 stype = sample["type"]
638                 vname = sample["value_name"]
639                 sample = self.converter.convert(sample)
640                 if sample is None:
641                     continue
642                 host, name, vtype, val, time = sample
643                 if not name.strip():
644                     continue
645                 val = self.calculate(host, name, vtype, val, time)
646                 val = self.check_range(stype, vname, val)
647                 if val is not None:
648                     yield host, name, val, time
649         except ProtocolError as error:
650             LOG.error("Protocol error: %s", error)
651             if self.last_sample is not None:
652                 LOG.info("Last sample: %s", self.last_sample)
653
654     def check_range(self, stype, vname, val):
655         """
656         Check the value range
657         """
658         if val is None:
659             return None
660         try:
661             vmin, vmax = self.parser.types.type_ranges[stype][vname]
662         except KeyError:
663             LOG.error("Couldn't find vmin, vmax in CollectDTypes")
664             return val
665         if vmin is not None and val < vmin:
666             LOG.debug("Invalid value %s (<%s) for %s", val, vmin, vname)
667             LOG.debug("Last sample: %s", self.last_sample)
668             return None
669         if vmax is not None and val > vmax:
670             LOG.debug("Invalid value %s (>%s) for %s", val, vmax, vname)
671             LOG.debug("Last sample: %s", self.last_sample)
672             return None
673         return val
674
675     def calculate(self, host, name, vtype, val, time):
676         """
677         Perform calculations for handlers
678         """
679         handlers = {
680             0: self._calc_counter,  # counter
681             1: lambda _host, _name, v, _time: v,  # gauge
682             2: self._calc_derive,  # derive
683             3: self._calc_absolute  # absolute
684         }
685         if vtype not in handlers:
686             LOG.error("Invalid value type %s for %s", vtype, name)
687             LOG.info("Last sample: %s", self.last_sample)
688             return None
689         return handlers[vtype](host, name, val, time)
690
691     def _calc_counter(self, host, name, val, time):
692         """
693         Calculating counter values
694         """
695         key = (host, name)
696         if key not in self.prev_samples:
697             self.prev_samples[key] = (val, time)
698             return None
699         pval, ptime = self.prev_samples[key]
700         self.prev_samples[key] = (val, time)
701         if time <= ptime:
702             LOG.error("Invalid COUNTER update for: %s:%s", key[0], key[1])
703             LOG.info("Last sample: %s", self.last_sample)
704             return None
705         if val < pval:
706             # this is supposed to handle counter wrap around
707             # see https://collectd.org/wiki/index.php/Data_source
708             LOG.debug("COUNTER wrap-around for: %s:%s (%s -> %s)",
709                       host, name, pval, val)
710             if pval < 0x100000000:
711                 val += 0x100000000  # 2**32
712             else:
713                 val += 0x10000000000000000  # 2**64
714         return float(val - pval) / (time - ptime)
715
716     def _calc_derive(self, host, name, val, time):
717         """
718         Calculating derived values
719         """
720         key = (host, name)
721         if key not in self.prev_samples:
722             self.prev_samples[key] = (val, time)
723             return None
724         pval, ptime = self.prev_samples[key]
725         self.prev_samples[key] = (val, time)
726         if time <= ptime:
727             LOG.debug("Invalid DERIVE update for: %s:%s", key[0], key[1])
728             LOG.debug("Last sample: %s", self.last_sample)
729             return None
730         return float(abs(val - pval)) / (time - ptime)
731
732     def _calc_absolute(self, host, name, val, time):
733         """
734         Calculating absolute values
735         """
736         key = (host, name)
737         if key not in self.prev_samples:
738             self.prev_samples[key] = (val, time)
739             return None
740         _, ptime = self.prev_samples[key]
741         self.prev_samples[key] = (val, time)
742         if time <= ptime:
743             LOG.error("Invalid ABSOLUTE update for: %s:%s", key[0], key[1])
744             LOG.info("Last sample: %s", self.last_sample)
745             return None
746         return float(val) / (time - ptime)
747
748
749 class CollectDServer(UDPServer):
750     """Single processes CollectDServer"""
751
752     def __init__(self, queue):
753         super(CollectDServer, self).__init__(settings.getValue('COLLECTD_IP'),
754                                              settings.getValue('COLLECTD_PORT'))
755         self.handler = CollectDHandler()
756         self.queue = queue
757
758     def handle(self, data, addr):
759         for sample in self.handler.parse(data):
760             self.queue.put(sample)
761         return True
762
763     def pre_shutdown(self):
764         LOG.info("Sutting down CollectDServer")
765
766
767 def get_collectd_server(queue):
768     """Get the collectd server """
769     server = CollectDServer
770     return server(queue)