cabf1cb1d31d91d05e071f6090ef711f1980eadd
[nfvbench.git] / nfvbench / traffic_gen / trex.py
1 # Copyright 2016 Cisco Systems, Inc.  All rights reserved.
2 #
3 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
4 #    not use this file except in compliance with the License. You may obtain
5 #    a copy of the License at
6 #
7 #         http://www.apache.org/licenses/LICENSE-2.0
8 #
9 #    Unless required by applicable law or agreed to in writing, software
10 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 #    License for the specific language governing permissions and limitations
13 #    under the License.
14
15 import os
16 import random
17 import time
18 import traceback
19
20 from collections import defaultdict
21 from itertools import count
22 from nfvbench.log import LOG
23 from nfvbench.specs import ChainType
24 from nfvbench.traffic_server import TRexTrafficServer
25 from nfvbench.utils import cast_integer
26 from nfvbench.utils import timeout
27 from nfvbench.utils import TimeoutError
28 from traffic_base import AbstractTrafficGenerator
29 from traffic_base import TrafficGeneratorException
30 import traffic_utils as utils
31
32 # pylint: disable=import-error
33 from trex_stl_lib.api import CTRexVmInsFixHwCs
34 from trex_stl_lib.api import Dot1Q
35 from trex_stl_lib.api import Ether
36 from trex_stl_lib.api import IP
37 from trex_stl_lib.api import STLClient
38 from trex_stl_lib.api import STLError
39 from trex_stl_lib.api import STLFlowLatencyStats
40 from trex_stl_lib.api import STLFlowStats
41 from trex_stl_lib.api import STLPktBuilder
42 from trex_stl_lib.api import STLScVmRaw
43 from trex_stl_lib.api import STLStream
44 from trex_stl_lib.api import STLTXCont
45 from trex_stl_lib.api import STLVmFixChecksumHw
46 from trex_stl_lib.api import STLVmFlowVar
47 from trex_stl_lib.api import STLVmFlowVarRepetableRandom
48 from trex_stl_lib.api import STLVmWrFlowVar
49 from trex_stl_lib.api import UDP
50 from trex_stl_lib.services.trex_stl_service_arp import STLServiceARP
51
52
53 # pylint: enable=import-error
54
55
56 class TRex(AbstractTrafficGenerator):
57     LATENCY_PPS = 1000
58
59     def __init__(self, runner):
60         AbstractTrafficGenerator.__init__(self, runner)
61         self.client = None
62         self.id = count()
63         self.latencies = defaultdict(list)
64         self.stream_ids = defaultdict(list)
65         self.port_handle = []
66         self.streamblock = defaultdict(list)
67         self.rates = []
68         self.arps = {}
69         self.capture_id = None
70         self.packet_list = []
71
72     def get_version(self):
73         return self.client.get_server_version()
74
75     def extract_stats(self, in_stats):
76         """Extract stats from dict returned by Trex API.
77
78         :param in_stats: dict as returned by TRex api
79         """
80         utils.nan_replace(in_stats)
81         LOG.debug(in_stats)
82
83         result = {}
84         # port_handles should have only 2 elements: [0, 1]
85         # so (1 - ph) will be the index for the far end port
86         for ph in self.port_handle:
87             stats = in_stats[ph]
88             far_end_stats = in_stats[1 - ph]
89             result[ph] = {
90                 'tx': {
91                     'total_pkts': cast_integer(stats['opackets']),
92                     'total_pkt_bytes': cast_integer(stats['obytes']),
93                     'pkt_rate': cast_integer(stats['tx_pps']),
94                     'pkt_bit_rate': cast_integer(stats['tx_bps'])
95                 },
96                 'rx': {
97                     'total_pkts': cast_integer(stats['ipackets']),
98                     'total_pkt_bytes': cast_integer(stats['ibytes']),
99                     'pkt_rate': cast_integer(stats['rx_pps']),
100                     'pkt_bit_rate': cast_integer(stats['rx_bps']),
101                     # how many pkts were dropped in RX direction
102                     # need to take the tx counter on the far end port
103                     'dropped_pkts': cast_integer(
104                         far_end_stats['opackets'] - stats['ipackets'])
105                 }
106             }
107
108             lat = self.__combine_latencies(in_stats, ph)
109             result[ph]['rx']['max_delay_usec'] = cast_integer(
110                 lat['total_max']) if 'total_max' in lat else float('nan')
111             result[ph]['rx']['min_delay_usec'] = cast_integer(
112                 lat['total_min']) if 'total_min' in lat else float('nan')
113             result[ph]['rx']['avg_delay_usec'] = cast_integer(
114                 lat['average']) if 'average' in lat else float('nan')
115         total_tx_pkts = result[0]['tx']['total_pkts'] + result[1]['tx']['total_pkts']
116         result["total_tx_rate"] = cast_integer(total_tx_pkts / self.config.duration_sec)
117         return result
118
119     def __combine_latencies(self, in_stats, port_handle):
120         """Traverses TRex result dictionary and combines chosen latency stats."""
121         if not self.latencies[port_handle]:
122             return {}
123
124         result = defaultdict(float)
125         result['total_min'] = float("inf")
126         for lat_id in self.latencies[port_handle]:
127             lat = in_stats['latency'][lat_id]
128             result['dropped_pkts'] += lat['err_cntrs']['dropped']
129             result['total_max'] = max(lat['latency']['total_max'], result['total_max'])
130             result['total_min'] = min(lat['latency']['total_min'], result['total_min'])
131             result['average'] += lat['latency']['average']
132
133         result['average'] /= len(self.latencies[port_handle])
134
135         return result
136
137     def create_pkt(self, stream_cfg, l2frame_size):
138
139         pkt_base = Ether(src=stream_cfg['mac_src'], dst=stream_cfg['mac_dst'])
140         if stream_cfg['vlan_tag'] is not None:
141             # 50 = 14 (Ethernet II) + 4 (Vlan tag) + 4 (CRC Checksum) + 20 (IPv4) + 8 (UDP)
142             pkt_base /= Dot1Q(vlan=stream_cfg['vlan_tag'])
143             l2payload_size = int(l2frame_size) - 50
144         else:
145             # 46 = 14 (Ethernet II) + 4 (CRC Checksum) + 20 (IPv4) + 8 (UDP)
146             l2payload_size = int(l2frame_size) - 46
147         payload = 'x' * l2payload_size
148         udp_args = {}
149         if stream_cfg['udp_src_port']:
150             udp_args['sport'] = int(stream_cfg['udp_src_port'])
151         if stream_cfg['udp_dst_port']:
152             udp_args['dport'] = int(stream_cfg['udp_dst_port'])
153         pkt_base /= IP() / UDP(**udp_args)
154
155         if stream_cfg['ip_addrs_step'] == 'random':
156             src_fv = STLVmFlowVarRepetableRandom(
157                 name="ip_src",
158                 min_value=stream_cfg['ip_src_addr'],
159                 max_value=stream_cfg['ip_src_addr_max'],
160                 size=4,
161                 seed=random.randint(0, 32767),
162                 limit=stream_cfg['ip_src_count'])
163             dst_fv = STLVmFlowVarRepetableRandom(
164                 name="ip_dst",
165                 min_value=stream_cfg['ip_dst_addr'],
166                 max_value=stream_cfg['ip_dst_addr_max'],
167                 size=4,
168                 seed=random.randint(0, 32767),
169                 limit=stream_cfg['ip_dst_count'])
170         else:
171             src_fv = STLVmFlowVar(
172                 name="ip_src",
173                 min_value=stream_cfg['ip_src_addr'],
174                 max_value=stream_cfg['ip_src_addr'],
175                 size=4,
176                 op="inc",
177                 step=stream_cfg['ip_addrs_step'])
178             dst_fv = STLVmFlowVar(
179                 name="ip_dst",
180                 min_value=stream_cfg['ip_dst_addr'],
181                 max_value=stream_cfg['ip_dst_addr_max'],
182                 size=4,
183                 op="inc",
184                 step=stream_cfg['ip_addrs_step'])
185
186         vm_param = [
187             src_fv,
188             STLVmWrFlowVar(fv_name="ip_src", pkt_offset="IP.src"),
189             dst_fv,
190             STLVmWrFlowVar(fv_name="ip_dst", pkt_offset="IP.dst"),
191             STLVmFixChecksumHw(l3_offset="IP",
192                                l4_offset="UDP",
193                                l4_type=CTRexVmInsFixHwCs.L4_TYPE_UDP)
194         ]
195
196         return STLPktBuilder(pkt=pkt_base / payload, vm=STLScVmRaw(vm_param))
197
198     def generate_streams(self, port_handle, stream_cfg, l2frame, isg=0.0, latency=True):
199         idx_lat = None
200         streams = []
201         if l2frame == 'IMIX':
202             min_size = 64 if stream_cfg['vlan_tag'] is None else 68
203             self.adjust_imix_min_size(min_size)
204             for t, (ratio, l2_frame_size) in enumerate(zip(self.imix_ratios, self.imix_l2_sizes)):
205                 pkt = self.create_pkt(stream_cfg, l2_frame_size)
206                 streams.append(STLStream(packet=pkt,
207                                          isg=0.1 * t,
208                                          flow_stats=STLFlowStats(
209                                              pg_id=self.stream_ids[port_handle]),
210                                          mode=STLTXCont(pps=ratio)))
211
212             if latency:
213                 idx_lat = self.id.next()
214                 pkt = self.create_pkt(stream_cfg, self.imix_avg_l2_size)
215                 sl = STLStream(packet=pkt,
216                                isg=isg,
217                                flow_stats=STLFlowLatencyStats(pg_id=idx_lat),
218                                mode=STLTXCont(pps=self.LATENCY_PPS))
219                 streams.append(sl)
220         else:
221             pkt = self.create_pkt(stream_cfg, l2frame)
222             streams.append(STLStream(packet=pkt,
223                                      flow_stats=STLFlowStats(pg_id=self.stream_ids[port_handle]),
224                                      mode=STLTXCont()))
225
226             if latency:
227                 idx_lat = self.id.next()
228                 streams.append(STLStream(packet=pkt,
229                                          flow_stats=STLFlowLatencyStats(pg_id=idx_lat),
230                                          mode=STLTXCont(pps=self.LATENCY_PPS)))
231
232         if latency:
233             self.latencies[port_handle].append(idx_lat)
234
235         return streams
236
237     def init(self):
238         pass
239
240     @timeout(5)
241     def __connect(self, client):
242         client.connect()
243
244     def __connect_after_start(self):
245         # after start, Trex may take a bit of time to initialize
246         # so we need to retry a few times
247         for it in xrange(self.config.generic_retry_count):
248             try:
249                 time.sleep(1)
250                 self.client.connect()
251                 break
252             except Exception as ex:
253                 if it == (self.config.generic_retry_count - 1):
254                     raise
255                 LOG.info("Retrying connection to TRex (%s)...", ex.message)
256
257     def connect(self):
258         LOG.info("Connecting to TRex...")
259         server_ip = self.config.generator_config.ip
260
261         # Connect to TRex server
262         self.client = STLClient(server=server_ip)
263         try:
264             self.__connect(self.client)
265         except (TimeoutError, STLError) as e:
266             if server_ip == '127.0.0.1':
267                 try:
268                     self.__start_server()
269                     self.__connect_after_start()
270                 except (TimeoutError, STLError) as e:
271                     LOG.error('Cannot connect to TRex')
272                     LOG.error(traceback.format_exc())
273                     logpath = '/tmp/trex.log'
274                     if os.path.isfile(logpath):
275                         # Wait for TRex to finish writing error message
276                         last_size = 0
277                         for _ in xrange(self.config.generic_retry_count):
278                             size = os.path.getsize(logpath)
279                             if size == last_size:
280                                 # probably not writing anymore
281                                 break
282                             last_size = size
283                             time.sleep(1)
284                         with open(logpath, 'r') as f:
285                             message = f.read()
286                     else:
287                         message = e.message
288                     raise TrafficGeneratorException(message)
289             else:
290                 raise TrafficGeneratorException(e.message)
291
292         ports = list(self.config.generator_config.ports)
293         self.port_handle = ports
294         # Prepare the ports
295         self.client.reset(ports)
296
297     def set_mode(self):
298         if self.config.service_chain == ChainType.EXT and not self.config.no_arp:
299             self.__set_l3_mode()
300         else:
301             self.__set_l2_mode()
302
303     def __set_l3_mode(self):
304         self.client.set_service_mode(ports=self.port_handle, enabled=True)
305         for port, device in zip(self.port_handle, self.config.generator_config.devices):
306             try:
307                 self.client.set_l3_mode(port=port,
308                                         src_ipv4=device.tg_gateway_ip,
309                                         dst_ipv4=device.dst.gateway_ip,
310                                         vlan=device.vlan_tag if device.vlan_tagging else None)
311             except STLError:
312                 # TRex tries to resolve ARP already, doesn't have to be successful yet
313                 continue
314         self.client.set_service_mode(ports=self.port_handle, enabled=False)
315
316     def __set_l2_mode(self):
317         self.client.set_service_mode(ports=self.port_handle, enabled=True)
318         for port, device in zip(self.port_handle, self.config.generator_config.devices):
319             for cfg in device.get_stream_configs(self.config.generator_config.service_chain):
320                 self.client.set_l2_mode(port=port, dst_mac=cfg['mac_dst'])
321         self.client.set_service_mode(ports=self.port_handle, enabled=False)
322
323     def __start_server(self):
324         server = TRexTrafficServer()
325         server.run_server(self.config.generator_config, self.config.vlan_tagging)
326
327     def resolve_arp(self):
328         self.client.set_service_mode(ports=self.port_handle)
329         LOG.info('Polling ARP until successful')
330         resolved = 0
331         attempt = 0
332         for port, device in zip(self.port_handle, self.config.generator_config.devices):
333             ctx = self.client.create_service_ctx(port=port)
334
335             arps = [
336                 STLServiceARP(ctx,
337                               src_ip=cfg['ip_src_tg_gw'],
338                               dst_ip=cfg['mac_discovery_gw'],
339                               vlan=device.vlan_tag if device.vlan_tagging else None)
340                 for cfg in device.get_stream_configs(self.config.generator_config.service_chain)
341             ]
342
343             for _ in xrange(self.config.generic_retry_count):
344                 attempt += 1
345                 try:
346                     ctx.run(arps)
347                 except STLError:
348                     LOG.error(traceback.format_exc())
349                     continue
350
351                 self.arps[port] = [arp.get_record().dst_mac for arp in arps
352                                    if arp.get_record().dst_mac is not None]
353
354                 if len(self.arps[port]) == self.config.service_chain_count:
355                     resolved += 1
356                     LOG.info('ARP resolved successfully for port %s', port)
357                     break
358                 else:
359                     failed = [arp.get_record().dst_ip for arp in arps
360                               if arp.get_record().dst_mac is None]
361                     LOG.info('Retrying ARP for: %s (%d / %d)',
362                              failed, attempt, self.config.generic_retry_count)
363                     time.sleep(self.config.generic_poll_sec)
364
365         self.client.set_service_mode(ports=self.port_handle, enabled=False)
366         return resolved == len(self.port_handle)
367
368     def config_interface(self):
369         pass
370
371     def __is_rate_enough(self, l2frame_size, rates, bidirectional, latency):
372         """Check if rate provided by user is above requirements. Applies only if latency is True."""
373         intf_speed = self.config.generator_config.intf_speed
374         if latency:
375             if bidirectional:
376                 mult = 2
377                 total_rate = 0
378                 for rate in rates:
379                     r = utils.convert_rates(l2frame_size, rate, intf_speed)
380                     total_rate += int(r['rate_pps'])
381             else:
382                 mult = 1
383                 total_rate = utils.convert_rates(l2frame_size, rates[0], intf_speed)
384             # rate must be enough for latency stream and at least 1 pps for base stream per chain
385             required_rate = (self.LATENCY_PPS + 1) * self.config.service_chain_count * mult
386             result = utils.convert_rates(l2frame_size,
387                                          {'rate_pps': required_rate},
388                                          intf_speed * mult)
389             result['result'] = total_rate >= required_rate
390             return result
391
392         return {'result': True}
393
394     def create_traffic(self, l2frame_size, rates, bidirectional, latency=True):
395         r = self.__is_rate_enough(l2frame_size, rates, bidirectional, latency)
396         if not r['result']:
397             raise TrafficGeneratorException(
398                 'Required rate in total is at least one of: \n{pps}pps \n{bps}bps \n{load}%.'
399                 .format(pps=r['rate_pps'],
400                         bps=r['rate_bps'],
401                         load=r['rate_percent']))
402
403         stream_cfgs = [d.get_stream_configs(self.config.generator_config.service_chain)
404                        for d in self.config.generator_config.devices]
405         self.rates = [utils.to_rate_str(rate) for rate in rates]
406
407         for ph in self.port_handle:
408             # generate one pg_id for each direction
409             self.stream_ids[ph] = self.id.next()
410
411         for i, (fwd_stream_cfg, rev_stream_cfg) in enumerate(zip(*stream_cfgs)):
412             if self.config.service_chain == ChainType.EXT and not self.config.no_arp:
413                 fwd_stream_cfg['mac_dst'] = self.arps[self.port_handle[0]][i]
414                 rev_stream_cfg['mac_dst'] = self.arps[self.port_handle[1]][i]
415
416             self.streamblock[0].extend(self.generate_streams(self.port_handle[0],
417                                                              fwd_stream_cfg,
418                                                              l2frame_size,
419                                                              latency=latency))
420             if len(self.rates) > 1:
421                 self.streamblock[1].extend(self.generate_streams(self.port_handle[1],
422                                                                  rev_stream_cfg,
423                                                                  l2frame_size,
424                                                                  isg=10.0,
425                                                                  latency=bidirectional and latency))
426
427         for ph in self.port_handle:
428             self.client.add_streams(self.streamblock[ph], ports=ph)
429             LOG.info('Created traffic stream for port %s.', ph)
430
431     def clear_streamblock(self):
432         self.streamblock = defaultdict(list)
433         self.latencies = defaultdict(list)
434         self.stream_ids = defaultdict(list)
435         self.rates = []
436         self.client.reset(self.port_handle)
437         LOG.info('Cleared all existing streams.')
438
439     def get_stats(self):
440         stats = self.client.get_stats()
441         return self.extract_stats(stats)
442
443     def get_macs(self):
444         return [self.client.get_port_attr(port=port)['src_mac'] for port in self.port_handle]
445
446     def clear_stats(self):
447         if self.port_handle:
448             self.client.clear_stats()
449
450     def start_traffic(self):
451         for port, rate in zip(self.port_handle, self.rates):
452             self.client.start(ports=port, mult=rate, duration=self.config.duration_sec, force=True)
453
454     def stop_traffic(self):
455         self.client.stop(ports=self.port_handle)
456
457     def start_capture(self):
458         """Capture all packets on both ports that are unicast to us."""
459         if self.capture_id:
460             self.stop_capture()
461         # Need to filter out unwanted packets so we do not end up counting
462         # src MACs of frames that are not unicast to us
463         src_mac_list = self.get_macs()
464         bpf_filter = "ether dst %s or ether dst %s" % (src_mac_list[0], src_mac_list[1])
465         # ports must be set in service in order to enable capture
466         self.client.set_service_mode(ports=self.port_handle)
467         self.capture_id = self.client.start_capture(rx_ports=self.port_handle,
468                                                     bpf_filter=bpf_filter)
469
470     def fetch_capture_packets(self):
471         if self.capture_id:
472             self.packet_list = []
473             self.client.fetch_capture_packets(capture_id=self.capture_id['id'],
474                                               output=self.packet_list)
475
476     def stop_capture(self):
477         if self.capture_id:
478             self.client.stop_capture(capture_id=self.capture_id['id'])
479             self.capture_id = None
480             self.client.set_service_mode(ports=self.port_handle, enabled=False)
481
482     def cleanup(self):
483         if self.client:
484             try:
485                 self.client.reset(self.port_handle)
486                 self.client.disconnect()
487             except STLError:
488                 # TRex does not like a reset while in disconnected state
489                 pass