7d64f0f2bf827233fbff3809407e237675d6309d
[nfvbench.git] / nfvbench / traffic_gen / trex.py
1 # Copyright 2016 Cisco Systems, Inc.  All rights reserved.
2 #
3 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
4 #    not use this file except in compliance with the License. You may obtain
5 #    a copy of the License at
6 #
7 #         http://www.apache.org/licenses/LICENSE-2.0
8 #
9 #    Unless required by applicable law or agreed to in writing, software
10 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 #    License for the specific language governing permissions and limitations
13 #    under the License.
14
15 import os
16 import random
17 import time
18 import traceback
19
20 from collections import defaultdict
21 from itertools import count
22 from nfvbench.log import LOG
23 from nfvbench.specs import ChainType
24 from nfvbench.traffic_server import TRexTrafficServer
25 from nfvbench.utils import cast_integer
26 from nfvbench.utils import timeout
27 from nfvbench.utils import TimeoutError
28 from traffic_base import AbstractTrafficGenerator
29 from traffic_base import TrafficGeneratorException
30 import traffic_utils as utils
31
32 # pylint: disable=import-error
33 from trex_stl_lib.api import CTRexVmInsFixHwCs
34 from trex_stl_lib.api import Dot1Q
35 from trex_stl_lib.api import Ether
36 from trex_stl_lib.api import IP
37 from trex_stl_lib.api import STLClient
38 from trex_stl_lib.api import STLError
39 from trex_stl_lib.api import STLFlowLatencyStats
40 from trex_stl_lib.api import STLFlowStats
41 from trex_stl_lib.api import STLPktBuilder
42 from trex_stl_lib.api import STLScVmRaw
43 from trex_stl_lib.api import STLStream
44 from trex_stl_lib.api import STLTXCont
45 from trex_stl_lib.api import STLVmFixChecksumHw
46 from trex_stl_lib.api import STLVmFlowVar
47 from trex_stl_lib.api import STLVmFlowVarRepetableRandom
48 from trex_stl_lib.api import STLVmWrFlowVar
49 from trex_stl_lib.api import UDP
50 from trex_stl_lib.services.trex_stl_service_arp import STLServiceARP
51
52
53 # pylint: enable=import-error
54
55
56 class TRex(AbstractTrafficGenerator):
57     LATENCY_PPS = 1000
58
59     def __init__(self, runner):
60         AbstractTrafficGenerator.__init__(self, runner)
61         self.client = None
62         self.id = count()
63         self.latencies = defaultdict(list)
64         self.stream_ids = defaultdict(list)
65         self.port_handle = []
66         self.streamblock = defaultdict(list)
67         self.rates = []
68         self.arps = {}
69         self.capture_id = None
70         self.packet_list = []
71
72     def get_version(self):
73         return self.client.get_server_version()
74
75     def extract_stats(self, in_stats):
76         utils.nan_replace(in_stats)
77         LOG.debug(in_stats)
78
79         result = {}
80         for ph in self.port_handle:
81             stats = self.__combine_stats(in_stats, ph)
82             result[ph] = {
83                 'tx': {
84                     'total_pkts': cast_integer(stats['tx_pkts']['total']),
85                     'total_pkt_bytes': cast_integer(stats['tx_bytes']['total']),
86                     'pkt_rate': cast_integer(stats['tx_pps']['total']),
87                     'pkt_bit_rate': cast_integer(stats['tx_bps']['total'])
88                 },
89                 'rx': {
90                     'total_pkts': cast_integer(stats['rx_pkts']['total']),
91                     'total_pkt_bytes': cast_integer(stats['rx_bytes']['total']),
92                     'pkt_rate': cast_integer(stats['rx_pps']['total']),
93                     'pkt_bit_rate': cast_integer(stats['rx_bps']['total']),
94                     'dropped_pkts': cast_integer(
95                         stats['tx_pkts']['total'] - stats['rx_pkts']['total'])
96                 }
97             }
98
99             lat = self.__combine_latencies(in_stats, ph)
100             result[ph]['rx']['max_delay_usec'] = cast_integer(
101                 lat['total_max']) if 'total_max' in lat else float('nan')
102             result[ph]['rx']['min_delay_usec'] = cast_integer(
103                 lat['total_min']) if 'total_min' in lat else float('nan')
104             result[ph]['rx']['avg_delay_usec'] = cast_integer(
105                 lat['average']) if 'average' in lat else float('nan')
106         total_tx_pkts = result[0]['tx']['total_pkts'] + result[1]['tx']['total_pkts']
107         result["total_tx_rate"] = cast_integer(total_tx_pkts / self.config.duration_sec)
108         return result
109
110     def __combine_stats(self, in_stats, port_handle):
111         """Traverses TRex result dictionary and combines stream stats. Used for combining latency
112         and regular streams together.
113         """
114         result = defaultdict(lambda: defaultdict(float))
115
116         for pg_id in [self.stream_ids[port_handle]] + self.latencies[port_handle]:
117             record = in_stats['flow_stats'][pg_id]
118             for stat_type, stat_type_values in record.iteritems():
119                 for ph, value in stat_type_values.iteritems():
120                     result[stat_type][ph] += value
121
122         return result
123
124     def __combine_latencies(self, in_stats, port_handle):
125         """Traverses TRex result dictionary and combines chosen latency stats."""
126         if not self.latencies[port_handle]:
127             return {}
128
129         result = defaultdict(float)
130         result['total_min'] = float("inf")
131         for lat_id in self.latencies[port_handle]:
132             lat = in_stats['latency'][lat_id]
133             result['dropped_pkts'] += lat['err_cntrs']['dropped']
134             result['total_max'] = max(lat['latency']['total_max'], result['total_max'])
135             result['total_min'] = min(lat['latency']['total_min'], result['total_min'])
136             result['average'] += lat['latency']['average']
137
138         result['average'] /= len(self.latencies[port_handle])
139
140         return result
141
142     def create_pkt(self, stream_cfg, l2frame_size):
143
144         pkt_base = Ether(src=stream_cfg['mac_src'], dst=stream_cfg['mac_dst'])
145         if stream_cfg['vlan_tag'] is not None:
146             # 50 = 14 (Ethernet II) + 4 (Vlan tag) + 4 (CRC Checksum) + 20 (IPv4) + 8 (UDP)
147             pkt_base /= Dot1Q(vlan=stream_cfg['vlan_tag'])
148             l2payload_size = int(l2frame_size) - 50
149         else:
150             # 46 = 14 (Ethernet II) + 4 (CRC Checksum) + 20 (IPv4) + 8 (UDP)
151             l2payload_size = int(l2frame_size) - 46
152         payload = 'x' * l2payload_size
153         udp_args = {}
154         if stream_cfg['udp_src_port']:
155             udp_args['sport'] = int(stream_cfg['udp_src_port'])
156         if stream_cfg['udp_dst_port']:
157             udp_args['dport'] = int(stream_cfg['udp_dst_port'])
158         pkt_base /= IP() / UDP(**udp_args)
159
160         if stream_cfg['ip_addrs_step'] == 'random':
161             src_fv = STLVmFlowVarRepetableRandom(
162                 name="ip_src",
163                 min_value=stream_cfg['ip_src_addr'],
164                 max_value=stream_cfg['ip_src_addr_max'],
165                 size=4,
166                 seed=random.randint(0, 32767),
167                 limit=stream_cfg['ip_src_count'])
168             dst_fv = STLVmFlowVarRepetableRandom(
169                 name="ip_dst",
170                 min_value=stream_cfg['ip_dst_addr'],
171                 max_value=stream_cfg['ip_dst_addr_max'],
172                 size=4,
173                 seed=random.randint(0, 32767),
174                 limit=stream_cfg['ip_dst_count'])
175         else:
176             src_fv = STLVmFlowVar(
177                 name="ip_src",
178                 min_value=stream_cfg['ip_src_addr'],
179                 max_value=stream_cfg['ip_src_addr'],
180                 size=4,
181                 op="inc",
182                 step=stream_cfg['ip_addrs_step'])
183             dst_fv = STLVmFlowVar(
184                 name="ip_dst",
185                 min_value=stream_cfg['ip_dst_addr'],
186                 max_value=stream_cfg['ip_dst_addr_max'],
187                 size=4,
188                 op="inc",
189                 step=stream_cfg['ip_addrs_step'])
190
191         vm_param = [
192             src_fv,
193             STLVmWrFlowVar(fv_name="ip_src", pkt_offset="IP.src"),
194             dst_fv,
195             STLVmWrFlowVar(fv_name="ip_dst", pkt_offset="IP.dst"),
196             STLVmFixChecksumHw(l3_offset="IP",
197                                l4_offset="UDP",
198                                l4_type=CTRexVmInsFixHwCs.L4_TYPE_UDP)
199         ]
200
201         return STLPktBuilder(pkt=pkt_base / payload, vm=STLScVmRaw(vm_param))
202
203     def generate_streams(self, port_handle, stream_cfg, l2frame, isg=0.0, latency=True):
204         idx_lat = None
205         streams = []
206         if l2frame == 'IMIX':
207             min_size = 64 if stream_cfg['vlan_tag'] is None else 68
208             self.adjust_imix_min_size(min_size)
209             for t, (ratio, l2_frame_size) in enumerate(zip(self.imix_ratios, self.imix_l2_sizes)):
210                 pkt = self.create_pkt(stream_cfg, l2_frame_size)
211                 streams.append(STLStream(packet=pkt,
212                                          isg=0.1 * t,
213                                          flow_stats=STLFlowStats(
214                                              pg_id=self.stream_ids[port_handle]),
215                                          mode=STLTXCont(pps=ratio)))
216
217             if latency:
218                 idx_lat = self.id.next()
219                 sl = STLStream(packet=pkt,
220                                isg=isg,
221                                flow_stats=STLFlowLatencyStats(pg_id=idx_lat),
222                                mode=STLTXCont(pps=self.LATENCY_PPS))
223                 streams.append(sl)
224         else:
225             pkt = self.create_pkt(stream_cfg, l2frame)
226             streams.append(STLStream(packet=pkt,
227                                      flow_stats=STLFlowStats(pg_id=self.stream_ids[port_handle]),
228                                      mode=STLTXCont()))
229
230             if latency:
231                 idx_lat = self.id.next()
232                 streams.append(STLStream(packet=pkt,
233                                          flow_stats=STLFlowLatencyStats(pg_id=idx_lat),
234                                          mode=STLTXCont(pps=self.LATENCY_PPS)))
235
236         if latency:
237             self.latencies[port_handle].append(idx_lat)
238
239         return streams
240
241     def init(self):
242         pass
243
244     @timeout(5)
245     def __connect(self, client):
246         client.connect()
247
248     def __connect_after_start(self):
249         # after start, Trex may take a bit of time to initialize
250         # so we need to retry a few times
251         for it in xrange(self.config.generic_retry_count):
252             try:
253                 time.sleep(1)
254                 self.client.connect()
255                 break
256             except Exception as ex:
257                 if it == (self.config.generic_retry_count - 1):
258                     raise ex
259                 LOG.info("Retrying connection to TRex (%s)...", ex.message)
260
261     def connect(self):
262         LOG.info("Connecting to TRex...")
263         server_ip = self.config.generator_config.ip
264
265         # Connect to TRex server
266         self.client = STLClient(server=server_ip)
267         try:
268             self.__connect(self.client)
269         except (TimeoutError, STLError) as e:
270             if server_ip == '127.0.0.1':
271                 try:
272                     self.__start_server()
273                     self.__connect_after_start()
274                 except (TimeoutError, STLError) as e:
275                     LOG.error('Cannot connect to TRex')
276                     LOG.error(traceback.format_exc())
277                     logpath = '/tmp/trex.log'
278                     if os.path.isfile(logpath):
279                         # Wait for TRex to finish writing error message
280                         last_size = 0
281                         for _ in xrange(self.config.generic_retry_count):
282                             size = os.path.getsize(logpath)
283                             if size == last_size:
284                                 # probably not writing anymore
285                                 break
286                             last_size = size
287                             time.sleep(1)
288                         with open(logpath, 'r') as f:
289                             message = f.read()
290                     else:
291                         message = e.message
292                     raise TrafficGeneratorException(message)
293             else:
294                 raise TrafficGeneratorException(e.message)
295
296         ports = list(self.config.generator_config.ports)
297         self.port_handle = ports
298         # Prepare the ports
299         self.client.reset(ports)
300
301     def set_mode(self):
302         if self.config.service_chain == ChainType.EXT and not self.config.no_arp:
303             self.__set_l3_mode()
304         else:
305             self.__set_l2_mode()
306
307     def __set_l3_mode(self):
308         self.client.set_service_mode(ports=self.port_handle, enabled=True)
309         for port, device in zip(self.port_handle, self.config.generator_config.devices):
310             try:
311                 self.client.set_l3_mode(port=port,
312                                         src_ipv4=device.tg_gateway_ip,
313                                         dst_ipv4=device.dst.gateway_ip,
314                                         vlan=device.vlan_tag if device.vlan_tagging else None)
315             except STLError:
316                 # TRex tries to resolve ARP already, doesn't have to be successful yet
317                 continue
318         self.client.set_service_mode(ports=self.port_handle, enabled=False)
319
320     def __set_l2_mode(self):
321         self.client.set_service_mode(ports=self.port_handle, enabled=True)
322         for port, device in zip(self.port_handle, self.config.generator_config.devices):
323             for cfg in device.get_stream_configs(self.config.generator_config.service_chain):
324                 self.client.set_l2_mode(port=port, dst_mac=cfg['mac_dst'])
325         self.client.set_service_mode(ports=self.port_handle, enabled=False)
326
327     def __start_server(self):
328         server = TRexTrafficServer()
329         server.run_server(self.config.generator_config, self.config.vlan_tagging)
330
331     def resolve_arp(self):
332         self.client.set_service_mode(ports=self.port_handle)
333         LOG.info('Polling ARP until successful')
334         resolved = 0
335         attempt = 0
336         for port, device in zip(self.port_handle, self.config.generator_config.devices):
337             ctx = self.client.create_service_ctx(port=port)
338
339             arps = [
340                 STLServiceARP(ctx,
341                               src_ip=cfg['ip_src_tg_gw'],
342                               dst_ip=cfg['mac_discovery_gw'],
343                               vlan=device.vlan_tag if device.vlan_tagging else None)
344                 for cfg in device.get_stream_configs(self.config.generator_config.service_chain)
345             ]
346
347             for _ in xrange(self.config.generic_retry_count):
348                 attempt += 1
349                 try:
350                     ctx.run(arps)
351                 except STLError:
352                     LOG.error(traceback.format_exc())
353                     continue
354
355                 self.arps[port] = [arp.get_record().dst_mac for arp in arps
356                                    if arp.get_record().dst_mac is not None]
357
358                 if len(self.arps[port]) == self.config.service_chain_count:
359                     resolved += 1
360                     LOG.info('ARP resolved successfully for port %s', port)
361                     break
362                 else:
363                     failed = [arp.get_record().dst_ip for arp in arps
364                               if arp.get_record().dst_mac is None]
365                     LOG.info('Retrying ARP for: %s (%d / %d)',
366                              failed, attempt, self.config.generic_retry_count)
367                     time.sleep(self.config.generic_poll_sec)
368
369         self.client.set_service_mode(ports=self.port_handle, enabled=False)
370         return resolved == len(self.port_handle)
371
372     def config_interface(self):
373         pass
374
375     def __is_rate_enough(self, l2frame_size, rates, bidirectional, latency):
376         """Check if rate provided by user is above requirements. Applies only if latency is True."""
377         intf_speed = self.config.generator_config.intf_speed
378         if latency:
379             if bidirectional:
380                 mult = 2
381                 total_rate = 0
382                 for rate in rates:
383                     r = utils.convert_rates(l2frame_size, rate, intf_speed)
384                     total_rate += int(r['rate_pps'])
385             else:
386                 mult = 1
387                 total_rate = utils.convert_rates(l2frame_size, rates[0], intf_speed)
388             # rate must be enough for latency stream and at least 1 pps for base stream per chain
389             required_rate = (self.LATENCY_PPS + 1) * self.config.service_chain_count * mult
390             result = utils.convert_rates(l2frame_size,
391                                          {'rate_pps': required_rate},
392                                          intf_speed * mult)
393             result['result'] = total_rate >= required_rate
394             return result
395
396         return {'result': True}
397
398     def create_traffic(self, l2frame_size, rates, bidirectional, latency=True):
399         r = self.__is_rate_enough(l2frame_size, rates, bidirectional, latency)
400         if not r['result']:
401             raise TrafficGeneratorException(
402                 'Required rate in total is at least one of: \n{pps}pps \n{bps}bps \n{load}%.'
403                 .format(pps=r['rate_pps'],
404                         bps=r['rate_bps'],
405                         load=r['rate_percent']))
406
407         stream_cfgs = [d.get_stream_configs(self.config.generator_config.service_chain)
408                        for d in self.config.generator_config.devices]
409         self.rates = [utils.to_rate_str(rate) for rate in rates]
410
411         for ph in self.port_handle:
412             # generate one pg_id for each direction
413             self.stream_ids[ph] = self.id.next()
414
415         for i, (fwd_stream_cfg, rev_stream_cfg) in enumerate(zip(*stream_cfgs)):
416             if self.config.service_chain == ChainType.EXT and not self.config.no_arp:
417                 fwd_stream_cfg['mac_dst'] = self.arps[self.port_handle[0]][i]
418                 rev_stream_cfg['mac_dst'] = self.arps[self.port_handle[1]][i]
419
420             self.streamblock[0].extend(self.generate_streams(self.port_handle[0],
421                                                              fwd_stream_cfg,
422                                                              l2frame_size,
423                                                              latency=latency))
424             if len(self.rates) > 1:
425                 self.streamblock[1].extend(self.generate_streams(self.port_handle[1],
426                                                                  rev_stream_cfg,
427                                                                  l2frame_size,
428                                                                  isg=10.0,
429                                                                  latency=bidirectional and latency))
430
431         for ph in self.port_handle:
432             self.client.add_streams(self.streamblock[ph], ports=ph)
433             LOG.info('Created traffic stream for port %s.', ph)
434
435     def clear_streamblock(self):
436         self.streamblock = defaultdict(list)
437         self.latencies = defaultdict(list)
438         self.stream_ids = defaultdict(list)
439         self.rates = []
440         self.client.reset(self.port_handle)
441         LOG.info('Cleared all existing streams.')
442
443     def get_stats(self):
444         stats = self.client.get_pgid_stats()
445         return self.extract_stats(stats)
446
447     def get_macs(self):
448         return [self.client.get_port_attr(port=port)['src_mac'] for port in self.port_handle]
449
450     def clear_stats(self):
451         if self.port_handle:
452             self.client.clear_stats()
453
454     def start_traffic(self):
455         for port, rate in zip(self.port_handle, self.rates):
456             self.client.start(ports=port, mult=rate, duration=self.config.duration_sec, force=True)
457
458     def stop_traffic(self):
459         self.client.stop(ports=self.port_handle)
460
461     def start_capture(self):
462         if self.capture_id:
463             self.stop_capture()
464         self.client.set_service_mode(ports=self.port_handle)
465         self.capture_id = self.client.start_capture(rx_ports=self.port_handle)
466
467     def fetch_capture_packets(self):
468         if self.capture_id:
469             self.packet_list = []
470             self.client.fetch_capture_packets(capture_id=self.capture_id['id'],
471                                               output=self.packet_list)
472
473     def stop_capture(self):
474         if self.capture_id:
475             self.client.stop_capture(capture_id=self.capture_id['id'])
476             self.capture_id = None
477             self.client.set_service_mode(ports=self.port_handle, enabled=False)
478
479     def cleanup(self):
480         if self.client:
481             try:
482                 self.client.reset(self.port_handle)
483                 self.client.disconnect()
484             except STLError:
485                 # TRex does not like a reset while in disconnected state
486                 pass