4c9f492512ca51e841f455723263a0bc6a52a355
[nfvbench.git] / nfvbench / traffic_gen / trex.py
1 # Copyright 2016 Cisco Systems, Inc.  All rights reserved.
2 #
3 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
4 #    not use this file except in compliance with the License. You may obtain
5 #    a copy of the License at
6 #
7 #         http://www.apache.org/licenses/LICENSE-2.0
8 #
9 #    Unless required by applicable law or agreed to in writing, software
10 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 #    License for the specific language governing permissions and limitations
13 #    under the License.
14
15 import os
16 import random
17 import time
18 import traceback
19
20 from collections import defaultdict
21 from itertools import count
22 from nfvbench.log import LOG
23 from nfvbench.specs import ChainType
24 from nfvbench.traffic_server import TRexTrafficServer
25 from nfvbench.utils import cast_integer
26 from nfvbench.utils import timeout
27 from nfvbench.utils import TimeoutError
28 from traffic_base import AbstractTrafficGenerator
29 from traffic_base import TrafficGeneratorException
30 import traffic_utils as utils
31
32 # pylint: disable=import-error
33 from trex_stl_lib.api import CTRexVmInsFixHwCs
34 from trex_stl_lib.api import Dot1Q
35 from trex_stl_lib.api import Ether
36 from trex_stl_lib.api import IP
37 from trex_stl_lib.api import STLClient
38 from trex_stl_lib.api import STLError
39 from trex_stl_lib.api import STLFlowLatencyStats
40 from trex_stl_lib.api import STLFlowStats
41 from trex_stl_lib.api import STLPktBuilder
42 from trex_stl_lib.api import STLScVmRaw
43 from trex_stl_lib.api import STLStream
44 from trex_stl_lib.api import STLTXCont
45 from trex_stl_lib.api import STLVmFixChecksumHw
46 from trex_stl_lib.api import STLVmFlowVar
47 from trex_stl_lib.api import STLVmFlowVarRepetableRandom
48 from trex_stl_lib.api import STLVmWrFlowVar
49 from trex_stl_lib.api import UDP
50 from trex_stl_lib.services.trex_stl_service_arp import STLServiceARP
51
52
53 # pylint: enable=import-error
54
55
56 class TRex(AbstractTrafficGenerator):
57     LATENCY_PPS = 1000
58
59     def __init__(self, runner):
60         AbstractTrafficGenerator.__init__(self, runner)
61         self.client = None
62         self.id = count()
63         self.latencies = defaultdict(list)
64         self.stream_ids = defaultdict(list)
65         self.port_handle = []
66         self.streamblock = defaultdict(list)
67         self.rates = []
68         self.arps = {}
69         self.capture_id = None
70         self.packet_list = []
71
72     def get_version(self):
73         return self.client.get_server_version()
74
75     def extract_stats(self, in_stats):
76         utils.nan_replace(in_stats)
77         LOG.debug(in_stats)
78
79         result = {}
80         for ph in self.port_handle:
81             stats = self.__combine_stats(in_stats, ph)
82             result[ph] = {
83                 'tx': {
84                     'total_pkts': cast_integer(stats['tx_pkts']['total']),
85                     'total_pkt_bytes': cast_integer(stats['tx_bytes']['total']),
86                     'pkt_rate': cast_integer(stats['tx_pps']['total']),
87                     'pkt_bit_rate': cast_integer(stats['tx_bps']['total'])
88                 },
89                 'rx': {
90                     'total_pkts': cast_integer(stats['rx_pkts']['total']),
91                     'total_pkt_bytes': cast_integer(stats['rx_bytes']['total']),
92                     'pkt_rate': cast_integer(stats['rx_pps']['total']),
93                     'pkt_bit_rate': cast_integer(stats['rx_bps']['total']),
94                     'dropped_pkts': cast_integer(
95                         stats['tx_pkts']['total'] - stats['rx_pkts']['total'])
96                 }
97             }
98
99             lat = self.__combine_latencies(in_stats, ph)
100             result[ph]['rx']['max_delay_usec'] = cast_integer(
101                 lat['total_max']) if 'total_max' in lat else float('nan')
102             result[ph]['rx']['min_delay_usec'] = cast_integer(
103                 lat['total_min']) if 'total_min' in lat else float('nan')
104             result[ph]['rx']['avg_delay_usec'] = cast_integer(
105                 lat['average']) if 'average' in lat else float('nan')
106         total_tx_pkts = result[0]['tx']['total_pkts'] + result[1]['tx']['total_pkts']
107         result["total_tx_rate"] = cast_integer(total_tx_pkts / self.config.duration_sec)
108         return result
109
110     def __combine_stats(self, in_stats, port_handle):
111         """Traverses TRex result dictionary and combines stream stats. Used for combining latency
112         and regular streams together.
113         """
114         result = defaultdict(lambda: defaultdict(float))
115
116         for pg_id in [self.stream_ids[port_handle]] + self.latencies[port_handle]:
117             record = in_stats['flow_stats'][pg_id]
118             for stat_type, stat_type_values in record.iteritems():
119                 for ph, value in stat_type_values.iteritems():
120                     result[stat_type][ph] += value
121
122         return result
123
124     def __combine_latencies(self, in_stats, port_handle):
125         """Traverses TRex result dictionary and combines chosen latency stats."""
126         if not self.latencies[port_handle]:
127             return {}
128
129         result = defaultdict(float)
130         result['total_min'] = float("inf")
131         for lat_id in self.latencies[port_handle]:
132             lat = in_stats['latency'][lat_id]
133             result['dropped_pkts'] += lat['err_cntrs']['dropped']
134             result['total_max'] = max(lat['latency']['total_max'], result['total_max'])
135             result['total_min'] = min(lat['latency']['total_min'], result['total_min'])
136             result['average'] += lat['latency']['average']
137
138         result['average'] /= len(self.latencies[port_handle])
139
140         return result
141
142     def create_pkt(self, stream_cfg, l2frame_size):
143
144         pkt_base = Ether(src=stream_cfg['mac_src'], dst=stream_cfg['mac_dst'])
145         # TRex requires minimum payload size 16B
146         if stream_cfg['vlan_tag'] is not None:
147             # 50 = 14 (Ethernet II) + 4 (Vlan tag) + 4 (CRC Checksum) + 20 (IPv4) + 8 (UDP)
148             pkt_base /= Dot1Q(vlan=stream_cfg['vlan_tag'])
149             l2payload_size = max(max(64, int(l2frame_size)) - 50, 16)
150         else:
151             # 46 = 14 (Ethernet II) + 4 (CRC Checksum) + 20 (IPv4) + 8 (UDP)
152             l2payload_size = max(max(64, int(l2frame_size)) - 46, 16)
153         payload = 'x' * l2payload_size
154         udp_args = {}
155         if stream_cfg['udp_src_port']:
156             udp_args['sport'] = int(stream_cfg['udp_src_port'])
157         if stream_cfg['udp_dst_port']:
158             udp_args['dport'] = int(stream_cfg['udp_dst_port'])
159         pkt_base /= IP() / UDP(**udp_args)
160
161         if stream_cfg['ip_addrs_step'] == 'random':
162             src_fv = STLVmFlowVarRepetableRandom(
163                 name="ip_src",
164                 min_value=stream_cfg['ip_src_addr'],
165                 max_value=stream_cfg['ip_src_addr_max'],
166                 size=4,
167                 seed=random.randint(0, 32767),
168                 limit=stream_cfg['ip_src_count'])
169             dst_fv = STLVmFlowVarRepetableRandom(
170                 name="ip_dst",
171                 min_value=stream_cfg['ip_dst_addr'],
172                 max_value=stream_cfg['ip_dst_addr_max'],
173                 size=4,
174                 seed=random.randint(0, 32767),
175                 limit=stream_cfg['ip_dst_count'])
176         else:
177             src_fv = STLVmFlowVar(
178                 name="ip_src",
179                 min_value=stream_cfg['ip_src_addr'],
180                 max_value=stream_cfg['ip_src_addr'],
181                 size=4,
182                 op="inc",
183                 step=stream_cfg['ip_addrs_step'])
184             dst_fv = STLVmFlowVar(
185                 name="ip_dst",
186                 min_value=stream_cfg['ip_dst_addr'],
187                 max_value=stream_cfg['ip_dst_addr_max'],
188                 size=4,
189                 op="inc",
190                 step=stream_cfg['ip_addrs_step'])
191
192         vm_param = [
193             src_fv,
194             STLVmWrFlowVar(fv_name="ip_src", pkt_offset="IP.src"),
195             dst_fv,
196             STLVmWrFlowVar(fv_name="ip_dst", pkt_offset="IP.dst"),
197             STLVmFixChecksumHw(l3_offset="IP",
198                                l4_offset="UDP",
199                                l4_type=CTRexVmInsFixHwCs.L4_TYPE_UDP)
200         ]
201
202         return STLPktBuilder(pkt=pkt_base / payload, vm=STLScVmRaw(vm_param))
203
204     def generate_streams(self, port_handle, stream_cfg, l2frame, isg=0.0, latency=True):
205         idx_lat = None
206         streams = []
207         if l2frame == 'IMIX':
208             for t, (ratio, l2_frame_size) in enumerate(zip(self.imix_ratios, self.imix_l2_sizes)):
209                 pkt = self.create_pkt(stream_cfg, l2_frame_size)
210                 streams.append(STLStream(packet=pkt,
211                                          isg=0.1 * t,
212                                          flow_stats=STLFlowStats(
213                                              pg_id=self.stream_ids[port_handle]),
214                                          mode=STLTXCont(pps=ratio)))
215
216             if latency:
217                 idx_lat = self.id.next()
218                 sl = STLStream(packet=pkt,
219                                isg=isg,
220                                flow_stats=STLFlowLatencyStats(pg_id=idx_lat),
221                                mode=STLTXCont(pps=self.LATENCY_PPS))
222                 streams.append(sl)
223         else:
224             pkt = self.create_pkt(stream_cfg, l2frame)
225             streams.append(STLStream(packet=pkt,
226                                      flow_stats=STLFlowStats(pg_id=self.stream_ids[port_handle]),
227                                      mode=STLTXCont()))
228
229             if latency:
230                 idx_lat = self.id.next()
231                 streams.append(STLStream(packet=pkt,
232                                          flow_stats=STLFlowLatencyStats(pg_id=idx_lat),
233                                          mode=STLTXCont(pps=self.LATENCY_PPS)))
234
235         if latency:
236             self.latencies[port_handle].append(idx_lat)
237
238         return streams
239
240     def init(self):
241         pass
242
243     @timeout(5)
244     def __connect(self, client):
245         client.connect()
246
247     def __connect_after_start(self):
248         # after start, Trex may take a bit of time to initialize
249         # so we need to retry a few times
250         for it in xrange(self.config.generic_retry_count):
251             try:
252                 time.sleep(1)
253                 self.client.connect()
254                 break
255             except Exception as ex:
256                 if it == (self.config.generic_retry_count - 1):
257                     raise ex
258                 LOG.info("Retrying connection to TRex (%s)...", ex.message)
259
260     def connect(self):
261         LOG.info("Connecting to TRex...")
262         server_ip = self.config.generator_config.ip
263
264         # Connect to TRex server
265         self.client = STLClient(server=server_ip)
266         try:
267             self.__connect(self.client)
268         except (TimeoutError, STLError) as e:
269             if server_ip == '127.0.0.1':
270                 try:
271                     self.__start_server()
272                     self.__connect_after_start()
273                 except (TimeoutError, STLError) as e:
274                     LOG.error('Cannot connect to TRex')
275                     LOG.error(traceback.format_exc())
276                     logpath = '/tmp/trex.log'
277                     if os.path.isfile(logpath):
278                         # Wait for TRex to finish writing error message
279                         last_size = 0
280                         for _ in xrange(self.config.generic_retry_count):
281                             size = os.path.getsize(logpath)
282                             if size == last_size:
283                                 # probably not writing anymore
284                                 break
285                             last_size = size
286                             time.sleep(1)
287                         with open(logpath, 'r') as f:
288                             message = f.read()
289                     else:
290                         message = e.message
291                     raise TrafficGeneratorException(message)
292             else:
293                 raise TrafficGeneratorException(e.message)
294
295         ports = list(self.config.generator_config.ports)
296         self.port_handle = ports
297         # Prepare the ports
298         self.client.reset(ports)
299
300     def set_mode(self):
301         if self.config.service_chain == ChainType.EXT and not self.config.no_arp:
302             self.__set_l3_mode()
303         else:
304             self.__set_l2_mode()
305
306     def __set_l3_mode(self):
307         self.client.set_service_mode(ports=self.port_handle, enabled=True)
308         for port, device in zip(self.port_handle, self.config.generator_config.devices):
309             try:
310                 self.client.set_l3_mode(port=port,
311                                         src_ipv4=device.tg_gateway_ip,
312                                         dst_ipv4=device.dst.gateway_ip,
313                                         vlan=device.vlan_tag if device.vlan_tagging else None)
314             except STLError:
315                 # TRex tries to resolve ARP already, doesn't have to be successful yet
316                 continue
317         self.client.set_service_mode(ports=self.port_handle, enabled=False)
318
319     def __set_l2_mode(self):
320         self.client.set_service_mode(ports=self.port_handle, enabled=True)
321         for port, device in zip(self.port_handle, self.config.generator_config.devices):
322             for cfg in device.get_stream_configs(self.config.generator_config.service_chain):
323                 self.client.set_l2_mode(port=port, dst_mac=cfg['mac_dst'])
324         self.client.set_service_mode(ports=self.port_handle, enabled=False)
325
326     def __start_server(self):
327         server = TRexTrafficServer()
328         server.run_server(self.config.generator_config, self.config.vlan_tagging)
329
330     def resolve_arp(self):
331         self.client.set_service_mode(ports=self.port_handle)
332         LOG.info('Polling ARP until successful')
333         resolved = 0
334         attempt = 0
335         for port, device in zip(self.port_handle, self.config.generator_config.devices):
336             ctx = self.client.create_service_ctx(port=port)
337
338             arps = [
339                 STLServiceARP(ctx,
340                               src_ip=cfg['ip_src_tg_gw'],
341                               dst_ip=cfg['mac_discovery_gw'],
342                               vlan=device.vlan_tag if device.vlan_tagging else None)
343                 for cfg in device.get_stream_configs(self.config.generator_config.service_chain)
344             ]
345
346             for _ in xrange(self.config.generic_retry_count):
347                 attempt += 1
348                 try:
349                     ctx.run(arps)
350                 except STLError:
351                     LOG.error(traceback.format_exc())
352                     continue
353
354                 self.arps[port] = [arp.get_record().dst_mac for arp in arps
355                                    if arp.get_record().dst_mac is not None]
356
357                 if len(self.arps[port]) == self.config.service_chain_count:
358                     resolved += 1
359                     LOG.info('ARP resolved successfully for port %s', port)
360                     break
361                 else:
362                     failed = [arp.get_record().dst_ip for arp in arps
363                               if arp.get_record().dst_mac is None]
364                     LOG.info('Retrying ARP for: %s (%d / %d)',
365                              failed, attempt, self.config.generic_retry_count)
366                     time.sleep(self.config.generic_poll_sec)
367
368         self.client.set_service_mode(ports=self.port_handle, enabled=False)
369         return resolved == len(self.port_handle)
370
371     def config_interface(self):
372         pass
373
374     def __is_rate_enough(self, l2frame_size, rates, bidirectional, latency):
375         """Check if rate provided by user is above requirements. Applies only if latency is True."""
376         intf_speed = self.config.generator_config.intf_speed
377         if latency:
378             if bidirectional:
379                 mult = 2
380                 total_rate = 0
381                 for rate in rates:
382                     r = utils.convert_rates(l2frame_size, rate, intf_speed)
383                     total_rate += int(r['rate_pps'])
384             else:
385                 mult = 1
386                 total_rate = utils.convert_rates(l2frame_size, rates[0], intf_speed)
387             # rate must be enough for latency stream and at least 1 pps for base stream per chain
388             required_rate = (self.LATENCY_PPS + 1) * self.config.service_chain_count * mult
389             result = utils.convert_rates(l2frame_size,
390                                          {'rate_pps': required_rate},
391                                          intf_speed * mult)
392             result['result'] = total_rate >= required_rate
393             return result
394
395         return {'result': True}
396
397     def create_traffic(self, l2frame_size, rates, bidirectional, latency=True):
398         r = self.__is_rate_enough(l2frame_size, rates, bidirectional, latency)
399         if not r['result']:
400             raise TrafficGeneratorException(
401                 'Required rate in total is at least one of: \n{pps}pps \n{bps}bps \n{load}%.'
402                 .format(pps=r['rate_pps'],
403                         bps=r['rate_bps'],
404                         load=r['rate_percent']))
405
406         stream_cfgs = [d.get_stream_configs(self.config.generator_config.service_chain)
407                        for d in self.config.generator_config.devices]
408         self.rates = [utils.to_rate_str(rate) for rate in rates]
409
410         for ph in self.port_handle:
411             # generate one pg_id for each direction
412             self.stream_ids[ph] = self.id.next()
413
414         for i, (fwd_stream_cfg, rev_stream_cfg) in enumerate(zip(*stream_cfgs)):
415             if self.config.service_chain == ChainType.EXT and not self.config.no_arp:
416                 fwd_stream_cfg['mac_dst'] = self.arps[self.port_handle[0]][i]
417                 rev_stream_cfg['mac_dst'] = self.arps[self.port_handle[1]][i]
418
419             self.streamblock[0].extend(self.generate_streams(self.port_handle[0],
420                                                              fwd_stream_cfg,
421                                                              l2frame_size,
422                                                              latency=latency))
423             if len(self.rates) > 1:
424                 self.streamblock[1].extend(self.generate_streams(self.port_handle[1],
425                                                                  rev_stream_cfg,
426                                                                  l2frame_size,
427                                                                  isg=10.0,
428                                                                  latency=bidirectional and latency))
429
430         for ph in self.port_handle:
431             self.client.add_streams(self.streamblock[ph], ports=ph)
432             LOG.info('Created traffic stream for port %s.', ph)
433
434     def clear_streamblock(self):
435         self.streamblock = defaultdict(list)
436         self.latencies = defaultdict(list)
437         self.stream_ids = defaultdict(list)
438         self.rates = []
439         self.client.reset(self.port_handle)
440         LOG.info('Cleared all existing streams.')
441
442     def get_stats(self):
443         stats = self.client.get_pgid_stats()
444         return self.extract_stats(stats)
445
446     def get_macs(self):
447         return [self.client.get_port_attr(port=port)['src_mac'] for port in self.port_handle]
448
449     def clear_stats(self):
450         if self.port_handle:
451             self.client.clear_stats()
452
453     def start_traffic(self):
454         for port, rate in zip(self.port_handle, self.rates):
455             self.client.start(ports=port, mult=rate, duration=self.config.duration_sec, force=True)
456
457     def stop_traffic(self):
458         self.client.stop(ports=self.port_handle)
459
460     def start_capture(self):
461         if self.capture_id:
462             self.stop_capture()
463         self.client.set_service_mode(ports=self.port_handle)
464         self.capture_id = self.client.start_capture(rx_ports=self.port_handle)
465
466     def fetch_capture_packets(self):
467         if self.capture_id:
468             self.packet_list = []
469             self.client.fetch_capture_packets(capture_id=self.capture_id['id'],
470                                               output=self.packet_list)
471
472     def stop_capture(self):
473         if self.capture_id:
474             self.client.stop_capture(capture_id=self.capture_id['id'])
475             self.capture_id = None
476             self.client.set_service_mode(ports=self.port_handle, enabled=False)
477
478     def cleanup(self):
479         if self.client:
480             try:
481                 self.client.reset(self.port_handle)
482                 self.client.disconnect()
483             except STLError:
484                 # TRex does not like a reset while in disconnected state
485                 pass