Merge "Example as code, documentation template for sphinx build"
[domino.git] / lib / thrift / TTornado.py
1 #
2 # Licensed to the Apache Software Foundation (ASF) under one
3 # or more contributor license agreements. See the NOTICE file
4 # distributed with this work for additional information
5 # regarding copyright ownership. The ASF licenses this file
6 # to you under the Apache License, Version 2.0 (the
7 # "License"); you may not use this file except in compliance
8 # with the License. You may obtain a copy of the License at
9 #
10 #   http://www.apache.org/licenses/LICENSE-2.0
11 #
12 # Unless required by applicable law or agreed to in writing,
13 # software distributed under the License is distributed on an
14 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 # KIND, either express or implied. See the License for the
16 # specific language governing permissions and limitations
17 # under the License.
18 #
19
20 from __future__ import absolute_import
21 import socket
22 import struct
23
24 import logging
25 logger = logging.getLogger(__name__)
26
27 from thrift.transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
28
29 from io import BytesIO
30 from collections import deque
31 from contextlib import contextmanager
32 from tornado import gen, iostream, ioloop, tcpserver, concurrent
33
34 __all__ = ['TTornadoServer', 'TTornadoStreamTransport']
35
36
37 class _Lock(object):
38     def __init__(self):
39         self._waiters = deque()
40
41     def acquired(self):
42         return len(self._waiters) > 0
43
44     @gen.coroutine
45     def acquire(self):
46         blocker = self._waiters[-1] if self.acquired() else None
47         future = concurrent.Future()
48         self._waiters.append(future)
49         if blocker:
50             yield blocker
51
52         raise gen.Return(self._lock_context())
53
54     def release(self):
55         assert self.acquired(), 'Lock not aquired'
56         future = self._waiters.popleft()
57         future.set_result(None)
58
59     @contextmanager
60     def _lock_context(self):
61         try:
62             yield
63         finally:
64             self.release()
65
66
67 class TTornadoStreamTransport(TTransportBase):
68     """a framed, buffered transport over a Tornado stream"""
69     def __init__(self, host, port, stream=None, io_loop=None):
70         self.host = host
71         self.port = port
72         self.io_loop = io_loop or ioloop.IOLoop.current()
73         self.__wbuf = BytesIO()
74         self._read_lock = _Lock()
75
76         # servers provide a ready-to-go stream
77         self.stream = stream
78
79     def with_timeout(self, timeout, future):
80         return gen.with_timeout(timeout, future, self.io_loop)
81
82     @gen.coroutine
83     def open(self, timeout=None):
84         logger.debug('socket connecting')
85         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
86         self.stream = iostream.IOStream(sock)
87
88         try:
89             connect = self.stream.connect((self.host, self.port))
90             if timeout is not None:
91                 yield self.with_timeout(timeout, connect)
92             else:
93                 yield connect
94         except (socket.error, IOError, ioloop.TimeoutError) as e:
95             message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
96             raise TTransportException(
97                 type=TTransportException.NOT_OPEN,
98                 message=message)
99
100         raise gen.Return(self)
101
102     def set_close_callback(self, callback):
103         """
104         Should be called only after open() returns
105         """
106         self.stream.set_close_callback(callback)
107
108     def close(self):
109         # don't raise if we intend to close
110         self.stream.set_close_callback(None)
111         self.stream.close()
112
113     def read(self, _):
114         # The generated code for Tornado shouldn't do individual reads -- only
115         # frames at a time
116         assert False, "you're doing it wrong"
117
118     @contextmanager
119     def io_exception_context(self):
120         try:
121             yield
122         except (socket.error, IOError) as e:
123             raise TTransportException(
124                 type=TTransportException.END_OF_FILE,
125                 message=str(e))
126         except iostream.StreamBufferFullError as e:
127             raise TTransportException(
128                 type=TTransportException.UNKNOWN,
129                 message=str(e))
130
131     @gen.coroutine
132     def readFrame(self):
133         # IOStream processes reads one at a time
134         with (yield self._read_lock.acquire()):
135             with self.io_exception_context():
136                 frame_header = yield self.stream.read_bytes(4)
137                 if len(frame_header) == 0:
138                     raise iostream.StreamClosedError('Read zero bytes from stream')
139                 frame_length, = struct.unpack('!i', frame_header)
140                 frame = yield self.stream.read_bytes(frame_length)
141                 raise gen.Return(frame)
142
143     def write(self, buf):
144         self.__wbuf.write(buf)
145
146     def flush(self):
147         frame = self.__wbuf.getvalue()
148         # reset wbuf before write/flush to preserve state on underlying failure
149         frame_length = struct.pack('!i', len(frame))
150         self.__wbuf = BytesIO()
151         with self.io_exception_context():
152             return self.stream.write(frame_length + frame)
153
154
155 class TTornadoServer(tcpserver.TCPServer):
156     def __init__(self, processor, iprot_factory, oprot_factory=None,
157                  *args, **kwargs):
158         super(TTornadoServer, self).__init__(*args, **kwargs)
159
160         self._processor = processor
161         self._iprot_factory = iprot_factory
162         self._oprot_factory = (oprot_factory if oprot_factory is not None
163                                else iprot_factory)
164
165     @gen.coroutine
166     def handle_stream(self, stream, address):
167         host, port = address
168         trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
169                                         io_loop=self.io_loop)
170         oprot = self._oprot_factory.getProtocol(trans)
171
172         try:
173             while not trans.stream.closed():
174                 frame = yield trans.readFrame()
175                 tr = TMemoryBuffer(frame)
176                 iprot = self._iprot_factory.getProtocol(tr)
177                 yield self._processor.process(iprot, oprot)
178         except Exception:
179             logger.exception('thrift exception in handle_stream')
180             trans.close()
181
182         logger.info('client disconnected %s:%d', host, port)