Support to customize sport and dport for UDP traffic
[nfvbench.git] / nfvbench / traffic_gen / trex.py
1 # Copyright 2016 Cisco Systems, Inc.  All rights reserved.
2 #
3 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
4 #    not use this file except in compliance with the License. You may obtain
5 #    a copy of the License at
6 #
7 #         http://www.apache.org/licenses/LICENSE-2.0
8 #
9 #    Unless required by applicable law or agreed to in writing, software
10 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 #    License for the specific language governing permissions and limitations
13 #    under the License.
14
15 from collections import defaultdict
16 from itertools import count
17 from nfvbench.log import LOG
18 from nfvbench.specs import ChainType
19 from nfvbench.traffic_server import TRexTrafficServer
20 from nfvbench.utils import timeout
21 from nfvbench.utils import TimeoutError
22 import os
23 import random
24 import time
25 import traceback
26 from traffic_base import AbstractTrafficGenerator
27 from traffic_base import TrafficGeneratorException
28 import traffic_utils as utils
29
30
31 from trex_stl_lib.api import CTRexVmInsFixHwCs
32 from trex_stl_lib.api import Dot1Q
33 from trex_stl_lib.api import Ether
34 from trex_stl_lib.api import IP
35 from trex_stl_lib.api import STLClient
36 from trex_stl_lib.api import STLError
37 from trex_stl_lib.api import STLFlowLatencyStats
38 from trex_stl_lib.api import STLFlowStats
39 from trex_stl_lib.api import STLPktBuilder
40 from trex_stl_lib.api import STLScVmRaw
41 from trex_stl_lib.api import STLStream
42 from trex_stl_lib.api import STLTXCont
43 from trex_stl_lib.api import STLVmFixChecksumHw
44 from trex_stl_lib.api import STLVmFlowVar
45 from trex_stl_lib.api import STLVmFlowVarRepetableRandom
46 from trex_stl_lib.api import STLVmWrFlowVar
47 from trex_stl_lib.api import UDP
48 from trex_stl_lib.services.trex_stl_service_arp import STLServiceARP
49
50
51 class TRex(AbstractTrafficGenerator):
52
53     LATENCY_PPS = 1000
54
55     def __init__(self, runner):
56         AbstractTrafficGenerator.__init__(self, runner)
57         self.client = None
58         self.id = count()
59         self.latencies = defaultdict(list)
60         self.stream_ids = defaultdict(list)
61         self.port_handle = []
62         self.streamblock = defaultdict(list)
63         self.rates = []
64         self.arps = {}
65
66     def get_version(self):
67         return self.client.get_server_version()
68
69     def extract_stats(self, in_stats):
70         utils.nan_replace(in_stats)
71         LOG.debug(in_stats)
72
73         result = {}
74         for ph in self.port_handle:
75             stats = self.__combine_stats(in_stats, ph)
76             result[ph] = {
77                 'tx': {
78                     'total_pkts': stats['tx_pkts']['total'],
79                     'total_pkt_bytes': stats['tx_bytes']['total'],
80                     'pkt_rate': stats['tx_pps']['total'],
81                     'pkt_bit_rate': stats['tx_bps']['total']
82                 },
83                 'rx': {
84                     'total_pkts': stats['rx_pkts']['total'],
85                     'total_pkt_bytes': stats['rx_bytes']['total'],
86                     'pkt_rate': stats['rx_pps']['total'],
87                     'pkt_bit_rate': stats['rx_bps']['total'],
88                     'dropped_pkts': stats['tx_pkts']['total'] - stats['rx_pkts']['total']
89                 }
90             }
91
92             lat = self.__combine_latencies(in_stats, ph)
93             result[ph]['rx']['max_delay_usec'] = lat.get('total_max', float('nan'))
94             result[ph]['rx']['min_delay_usec'] = lat.get('total_min', float('nan'))
95             result[ph]['rx']['avg_delay_usec'] = lat.get('average', float('nan'))
96
97         total_tx_pkts = result[0]['tx']['total_pkts'] + result[1]['tx']['total_pkts']
98         result["total_tx_rate"] = total_tx_pkts / self.config.duration_sec
99         return result
100
101     def __combine_stats(self, in_stats, port_handle):
102         """Traverses TRex result dictionary and combines stream stats. Used for combining latency
103         and regular streams together.
104         """
105         result = defaultdict(lambda: defaultdict(float))
106
107         for pg_id in [self.stream_ids[port_handle]] + self.latencies[port_handle]:
108             record = in_stats['flow_stats'][pg_id]
109             for stat_type, stat_type_values in record.iteritems():
110                 for ph, value in stat_type_values.iteritems():
111                     result[stat_type][ph] += value
112
113         return result
114
115     def __combine_latencies(self, in_stats, port_handle):
116         """Traverses TRex result dictionary and combines chosen latency stats."""
117         if not len(self.latencies[port_handle]):
118             return {}
119
120         result = defaultdict(float)
121         result['total_min'] = float("inf")
122         for lat_id in self.latencies[port_handle]:
123             lat = in_stats['latency'][lat_id]
124             result['dropped_pkts'] += lat['err_cntrs']['dropped']
125             result['total_max'] = max(lat['latency']['total_max'], result['total_max'])
126             result['total_min'] = min(lat['latency']['total_min'], result['total_min'])
127             result['average'] += lat['latency']['average']
128
129         result['average'] /= len(self.latencies[port_handle])
130
131         return result
132
133     def create_pkt(self, stream_cfg, l2frame_size):
134         # 46 = 14 (Ethernet II) + 4 (CRC Checksum) + 20 (IPv4) + 8 (UDP)
135         payload = 'x' * (max(64, int(l2frame_size)) - 46)
136
137         pkt_base = Ether(src=stream_cfg['mac_src'], dst=stream_cfg['mac_dst'])
138
139         if stream_cfg['vlan_tag'] is not None:
140             pkt_base /= Dot1Q(vlan=stream_cfg['vlan_tag'])
141
142         udp_args = {}
143         if stream_cfg['udp_src_port']:
144             udp_args['sport'] = int(stream_cfg['udp_src_port'])
145         if stream_cfg['udp_dst_port']:
146             udp_args['dport'] = int(stream_cfg['udp_dst_port'])
147         pkt_base /= IP() / UDP(**udp_args)
148
149         if stream_cfg['ip_addrs_step'] == 'random':
150             src_fv = STLVmFlowVarRepetableRandom(
151                 name="ip_src",
152                 min_value=stream_cfg['ip_src_addr'],
153                 max_value=stream_cfg['ip_src_addr_max'],
154                 size=4,
155                 seed=random.randint(0, 32767),
156                 limit=stream_cfg['ip_src_count'])
157             dst_fv = STLVmFlowVarRepetableRandom(
158                 name="ip_dst",
159                 min_value=stream_cfg['ip_dst_addr'],
160                 max_value=stream_cfg['ip_dst_addr_max'],
161                 size=4,
162                 seed=random.randint(0, 32767),
163                 limit=stream_cfg['ip_dst_count'])
164         else:
165             src_fv = STLVmFlowVar(
166                 name="ip_src",
167                 min_value=stream_cfg['ip_src_addr'],
168                 max_value=stream_cfg['ip_src_addr'],
169                 size=4,
170                 op="inc",
171                 step=stream_cfg['ip_addrs_step'])
172             dst_fv = STLVmFlowVar(
173                 name="ip_dst",
174                 min_value=stream_cfg['ip_dst_addr'],
175                 max_value=stream_cfg['ip_dst_addr_max'],
176                 size=4,
177                 op="inc",
178                 step=stream_cfg['ip_addrs_step'])
179
180         vm_param = [
181             src_fv,
182             STLVmWrFlowVar(fv_name="ip_src", pkt_offset="IP.src"),
183             dst_fv,
184             STLVmWrFlowVar(fv_name="ip_dst", pkt_offset="IP.dst"),
185             STLVmFixChecksumHw(l3_offset="IP",
186                                l4_offset="UDP",
187                                l4_type=CTRexVmInsFixHwCs.L4_TYPE_UDP)
188         ]
189
190         return STLPktBuilder(pkt=pkt_base / payload, vm=STLScVmRaw(vm_param))
191
192     def generate_streams(self, port_handle, stream_cfg, l2frame, isg=0.0, latency=True):
193         idx_lat = None
194         streams = []
195         if l2frame == 'IMIX':
196             for t, (ratio, l2_frame_size) in enumerate(zip(self.imix_ratios, self.imix_l2_sizes)):
197                 pkt = self.create_pkt(stream_cfg, l2_frame_size)
198                 streams.append(STLStream(packet=pkt,
199                                          isg=0.1 * t,
200                                          flow_stats=STLFlowStats(
201                                              pg_id=self.stream_ids[port_handle]),
202                                          mode=STLTXCont(pps=ratio)))
203
204             if latency:
205                 idx_lat = self.id.next()
206                 sl = STLStream(packet=pkt,
207                                isg=isg,
208                                flow_stats=STLFlowLatencyStats(pg_id=idx_lat),
209                                mode=STLTXCont(pps=self.LATENCY_PPS))
210                 streams.append(sl)
211         else:
212             pkt = self.create_pkt(stream_cfg, l2frame)
213             streams.append(STLStream(packet=pkt,
214                                      flow_stats=STLFlowStats(pg_id=self.stream_ids[port_handle]),
215                                      mode=STLTXCont()))
216
217             if latency:
218                 idx_lat = self.id.next()
219                 streams.append(STLStream(packet=pkt,
220                                          flow_stats=STLFlowLatencyStats(pg_id=idx_lat),
221                                          mode=STLTXCont(pps=self.LATENCY_PPS)))
222
223         if latency:
224             self.latencies[port_handle].append(idx_lat)
225
226         return streams
227
228     def init(self):
229         pass
230
231     @timeout(5)
232     def __connect(self, client):
233         client.connect()
234
235     def __connect_after_start(self):
236         # after start, Trex may take a bit of time to initialize
237         # so we need to retry a few times
238         for it in xrange(self.config.generic_retry_count):
239             try:
240                 time.sleep(1)
241                 self.client.connect()
242                 break
243             except Exception as ex:
244                 if it == (self.config.generic_retry_count - 1):
245                     raise ex
246                 LOG.info("Retrying connection to TRex (%s)...", ex.message)
247
248     def connect(self):
249         LOG.info("Connecting to TRex...")
250         server_ip = self.config.generator_config.ip
251
252         # Connect to TRex server
253         self.client = STLClient(server=server_ip)
254         try:
255             self.__connect(self.client)
256         except (TimeoutError, STLError) as e:
257             if server_ip == '127.0.0.1':
258                 try:
259                     self.__start_server()
260                     self.__connect_after_start()
261                 except (TimeoutError, STLError) as e:
262                     LOG.error('Cannot connect to TRex')
263                     LOG.error(traceback.format_exc())
264                     logpath = '/tmp/trex.log'
265                     if os.path.isfile(logpath):
266                         # Wait for TRex to finish writing error message
267                         last_size = 0
268                         for it in xrange(self.config.generic_retry_count):
269                             size = os.path.getsize(logpath)
270                             if size == last_size:
271                                 # probably not writing anymore
272                                 break
273                             last_size = size
274                             time.sleep(1)
275                         with open(logpath, 'r') as f:
276                             message = f.read()
277                     else:
278                         message = e.message
279                     raise TrafficGeneratorException(message)
280             else:
281                 raise TrafficGeneratorException(e.message)
282
283         ports = list(self.config.generator_config.ports)
284         self.port_handle = ports
285         # Prepare the ports
286         self.client.reset(ports)
287
288     def set_mode(self):
289         if self.config.service_chain == ChainType.EXT and not self.config.no_arp:
290             self.__set_l3_mode()
291         else:
292             self.__set_l2_mode()
293
294     def __set_l3_mode(self):
295         self.client.set_service_mode(ports=self.port_handle, enabled=True)
296         for port, device in zip(self.port_handle, self.config.generator_config.devices):
297             try:
298                 self.client.set_l3_mode(port=port,
299                                         src_ipv4=device.tg_gateway_ip,
300                                         dst_ipv4=device.dst.gateway_ip,
301                                         vlan=device.vlan_tag if device.vlan_tagging else None)
302             except STLError:
303                 # TRex tries to resolve ARP already, doesn't have to be successful yet
304                 continue
305         self.client.set_service_mode(ports=self.port_handle, enabled=False)
306
307     def __set_l2_mode(self):
308         self.client.set_service_mode(ports=self.port_handle, enabled=True)
309         for port, device in zip(self.port_handle, self.config.generator_config.devices):
310             for cfg in device.get_stream_configs(self.config.generator_config.service_chain):
311                 self.client.set_l2_mode(port=port, dst_mac=cfg['mac_dst'])
312         self.client.set_service_mode(ports=self.port_handle, enabled=False)
313
314     def __start_server(self):
315         server = TRexTrafficServer()
316         server.run_server(self.config.generator_config)
317
318     def resolve_arp(self):
319         self.client.set_service_mode(ports=self.port_handle)
320         LOG.info('Polling ARP until successful')
321         resolved = 0
322         attempt = 0
323         for port, device in zip(self.port_handle, self.config.generator_config.devices):
324             ctx = self.client.create_service_ctx(port=port)
325
326             arps = [
327                 STLServiceARP(ctx,
328                               src_ip=cfg['ip_src_tg_gw'],
329                               dst_ip=cfg['mac_discovery_gw'],
330                               vlan=device.vlan_tag if device.vlan_tagging else None)
331                 for cfg in device.get_stream_configs(self.config.generator_config.service_chain)
332             ]
333
334             for _ in xrange(self.config.generic_retry_count):
335                 attempt += 1
336                 try:
337                     ctx.run(arps)
338                 except STLError:
339                     LOG.error(traceback.format_exc())
340                     continue
341
342                 self.arps[port] = [arp.get_record().dst_mac for arp in arps
343                                    if arp.get_record().dst_mac is not None]
344
345                 if len(self.arps[port]) == self.config.service_chain_count:
346                     resolved += 1
347                     LOG.info('ARP resolved successfully for port {}'.format(port))
348                     break
349                 else:
350                     failed = [arp.get_record().dst_ip for arp in arps
351                               if arp.get_record().dst_mac is None]
352                     LOG.info('Retrying ARP for: {} ({} / {})'.format(
353                         failed, attempt, self.config.generic_retry_count))
354                     time.sleep(self.config.generic_poll_sec)
355
356         self.client.set_service_mode(ports=self.port_handle, enabled=False)
357         return resolved == len(self.port_handle)
358
359     def config_interface(self):
360         pass
361
362     def __is_rate_enough(self, l2frame_size, rates, bidirectional, latency):
363         """Check if rate provided by user is above requirements. Applies only if latency is True."""
364         intf_speed = self.config.generator_config.intf_speed
365         if latency:
366             if bidirectional:
367                 mult = 2
368                 total_rate = 0
369                 for rate in rates:
370                     r = utils.convert_rates(l2frame_size, rate, intf_speed)
371                     total_rate += int(r['rate_pps'])
372             else:
373                 mult = 1
374                 total_rate = utils.convert_rates(l2frame_size, rates[0], intf_speed)
375             # rate must be enough for latency stream and at least 1 pps for base stream per chain
376             required_rate = (self.LATENCY_PPS + 1) * self.config.service_chain_count * mult
377             result = utils.convert_rates(l2frame_size,
378                                          {'rate_pps': required_rate},
379                                          intf_speed * mult)
380             result['result'] = total_rate >= required_rate
381             return result
382
383         return {'result': True}
384
385     def create_traffic(self, l2frame_size, rates, bidirectional, latency=True):
386         r = self.__is_rate_enough(l2frame_size, rates, bidirectional, latency)
387         if not r['result']:
388             raise TrafficGeneratorException(
389                 'Required rate in total is at least one of: \n{pps}pps \n{bps}bps \n{load}%.'
390                 .format(pps=r['rate_pps'],
391                         bps=r['rate_bps'],
392                         load=r['rate_percent']))
393
394         stream_cfgs = [d.get_stream_configs(self.config.generator_config.service_chain)
395                        for d in self.config.generator_config.devices]
396         self.rates = map(lambda rate: utils.to_rate_str(rate), rates)
397
398         for ph in self.port_handle:
399             # generate one pg_id for each direction
400             self.stream_ids[ph] = self.id.next()
401
402         for i, (fwd_stream_cfg, rev_stream_cfg) in enumerate(zip(*stream_cfgs)):
403             if self.config.service_chain == ChainType.EXT and not self.config.no_arp:
404                 fwd_stream_cfg['mac_dst'] = self.arps[self.port_handle[0]][i]
405                 rev_stream_cfg['mac_dst'] = self.arps[self.port_handle[1]][i]
406
407             self.streamblock[0].extend(self.generate_streams(self.port_handle[0],
408                                                              fwd_stream_cfg,
409                                                              l2frame_size,
410                                                              latency=latency))
411             if len(self.rates) > 1:
412                 self.streamblock[1].extend(self.generate_streams(self.port_handle[1],
413                                                                  rev_stream_cfg,
414                                                                  l2frame_size,
415                                                                  isg=10.0,
416                                                                  latency=bidirectional and latency))
417
418         for ph in self.port_handle:
419             self.client.add_streams(self.streamblock[ph], ports=ph)
420             LOG.info('Created traffic stream for port %s.' % ph)
421
422     def modify_rate(self, rate, reverse):
423         port_index = int(reverse)
424         port = self.port_handle[port_index]
425         self.rates[port_index] = utils.to_rate_str(rate)
426         LOG.info('Modified traffic stream for %s, new rate=%s.' % (port, utils.to_rate_str(rate)))
427
428     def clear_streamblock(self):
429         self.streamblock = defaultdict(list)
430         self.latencies = defaultdict(list)
431         self.stream_ids = defaultdict(list)
432         self.rates = []
433         self.client.reset(self.port_handle)
434         LOG.info('Cleared all existing streams.')
435
436     def get_stats(self):
437         stats = self.client.get_pgid_stats()
438         return self.extract_stats(stats)
439
440     def get_macs(self):
441         return [self.client.get_port_attr(port=port)['src_mac'] for port in self.port_handle]
442
443     def clear_stats(self):
444         if self.port_handle:
445             self.client.clear_stats()
446
447     def start_traffic(self):
448         for port, rate in zip(self.port_handle, self.rates):
449             self.client.start(ports=port, mult=rate, duration=self.config.duration_sec, force=True)
450
451     def stop_traffic(self):
452         self.client.stop(ports=self.port_handle)
453
454     def cleanup(self):
455         if self.client:
456             try:
457                 self.client.reset(self.port_handle)
458                 self.client.disconnect()
459             except STLError:
460                 # TRex does not like a reset while in disconnected state
461                 pass