2 # Licensed to the Apache Software Foundation (ASF) under one
3 # or more contributor license agreements. See the NOTICE file
4 # distributed with this work for additional information
5 # regarding copyright ownership. The ASF licenses this file
6 # to you under the Apache License, Version 2.0 (the
7 # "License"); you may not use this file except in compliance
8 # with the License. You may obtain a copy of the License at
10 # http://www.apache.org/licenses/LICENSE-2.0
12 # Unless required by applicable law or agreed to in writing,
13 # software distributed under the License is distributed on an
14 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 # KIND, either express or implied. See the License for the
16 # specific language governing permissions and limitations
20 from cStringIO import StringIO
21 from struct import pack, unpack
22 from thrift.Thrift import TException
25 class TTransportException(TException):
26 """Custom Transport Exception class"""
34 def __init__(self, type=UNKNOWN, message=None):
35 TException.__init__(self, message)
40 """Base class for Thrift transport layer."""
54 def readAll(self, sz):
58 chunk = self.read(sz - have)
74 # This class should be thought of as an interface.
75 class CReadableTransport:
76 """base class for transports that are readable from C"""
78 # TODO(dreiss): Think about changing this interface to allow us to use
79 # a (Python, not c) StringIO instead, because it allows
80 # you to write after reading.
82 # NOTE: This is a classic class, so properties will NOT work
83 # correctly for setting.
85 def cstringio_buf(self):
86 """A cStringIO buffer that contains the current chunk we are reading."""
89 def cstringio_refill(self, partialread, reqlen):
90 """Refills cstringio_buf.
92 Returns the currently used buffer (which can but need not be the same as
93 the old cstringio_buf). partialread is what the C code has read from the
94 buffer, and should be inserted into the buffer before any more reads. The
95 return value must be a new, not borrowed reference. Something along the
96 lines of self._buf should be fine.
98 If reqlen bytes can't be read, throw EOFError.
103 class TServerTransportBase:
104 """Base class for Thrift server transports."""
116 class TTransportFactoryBase:
117 """Base class for a Transport Factory"""
119 def getTransport(self, trans):
123 class TBufferedTransportFactory:
124 """Factory transport that builds buffered transports"""
126 def getTransport(self, trans):
127 buffered = TBufferedTransport(trans)
131 class TBufferedTransport(TTransportBase, CReadableTransport):
132 """Class that wraps another transport and buffers its I/O.
134 The implementation uses a (configurable) fixed-size read buffer
135 but buffers all writes until a flush is performed.
137 DEFAULT_BUFFER = 4096
139 def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
141 self.__wbuf = StringIO()
142 self.__rbuf = StringIO("")
143 self.__rbuf_size = rbuf_size
146 return self.__trans.isOpen()
149 return self.__trans.open()
152 return self.__trans.close()
155 ret = self.__rbuf.read(sz)
159 self.__rbuf = StringIO(self.__trans.read(max(sz, self.__rbuf_size)))
160 return self.__rbuf.read(sz)
162 def write(self, buf):
164 self.__wbuf.write(buf)
165 except Exception as e:
166 # on exception reset wbuf so it doesn't contain a partial function call
167 self.__wbuf = StringIO()
171 out = self.__wbuf.getvalue()
172 # reset wbuf before write/flush to preserve state on underlying failure
173 self.__wbuf = StringIO()
174 self.__trans.write(out)
177 # Implement the CReadableTransport interface.
179 def cstringio_buf(self):
182 def cstringio_refill(self, partialread, reqlen):
183 retstring = partialread
184 if reqlen < self.__rbuf_size:
185 # try to make a read of as much as we can.
186 retstring += self.__trans.read(self.__rbuf_size)
188 # but make sure we do read reqlen bytes.
189 if len(retstring) < reqlen:
190 retstring += self.__trans.readAll(reqlen - len(retstring))
192 self.__rbuf = StringIO(retstring)
196 class TMemoryBuffer(TTransportBase, CReadableTransport):
197 """Wraps a cStringIO object as a TTransport.
199 NOTE: Unlike the C++ version of this class, you cannot write to it
200 then immediately read from it. If you want to read from a
201 TMemoryBuffer, you must either pass a string to the constructor.
202 TODO(dreiss): Make this work like the C++ version.
205 def __init__(self, value=None):
206 """value -- a value to read from for stringio
208 If value is set, this will be a transport for reading,
209 otherwise, it is for writing"""
210 if value is not None:
211 self._buffer = StringIO(value)
213 self._buffer = StringIO()
216 return not self._buffer.closed
225 return self._buffer.read(sz)
227 def write(self, buf):
228 self._buffer.write(buf)
234 return self._buffer.getvalue()
236 # Implement the CReadableTransport interface.
238 def cstringio_buf(self):
241 def cstringio_refill(self, partialread, reqlen):
242 # only one shot at reading...
246 class TFramedTransportFactory:
247 """Factory transport that builds framed transports"""
249 def getTransport(self, trans):
250 framed = TFramedTransport(trans)
254 class TFramedTransport(TTransportBase, CReadableTransport):
255 """Class that wraps another transport and frames its I/O when writing."""
257 def __init__(self, trans,):
259 self.__rbuf = StringIO()
260 self.__wbuf = StringIO()
263 return self.__trans.isOpen()
266 return self.__trans.open()
269 return self.__trans.close()
272 ret = self.__rbuf.read(sz)
277 return self.__rbuf.read(sz)
280 buff = self.__trans.readAll(4)
281 sz, = unpack('!i', buff)
282 self.__rbuf = StringIO(self.__trans.readAll(sz))
284 def write(self, buf):
285 self.__wbuf.write(buf)
288 wout = self.__wbuf.getvalue()
290 # reset wbuf before write/flush to preserve state on underlying failure
291 self.__wbuf = StringIO()
292 # N.B.: Doing this string concatenation is WAY cheaper than making
293 # two separate calls to the underlying socket object. Socket writes in
294 # Python turn out to be REALLY expensive, but it seems to do a pretty
295 # good job of managing string buffer operations without excessive copies
296 buf = pack("!i", wsz) + wout
297 self.__trans.write(buf)
300 # Implement the CReadableTransport interface.
302 def cstringio_buf(self):
305 def cstringio_refill(self, prefix, reqlen):
306 # self.__rbuf will already be empty here because fastbinary doesn't
307 # ask for a refill until the previous buffer is empty. Therefore,
308 # we can start reading new frames immediately.
309 while len(prefix) < reqlen:
311 prefix += self.__rbuf.getvalue()
312 self.__rbuf = StringIO(prefix)
316 class TFileObjectTransport(TTransportBase):
317 """Wraps a file-like object to make it work as a Thrift transport."""
319 def __init__(self, fileobj):
320 self.fileobj = fileobj
329 return self.fileobj.read(sz)
331 def write(self, buf):
332 self.fileobj.write(buf)
338 class TSaslClientTransport(TTransportBase, CReadableTransport):
349 def __init__(self, transport, host, service, mechanism='GSSAPI',
352 transport: an underlying transport to use, typically just a TSocket
353 host: the name of the server, from a SASL perspective
354 service: the name of the server's service, from a SASL perspective
355 mechanism: the name of the preferred mechanism to use
357 All other kwargs will be passed to the puresasl.client.SASLClient
361 from puresasl.client import SASLClient
363 self.transport = transport
364 self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
366 self.__wbuf = StringIO()
367 self.__rbuf = StringIO()
370 if not self.transport.isOpen():
371 self.transport.open()
373 self.send_sasl_msg(self.START, self.sasl.mechanism)
374 self.send_sasl_msg(self.OK, self.sasl.process())
377 status, challenge = self.recv_sasl_msg()
378 if status == self.OK:
379 self.send_sasl_msg(self.OK, self.sasl.process(challenge))
380 elif status == self.COMPLETE:
381 if not self.sasl.complete:
382 raise TTransportException("The server erroneously indicated "
383 "that SASL negotiation was complete")
387 raise TTransportException("Bad SASL negotiation status: %d (%s)"
388 % (status, challenge))
390 def send_sasl_msg(self, status, body):
391 header = pack(">BI", status, len(body))
392 self.transport.write(header + body)
393 self.transport.flush()
395 def recv_sasl_msg(self):
396 header = self.transport.readAll(5)
397 status, length = unpack(">BI", header)
399 payload = self.transport.readAll(length)
402 return status, payload
404 def write(self, data):
405 self.__wbuf.write(data)
408 data = self.__wbuf.getvalue()
409 encoded = self.sasl.wrap(data)
410 self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
411 self.transport.flush()
412 self.__wbuf = StringIO()
415 ret = self.__rbuf.read(sz)
420 return self.__rbuf.read(sz)
422 def _read_frame(self):
423 header = self.transport.readAll(4)
424 length, = unpack('!i', header)
425 encoded = self.transport.readAll(length)
426 self.__rbuf = StringIO(self.sasl.unwrap(encoded))
430 self.transport.close()
432 # based on TFramedTransport
434 def cstringio_buf(self):
437 def cstringio_refill(self, prefix, reqlen):
438 # self.__rbuf will already be empty here because fastbinary doesn't
439 # ask for a refill until the previous buffer is empty. Therefore,
440 # we can start reading new frames immediately.
441 while len(prefix) < reqlen:
443 prefix += self.__rbuf.getvalue()
444 self.__rbuf = StringIO(prefix)