Functest test script file update
[domino.git] / lib / thrift / transport / TTransport.py
1 #
2 # Licensed to the Apache Software Foundation (ASF) under one
3 # or more contributor license agreements. See the NOTICE file
4 # distributed with this work for additional information
5 # regarding copyright ownership. The ASF licenses this file
6 # to you under the Apache License, Version 2.0 (the
7 # "License"); you may not use this file except in compliance
8 # with the License. You may obtain a copy of the License at
9 #
10 #   http://www.apache.org/licenses/LICENSE-2.0
11 #
12 # Unless required by applicable law or agreed to in writing,
13 # software distributed under the License is distributed on an
14 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 # KIND, either express or implied. See the License for the
16 # specific language governing permissions and limitations
17 # under the License.
18 #
19
20 from cStringIO import StringIO
21 from struct import pack, unpack
22 from thrift.Thrift import TException
23
24
25 class TTransportException(TException):
26   """Custom Transport Exception class"""
27
28   UNKNOWN = 0
29   NOT_OPEN = 1
30   ALREADY_OPEN = 2
31   TIMED_OUT = 3
32   END_OF_FILE = 4
33
34   def __init__(self, type=UNKNOWN, message=None):
35     TException.__init__(self, message)
36     self.type = type
37
38
39 class TTransportBase:
40   """Base class for Thrift transport layer."""
41
42   def isOpen(self):
43     pass
44
45   def open(self):
46     pass
47
48   def close(self):
49     pass
50
51   def read(self, sz):
52     pass
53
54   def readAll(self, sz):
55     buff = ''
56     have = 0
57     while (have < sz):
58       chunk = self.read(sz - have)
59       have += len(chunk)
60       buff += chunk
61
62       if len(chunk) == 0:
63         raise EOFError()
64
65     return buff
66
67   def write(self, buf):
68     pass
69
70   def flush(self):
71     pass
72
73
74 # This class should be thought of as an interface.
75 class CReadableTransport:
76   """base class for transports that are readable from C"""
77
78   # TODO(dreiss): Think about changing this interface to allow us to use
79   #               a (Python, not c) StringIO instead, because it allows
80   #               you to write after reading.
81
82   # NOTE: This is a classic class, so properties will NOT work
83   #       correctly for setting.
84   @property
85   def cstringio_buf(self):
86     """A cStringIO buffer that contains the current chunk we are reading."""
87     pass
88
89   def cstringio_refill(self, partialread, reqlen):
90     """Refills cstringio_buf.
91
92     Returns the currently used buffer (which can but need not be the same as
93     the old cstringio_buf). partialread is what the C code has read from the
94     buffer, and should be inserted into the buffer before any more reads.  The
95     return value must be a new, not borrowed reference.  Something along the
96     lines of self._buf should be fine.
97
98     If reqlen bytes can't be read, throw EOFError.
99     """
100     pass
101
102
103 class TServerTransportBase:
104   """Base class for Thrift server transports."""
105
106   def listen(self):
107     pass
108
109   def accept(self):
110     pass
111
112   def close(self):
113     pass
114
115
116 class TTransportFactoryBase:
117   """Base class for a Transport Factory"""
118
119   def getTransport(self, trans):
120     return trans
121
122
123 class TBufferedTransportFactory:
124   """Factory transport that builds buffered transports"""
125
126   def getTransport(self, trans):
127     buffered = TBufferedTransport(trans)
128     return buffered
129
130
131 class TBufferedTransport(TTransportBase, CReadableTransport):
132   """Class that wraps another transport and buffers its I/O.
133
134   The implementation uses a (configurable) fixed-size read buffer
135   but buffers all writes until a flush is performed.
136   """
137   DEFAULT_BUFFER = 4096
138
139   def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
140     self.__trans = trans
141     self.__wbuf = StringIO()
142     self.__rbuf = StringIO("")
143     self.__rbuf_size = rbuf_size
144
145   def isOpen(self):
146     return self.__trans.isOpen()
147
148   def open(self):
149     return self.__trans.open()
150
151   def close(self):
152     return self.__trans.close()
153
154   def read(self, sz):
155     ret = self.__rbuf.read(sz)
156     if len(ret) != 0:
157       return ret
158
159     self.__rbuf = StringIO(self.__trans.read(max(sz, self.__rbuf_size)))
160     return self.__rbuf.read(sz)
161
162   def write(self, buf):
163     try:
164       self.__wbuf.write(buf)
165     except Exception as e:
166       # on exception reset wbuf so it doesn't contain a partial function call
167       self.__wbuf = StringIO()
168       raise e
169
170   def flush(self):
171     out = self.__wbuf.getvalue()
172     # reset wbuf before write/flush to preserve state on underlying failure
173     self.__wbuf = StringIO()
174     self.__trans.write(out)
175     self.__trans.flush()
176
177   # Implement the CReadableTransport interface.
178   @property
179   def cstringio_buf(self):
180     return self.__rbuf
181
182   def cstringio_refill(self, partialread, reqlen):
183     retstring = partialread
184     if reqlen < self.__rbuf_size:
185       # try to make a read of as much as we can.
186       retstring += self.__trans.read(self.__rbuf_size)
187
188     # but make sure we do read reqlen bytes.
189     if len(retstring) < reqlen:
190       retstring += self.__trans.readAll(reqlen - len(retstring))
191
192     self.__rbuf = StringIO(retstring)
193     return self.__rbuf
194
195
196 class TMemoryBuffer(TTransportBase, CReadableTransport):
197   """Wraps a cStringIO object as a TTransport.
198
199   NOTE: Unlike the C++ version of this class, you cannot write to it
200         then immediately read from it.  If you want to read from a
201         TMemoryBuffer, you must either pass a string to the constructor.
202   TODO(dreiss): Make this work like the C++ version.
203   """
204
205   def __init__(self, value=None):
206     """value -- a value to read from for stringio
207
208     If value is set, this will be a transport for reading,
209     otherwise, it is for writing"""
210     if value is not None:
211       self._buffer = StringIO(value)
212     else:
213       self._buffer = StringIO()
214
215   def isOpen(self):
216     return not self._buffer.closed
217
218   def open(self):
219     pass
220
221   def close(self):
222     self._buffer.close()
223
224   def read(self, sz):
225     return self._buffer.read(sz)
226
227   def write(self, buf):
228     self._buffer.write(buf)
229
230   def flush(self):
231     pass
232
233   def getvalue(self):
234     return self._buffer.getvalue()
235
236   # Implement the CReadableTransport interface.
237   @property
238   def cstringio_buf(self):
239     return self._buffer
240
241   def cstringio_refill(self, partialread, reqlen):
242     # only one shot at reading...
243     raise EOFError()
244
245
246 class TFramedTransportFactory:
247   """Factory transport that builds framed transports"""
248
249   def getTransport(self, trans):
250     framed = TFramedTransport(trans)
251     return framed
252
253
254 class TFramedTransport(TTransportBase, CReadableTransport):
255   """Class that wraps another transport and frames its I/O when writing."""
256
257   def __init__(self, trans,):
258     self.__trans = trans
259     self.__rbuf = StringIO()
260     self.__wbuf = StringIO()
261
262   def isOpen(self):
263     return self.__trans.isOpen()
264
265   def open(self):
266     return self.__trans.open()
267
268   def close(self):
269     return self.__trans.close()
270
271   def read(self, sz):
272     ret = self.__rbuf.read(sz)
273     if len(ret) != 0:
274       return ret
275
276     self.readFrame()
277     return self.__rbuf.read(sz)
278
279   def readFrame(self):
280     buff = self.__trans.readAll(4)
281     sz, = unpack('!i', buff)
282     self.__rbuf = StringIO(self.__trans.readAll(sz))
283
284   def write(self, buf):
285     self.__wbuf.write(buf)
286
287   def flush(self):
288     wout = self.__wbuf.getvalue()
289     wsz = len(wout)
290     # reset wbuf before write/flush to preserve state on underlying failure
291     self.__wbuf = StringIO()
292     # N.B.: Doing this string concatenation is WAY cheaper than making
293     # two separate calls to the underlying socket object. Socket writes in
294     # Python turn out to be REALLY expensive, but it seems to do a pretty
295     # good job of managing string buffer operations without excessive copies
296     buf = pack("!i", wsz) + wout
297     self.__trans.write(buf)
298     self.__trans.flush()
299
300   # Implement the CReadableTransport interface.
301   @property
302   def cstringio_buf(self):
303     return self.__rbuf
304
305   def cstringio_refill(self, prefix, reqlen):
306     # self.__rbuf will already be empty here because fastbinary doesn't
307     # ask for a refill until the previous buffer is empty.  Therefore,
308     # we can start reading new frames immediately.
309     while len(prefix) < reqlen:
310       self.readFrame()
311       prefix += self.__rbuf.getvalue()
312     self.__rbuf = StringIO(prefix)
313     return self.__rbuf
314
315
316 class TFileObjectTransport(TTransportBase):
317   """Wraps a file-like object to make it work as a Thrift transport."""
318
319   def __init__(self, fileobj):
320     self.fileobj = fileobj
321
322   def isOpen(self):
323     return True
324
325   def close(self):
326     self.fileobj.close()
327
328   def read(self, sz):
329     return self.fileobj.read(sz)
330
331   def write(self, buf):
332     self.fileobj.write(buf)
333
334   def flush(self):
335     self.fileobj.flush()
336
337
338 class TSaslClientTransport(TTransportBase, CReadableTransport):
339   """
340   SASL transport 
341   """
342
343   START = 1
344   OK = 2
345   BAD = 3
346   ERROR = 4
347   COMPLETE = 5
348
349   def __init__(self, transport, host, service, mechanism='GSSAPI',
350       **sasl_kwargs):
351     """
352     transport: an underlying transport to use, typically just a TSocket
353     host: the name of the server, from a SASL perspective
354     service: the name of the server's service, from a SASL perspective
355     mechanism: the name of the preferred mechanism to use
356
357     All other kwargs will be passed to the puresasl.client.SASLClient
358     constructor.
359     """
360
361     from puresasl.client import SASLClient
362
363     self.transport = transport
364     self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
365
366     self.__wbuf = StringIO()
367     self.__rbuf = StringIO()
368
369   def open(self):
370     if not self.transport.isOpen():
371       self.transport.open()
372
373     self.send_sasl_msg(self.START, self.sasl.mechanism)
374     self.send_sasl_msg(self.OK, self.sasl.process())
375
376     while True:
377       status, challenge = self.recv_sasl_msg()
378       if status == self.OK:
379         self.send_sasl_msg(self.OK, self.sasl.process(challenge))
380       elif status == self.COMPLETE:
381         if not self.sasl.complete:
382           raise TTransportException("The server erroneously indicated "
383               "that SASL negotiation was complete")
384         else:
385           break
386       else:
387         raise TTransportException("Bad SASL negotiation status: %d (%s)"
388             % (status, challenge))
389
390   def send_sasl_msg(self, status, body):
391     header = pack(">BI", status, len(body))
392     self.transport.write(header + body)
393     self.transport.flush()
394
395   def recv_sasl_msg(self):
396     header = self.transport.readAll(5)
397     status, length = unpack(">BI", header)
398     if length > 0:
399       payload = self.transport.readAll(length)
400     else:
401       payload = ""
402     return status, payload
403
404   def write(self, data):
405     self.__wbuf.write(data)
406
407   def flush(self):
408     data = self.__wbuf.getvalue()
409     encoded = self.sasl.wrap(data)
410     self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
411     self.transport.flush()
412     self.__wbuf = StringIO()
413
414   def read(self, sz):
415     ret = self.__rbuf.read(sz)
416     if len(ret) != 0:
417       return ret
418
419     self._read_frame()
420     return self.__rbuf.read(sz)
421
422   def _read_frame(self):
423     header = self.transport.readAll(4)
424     length, = unpack('!i', header)
425     encoded = self.transport.readAll(length)
426     self.__rbuf = StringIO(self.sasl.unwrap(encoded))
427
428   def close(self):
429     self.sasl.dispose()
430     self.transport.close()
431
432   # based on TFramedTransport
433   @property
434   def cstringio_buf(self):
435     return self.__rbuf
436
437   def cstringio_refill(self, prefix, reqlen):
438     # self.__rbuf will already be empty here because fastbinary doesn't
439     # ask for a refill until the previous buffer is empty.  Therefore,
440     # we can start reading new frames immediately.
441     while len(prefix) < reqlen:
442       self._read_frame()
443       prefix += self.__rbuf.getvalue()
444     self.__rbuf = StringIO(prefix)
445     return self.__rbuf
446