2 # Licensed to the Apache Software Foundation (ASF) under one
3 # or more contributor license agreements. See the NOTICE file
4 # distributed with this work for additional information
5 # regarding copyright ownership. The ASF licenses this file
6 # to you under the Apache License, Version 2.0 (the
7 # "License"); you may not use this file except in compliance
8 # with the License. You may obtain a copy of the License at
10 # http://www.apache.org/licenses/LICENSE-2.0
12 # Unless required by applicable law or agreed to in writing,
13 # software distributed under the License is distributed on an
14 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 # KIND, either express or implied. See the License for the
16 # specific language governing permissions and limitations
20 from __future__ import absolute_import
25 logger = logging.getLogger(__name__)
27 from thrift.transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
29 from io import BytesIO
30 from collections import deque
31 from contextlib import contextmanager
32 from tornado import gen, iostream, ioloop, tcpserver, concurrent
34 __all__ = ['TTornadoServer', 'TTornadoStreamTransport']
39 self._waiters = deque()
42 return len(self._waiters) > 0
46 blocker = self._waiters[-1] if self.acquired() else None
47 future = concurrent.Future()
48 self._waiters.append(future)
52 raise gen.Return(self._lock_context())
55 assert self.acquired(), 'Lock not aquired'
56 future = self._waiters.popleft()
57 future.set_result(None)
60 def _lock_context(self):
67 class TTornadoStreamTransport(TTransportBase):
68 """a framed, buffered transport over a Tornado stream"""
69 def __init__(self, host, port, stream=None, io_loop=None):
72 self.io_loop = io_loop or ioloop.IOLoop.current()
73 self.__wbuf = BytesIO()
74 self._read_lock = _Lock()
76 # servers provide a ready-to-go stream
79 def with_timeout(self, timeout, future):
80 return gen.with_timeout(timeout, future, self.io_loop)
83 def open(self, timeout=None):
84 logger.debug('socket connecting')
85 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
86 self.stream = iostream.IOStream(sock)
89 connect = self.stream.connect((self.host, self.port))
90 if timeout is not None:
91 yield self.with_timeout(timeout, connect)
94 except (socket.error, IOError, ioloop.TimeoutError) as e:
95 message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
96 raise TTransportException(
97 type=TTransportException.NOT_OPEN,
100 raise gen.Return(self)
102 def set_close_callback(self, callback):
104 Should be called only after open() returns
106 self.stream.set_close_callback(callback)
109 # don't raise if we intend to close
110 self.stream.set_close_callback(None)
114 # The generated code for Tornado shouldn't do individual reads -- only
116 assert False, "you're doing it wrong"
119 def io_exception_context(self):
122 except (socket.error, IOError) as e:
123 raise TTransportException(
124 type=TTransportException.END_OF_FILE,
126 except iostream.StreamBufferFullError as e:
127 raise TTransportException(
128 type=TTransportException.UNKNOWN,
133 # IOStream processes reads one at a time
134 with (yield self._read_lock.acquire()):
135 with self.io_exception_context():
136 frame_header = yield self.stream.read_bytes(4)
137 if len(frame_header) == 0:
138 raise iostream.StreamClosedError('Read zero bytes from stream')
139 frame_length, = struct.unpack('!i', frame_header)
140 frame = yield self.stream.read_bytes(frame_length)
141 raise gen.Return(frame)
143 def write(self, buf):
144 self.__wbuf.write(buf)
147 frame = self.__wbuf.getvalue()
148 # reset wbuf before write/flush to preserve state on underlying failure
149 frame_length = struct.pack('!i', len(frame))
150 self.__wbuf = BytesIO()
151 with self.io_exception_context():
152 return self.stream.write(frame_length + frame)
155 class TTornadoServer(tcpserver.TCPServer):
156 def __init__(self, processor, iprot_factory, oprot_factory=None,
158 super(TTornadoServer, self).__init__(*args, **kwargs)
160 self._processor = processor
161 self._iprot_factory = iprot_factory
162 self._oprot_factory = (oprot_factory if oprot_factory is not None
166 def handle_stream(self, stream, address):
168 trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
169 io_loop=self.io_loop)
170 oprot = self._oprot_factory.getProtocol(trans)
173 while not trans.stream.closed():
174 frame = yield trans.readFrame()
175 tr = TMemoryBuffer(frame)
176 iprot = self._iprot_factory.getProtocol(tr)
177 yield self._processor.process(iprot, oprot)
179 logger.exception('thrift exception in handle_stream')
182 logger.info('client disconnected %s:%d', host, port)