1 # Copyright 2014-2018 TRBS, Spirent Communications
3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 # use this file except in compliance with the License. You may obtain a copy of
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 # License for the specific language governing permissions and limitations under
15 # This file is a modified version of scripts present in bucky software
16 # details of bucky can be found at https://github.com/trbs/bucky
19 This module receives the samples from collectd, processes it and
20 enqueues it in a format suitable for easy processing.
21 It also handles secure communication with collectd.
26 import multiprocessing
31 from hashlib import sha1, sha256
33 from Crypto.Cipher import AES
34 from conf import settings
37 LOG = logging.getLogger(__name__)
40 class CollectdError(Exception):
44 def __init__(self, mesg):
45 super(CollectdError, self).__init__(mesg)
52 class ConnectError(CollectdError):
59 class ConfigError(CollectdError):
66 class ProtocolError(CollectdError):
73 class UDPServer(multiprocessing.Process):
75 Actual UDP server receiving collectd samples over network
77 def __init__(self, ip, port):
78 super(UDPServer, self).__init__()
80 addrinfo = socket.getaddrinfo(ip, port,
81 socket.AF_UNSPEC, socket.SOCK_DGRAM)
82 afamily, _, _, _, addr = addrinfo[0]
86 self.sock = socket.socket(afamily, socket.SOCK_DGRAM)
87 self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
89 self.sock.bind((ip, port))
90 LOG.info("Bound socket socket %s:%s", ip, port)
92 LOG.exception("Error binding socket %s:%s.", ip, port)
95 self.sock_recvfrom = self.sock.recvfrom
99 Start receiving messages
101 recvfrom = self.sock_recvfrom
104 data, addr = recvfrom(65535)
105 except (IOError, KeyboardInterrupt):
107 addr = addr[:2] # for compatibility with longer ipv6 tuples
110 if not self.handle(data, addr):
115 LOG.exception("Failed pre_shutdown method for %s",
116 self.__class__.__name__)
118 def handle(self, data, addr):
122 raise NotImplementedError()
124 def pre_shutdown(self):
125 """ Pre shutdown hook """
130 Close the communication
134 def send(self, data):
136 Send over the network
138 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
139 if not isinstance(data, bytes):
141 sock.sendto(data, 0, (self.ip_addr, self.port))
144 class CPUConverter(object):
146 Converter for CPU samples fom collectd.
150 def __call__(self, sample):
151 return ["cpu", sample["plugin_instance"], sample["type_instance"]]
154 class InterfaceConverter(object):
156 Converter for Interface samples from collectd
160 def __call__(self, sample):
162 parts.append("interface")
163 if sample.get("plugin_instance", ""):
164 parts.append(sample["plugin_instance"].strip())
165 stypei = sample.get("type_instance", "").strip()
168 stype = sample.get("type").strip()
171 vname = sample.get("value_name").strip()
177 class MemoryConverter(object):
179 Converter for Memory samples from collectd
183 def __call__(self, sample):
184 return ["memory", sample["type_instance"]]
187 class DefaultConverter(object):
189 Default converter for samples from collectd
193 def __call__(self, sample):
195 parts.append(sample["plugin"].strip())
196 if sample.get("plugin_instance"):
197 parts.append(sample["plugin_instance"].strip())
198 stype = sample.get("type", "").strip()
199 if stype and stype != "value":
201 stypei = sample.get("type_instance", "").strip()
204 vname = sample.get("value_name").strip()
205 if vname and vname != "value":
210 DEFAULT_CONVERTERS = {
211 "cpu": CPUConverter(),
212 "interface": InterfaceConverter(),
213 "memory": MemoryConverter(),
214 "_default": DefaultConverter(),
218 class CollectDTypes(object):
220 Class to handle the sample types. The types.db that comes
221 with collectd, usually, defines the various types.
223 def __init__(self, types_dbs=None):
224 if types_dbs is None:
226 dirs = ["/opt/collectd/share/collectd/types.db",
227 "/usr/local/share/collectd/types.db"]
229 self.type_ranges = {}
231 types_dbs = [tdb for tdb in dirs if os.path.exists(tdb)]
233 raise ConfigError("Unable to locate types.db")
234 self.types_dbs = types_dbs
239 Get the name of the type
241 t_name = self.types.get(name)
243 raise ProtocolError("Invalid type name: %s" % name)
246 def _load_types(self):
248 Load all the types from types_db
250 for types_db in self.types_dbs:
251 with open(types_db) as handle:
253 if line.lstrip()[:1] == "#":
257 self._add_type_line(line)
258 LOG.info("Loaded collectd types from %s", types_db)
260 def _add_type_line(self, line):
262 Add types information
270 name, spec = line.split(None, 1)
271 self.types[name] = []
272 self.type_ranges[name] = {}
273 vals = spec.split(", ")
275 vname, vtype, minv, maxv = val.strip().split(":")
276 vtype = types.get(vtype)
278 raise ValueError("Invalid value type: %s" % vtype)
279 minv = None if minv == "U" else float(minv)
280 maxv = None if maxv == "U" else float(maxv)
281 self.types[name].append((vname, vtype))
282 self.type_ranges[name][vname] = (minv, maxv)
285 class CollectDParser(object):
287 Parser class: Implements the sample parsing operations.
288 The types definition defines the parsing process.
290 def __init__(self, types_dbs=None, counter_eq_derive=False):
291 if types_dbs is None:
293 self.types = CollectDTypes(types_dbs=types_dbs)
294 self.counter_eq_derive = counter_eq_derive
296 def parse(self, data):
298 Parse individual samples
300 for sample in self.parse_samples(data):
303 def parse_samples(self, data):
305 Extract all the samples from the message.
308 0x0000: self._parse_string("host"),
309 0x0001: self._parse_time("time"),
310 0x0008: self._parse_time_hires("time"),
311 0x0002: self._parse_string("plugin"),
312 0x0003: self._parse_string("plugin_instance"),
313 0x0004: self._parse_string("type"),
314 0x0005: self._parse_string("type_instance"),
315 0x0006: None, # handle specially
316 0x0007: self._parse_time("interval"),
317 0x0009: self._parse_time_hires("interval")
320 for (ptype, pdata) in self.parse_data(data):
321 if ptype not in types:
322 LOG.debug("Ignoring part type: 0x%02x", ptype)
325 types[ptype](sample, pdata)
327 for vname, vtype, val in self.parse_values(sample["type"], pdata):
328 sample["value_name"] = vname
329 sample["value_type"] = vtype
330 sample["value"] = val
331 yield copy.deepcopy(sample)
334 def parse_data(data):
339 0x0000, 0x0001, 0x0002, 0x0003, 0x0004,
340 0x0005, 0x0006, 0x0007, 0x0008, 0x0009,
341 0x0100, 0x0101, 0x0200, 0x0210
345 raise ProtocolError("Truncated header.")
346 (part_type, part_len) = struct.unpack("!HH", data[:4])
348 if part_type not in types:
349 raise ProtocolError("Invalid part type: 0x%02x" % part_type)
350 part_len -= 4 # includes four header bytes we just parsed
351 if len(data) < part_len:
352 raise ProtocolError("Truncated value.")
353 part_data, data = data[:part_len], data[part_len:]
354 yield (part_type, part_data)
356 def parse_values(self, stype, data):
358 Parse the value of a particular type
360 types = {0: "!Q", 1: "<d", 2: "!q", 3: "!Q"}
361 (nvals,) = struct.unpack("!H", data[:2])
363 if len(data) != 9 * nvals:
364 raise ProtocolError("Invalid value structure length.")
365 vtypes = self.types.get(stype)
366 if nvals != len(vtypes):
367 raise ProtocolError("Values different than types.db info.")
368 for i in range(nvals):
370 if vtype != vtypes[i][1]:
371 if self.counter_eq_derive and \
372 (vtype, vtypes[i][1]) in ((0, 2), (2, 0)):
373 # if counter vs derive don't break, assume server is right
374 LOG.debug("Type mismatch (counter/derive) for %s/%s",
377 raise ProtocolError("Type mismatch with types.db")
379 for i in range(nvals):
380 vdata, data = data[:8], data[8:]
381 (val,) = struct.unpack(types[vtypes[i][1]], vdata)
382 yield vtypes[i][0], vtypes[i][1], val
385 def _parse_string(name):
389 def _parser(sample, data):
395 raise ProtocolError("Invalid string detected.")
396 sample[name] = data[:-1]
400 def _parse_time(name):
404 def _parser(sample, data):
409 raise ProtocolError("Invalid time data length.")
410 (val,) = struct.unpack("!Q", data)
411 sample[name] = float(val)
415 def _parse_time_hires(name):
417 Parse time hires value
419 def _parser(sample, data):
421 Actual time hires parser
424 raise ProtocolError("Invalid hires time data length.")
425 (val,) = struct.unpack("!Q", data)
426 sample[name] = val * (2 ** -30)
430 class CollectDCrypto(object):
432 Handle the sercured communications with collectd daemon
435 sec_level = settings.getValue('COLLECTD_SECURITY_LEVEL')
436 if sec_level in ("sign", "SIGN", "Sign", 1):
438 elif sec_level in ("encrypt", "ENCRYPT", "Encrypt", 2):
443 self.auth_file = settings.getValue('COLLECTD_AUTH_FILE')
446 self.load_auth_file()
447 if not self.auth_file:
448 raise ConfigError("Collectd security level configured but no "
449 "auth file specified in configuration")
451 LOG.warning("Collectd security level configured but no "
452 "user/passwd entries loaded from auth file")
454 def load_auth_file(self):
456 Loading the authentication file.
459 fil = open(self.auth_file)
460 except IOError as exc:
461 raise ConfigError("Unable to load collectd's auth file: %r" % exc)
465 if not line or line[0] == "#":
467 user, passwd = line.split(":", 1)
469 passwd = passwd.strip()
470 if not user or not passwd:
471 LOG.warning("Found line with missing user or password")
473 if user in self.auth_db:
474 LOG.warning("Found multiple entries for single user")
475 self.auth_db[user] = passwd
477 LOG.info("Loaded collectd's auth file from %s", self.auth_file)
479 def parse(self, data):
481 Parse the non-encrypted message
484 raise ProtocolError("Truncated header.")
485 part_type, part_len = struct.unpack("!HH", data[:4])
486 sec_level = {0x0200: 1, 0x0210: 2}.get(part_type, 0)
487 if sec_level < self.sec_level:
488 raise ProtocolError("Packet has lower security level than allowed")
491 if sec_level == 1 and not self.sec_level:
492 return data[part_len:]
495 if len(data) < part_len:
496 raise ProtocolError("Truncated part payload.")
498 return self.parse_signed(part_len, data)
500 return self.parse_encrypted(part_len, data)
502 def parse_signed(self, part_len, data):
504 Parse the signed message
508 raise ProtocolError("Truncated signed part.")
509 sig, data = data[:32], data[32:]
510 uname_len = part_len - 32
511 uname = data[:uname_len].decode()
512 if uname not in self.auth_db:
513 raise ProtocolError("Signed packet, unknown user '%s'" % uname)
514 password = self.auth_db[uname].encode()
515 sig2 = hmac.new(password, msg=data, digestmod=sha256).digest()
516 if not self._hashes_match(sig, sig2):
517 raise ProtocolError("Bad signature from user '%s'" % uname)
518 data = data[uname_len:]
521 def parse_encrypted(self, part_len, data):
523 Parse the encrypted message
525 if part_len != len(data):
526 raise ProtocolError("Enc pkt size disaggrees with header.")
528 raise ProtocolError("Truncated encrypted part.")
529 uname_len, data = struct.unpack("!H", data[:2])[0], data[2:]
530 if len(data) <= uname_len + 36:
531 raise ProtocolError("Truncated encrypted part.")
532 uname, data = data[:uname_len].decode(), data[uname_len:]
533 if uname not in self.auth_db:
534 raise ProtocolError("Couldn't decrypt, unknown user '%s'" % uname)
535 ival, data = data[:16], data[16:]
536 password = self.auth_db[uname].encode()
537 key = sha256(password).digest()
538 pad_bytes = 16 - (len(data) % 16)
539 data += b'\0' * pad_bytes
540 data = AES.new(key, IV=ival, mode=AES.MODE_OFB).decrypt(data)
541 data = data[:-pad_bytes]
542 tag, data = data[:20], data[20:]
543 tag2 = sha1(data).digest()
544 if not self._hashes_match(tag, tag2):
545 raise ProtocolError("Bad checksum on enc pkt for '%s'" % uname)
549 def _hashes_match(val_a, val_b):
550 """Constant time comparison of bytes """
551 if len(val_a) != len(val_b):
554 for val_x, val_y in zip(val_a, val_b):
555 diff |= val_x ^ val_y
559 class CollectDConverter(object):
561 Handle all conversions.
562 Coversion: Convert the sample received from collectd to an
563 appropriate format - for easy processing
566 self.converters = dict(DEFAULT_CONVERTERS)
568 def convert(self, sample):
570 Main conversion handling.
572 default = self.converters["_default"]
573 handler = self.converters.get(sample["plugin"], default)
575 name_parts = handler(sample)
576 if name_parts is None:
577 return # treat None as "ignore sample"
578 name = '.'.join(name_parts)
579 except (AttributeError, IndexError, MemoryError, RuntimeError):
580 LOG.exception("Exception in sample handler %s (%s):",
581 sample["plugin"], handler)
583 host = sample.get("host", "")
587 sample["value_type"],
592 def _add_converter(self, name, inst, source="unknown"):
594 Add new converter types
596 if name not in self.converters:
597 LOG.info("Converter: %s from %s", name, source)
598 self.converters[name] = inst
600 kpriority = getattr(inst, "PRIORITY", 0)
601 ipriority = getattr(self.converters[name], "PRIORITY", 0)
602 if kpriority > ipriority:
603 LOG.info("Replacing: %s", name)
604 LOG.info("Converter: %s from %s", name, source)
605 self.converters[name] = inst
607 LOG.info("Ignoring: %s (%s) from %s (priority: %s vs %s)",
608 name, inst, source, kpriority, ipriority)
611 class CollectDHandler(object):
612 """Wraps all CollectD parsing functionality in a class"""
615 self.crypto = CollectDCrypto()
617 collectd_counter_eq_derive = False
618 self.parser = CollectDParser(collectd_types,
619 collectd_counter_eq_derive)
620 self.converter = CollectDConverter()
621 self.prev_samples = {}
622 self.last_sample = None
624 def parse(self, data):
626 Parse the samples from collectd
629 data = self.crypto.parse(data)
630 except ProtocolError as error:
631 LOG.error("Protocol error in CollectDCrypto: %s", error)
634 for sample in self.parser.parse(data):
635 self.last_sample = sample
636 stype = sample["type"]
637 vname = sample["value_name"]
638 sample = self.converter.convert(sample)
641 host, name, vtype, val, time = sample
644 val = self.calculate(host, name, vtype, val, time)
645 val = self.check_range(stype, vname, val)
647 yield host, name, val, time
648 except ProtocolError as error:
649 LOG.error("Protocol error: %s", error)
650 if self.last_sample is not None:
651 LOG.info("Last sample: %s", self.last_sample)
653 def check_range(self, stype, vname, val):
655 Check the value range
660 vmin, vmax = self.parser.types.type_ranges[stype][vname]
662 LOG.error("Couldn't find vmin, vmax in CollectDTypes")
664 if vmin is not None and val < vmin:
665 LOG.debug("Invalid value %s (<%s) for %s", val, vmin, vname)
666 LOG.debug("Last sample: %s", self.last_sample)
668 if vmax is not None and val > vmax:
669 LOG.debug("Invalid value %s (>%s) for %s", val, vmax, vname)
670 LOG.debug("Last sample: %s", self.last_sample)
674 def calculate(self, host, name, vtype, val, time):
676 Perform calculations for handlers
679 0: self._calc_counter, # counter
680 1: lambda _host, _name, v, _time: v, # gauge
681 2: self._calc_derive, # derive
682 3: self._calc_absolute # absolute
684 if vtype not in handlers:
685 LOG.error("Invalid value type %s for %s", vtype, name)
686 LOG.info("Last sample: %s", self.last_sample)
688 return handlers[vtype](host, name, val, time)
690 def _calc_counter(self, host, name, val, time):
692 Calculating counter values
695 if key not in self.prev_samples:
696 self.prev_samples[key] = (val, time)
698 pval, ptime = self.prev_samples[key]
699 self.prev_samples[key] = (val, time)
701 LOG.error("Invalid COUNTER update for: %s:%s", key[0], key[1])
702 LOG.info("Last sample: %s", self.last_sample)
705 # this is supposed to handle counter wrap around
706 # see https://collectd.org/wiki/index.php/Data_source
707 LOG.debug("COUNTER wrap-around for: %s:%s (%s -> %s)",
708 host, name, pval, val)
709 if pval < 0x100000000:
710 val += 0x100000000 # 2**32
712 val += 0x10000000000000000 # 2**64
713 return float(val - pval) / (time - ptime)
715 def _calc_derive(self, host, name, val, time):
717 Calculating derived values
720 if key not in self.prev_samples:
721 self.prev_samples[key] = (val, time)
723 pval, ptime = self.prev_samples[key]
724 self.prev_samples[key] = (val, time)
726 LOG.debug("Invalid DERIVE update for: %s:%s", key[0], key[1])
727 LOG.debug("Last sample: %s", self.last_sample)
729 return float(abs(val - pval)) / (time - ptime)
731 def _calc_absolute(self, host, name, val, time):
733 Calculating absolute values
736 if key not in self.prev_samples:
737 self.prev_samples[key] = (val, time)
739 _, ptime = self.prev_samples[key]
740 self.prev_samples[key] = (val, time)
742 LOG.error("Invalid ABSOLUTE update for: %s:%s", key[0], key[1])
743 LOG.info("Last sample: %s", self.last_sample)
745 return float(val) / (time - ptime)
748 class CollectDServer(UDPServer):
749 """Single processes CollectDServer"""
751 def __init__(self, queue):
752 super(CollectDServer, self).__init__(settings.getValue('COLLECTD_IP'),
753 settings.getValue('COLLECTD_PORT'))
754 self.handler = CollectDHandler()
757 def handle(self, data, addr):
758 for sample in self.handler.parse(data):
759 self.queue.put(sample)
762 def pre_shutdown(self):
763 LOG.info("Sutting down CollectDServer")
766 def get_collectd_server(queue):
767 """Get the collectd server """
768 server = CollectDServer