Simple mapper/scheduler/partitioner functions implemented
[domino.git] / lib / thrift / server / TNonblockingServer.py
1 #
2 # Licensed to the Apache Software Foundation (ASF) under one
3 # or more contributor license agreements. See the NOTICE file
4 # distributed with this work for additional information
5 # regarding copyright ownership. The ASF licenses this file
6 # to you under the Apache License, Version 2.0 (the
7 # "License"); you may not use this file except in compliance
8 # with the License. You may obtain a copy of the License at
9 #
10 #   http://www.apache.org/licenses/LICENSE-2.0
11 #
12 # Unless required by applicable law or agreed to in writing,
13 # software distributed under the License is distributed on an
14 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 # KIND, either express or implied. See the License for the
16 # specific language governing permissions and limitations
17 # under the License.
18 #
19 """Implementation of non-blocking server.
20
21 The main idea of the server is to receive and send requests
22 only from the main thread.
23
24 The thread poool should be sized for concurrent tasks, not
25 maximum connections
26 """
27 import threading
28 import socket
29 import Queue
30 import select
31 import struct
32
33 import logging
34 logger = logging.getLogger(__name__)
35
36 from thrift.transport import TTransport
37 from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
38
39 __all__ = ['TNonblockingServer']
40
41
42 class Worker(threading.Thread):
43     """Worker is a small helper to process incoming connection."""
44
45     def __init__(self, queue):
46         threading.Thread.__init__(self)
47         self.queue = queue
48
49     def run(self):
50         """Process queries from task queue, stop if processor is None."""
51         while True:
52             try:
53                 processor, iprot, oprot, otrans, callback = self.queue.get()
54                 if processor is None:
55                     break
56                 processor.process(iprot, oprot)
57                 callback(True, otrans.getvalue())
58             except Exception:
59                 logger.exception("Exception while processing request")
60                 callback(False, '')
61
62 WAIT_LEN = 0
63 WAIT_MESSAGE = 1
64 WAIT_PROCESS = 2
65 SEND_ANSWER = 3
66 CLOSED = 4
67
68
69 def locked(func):
70     """Decorator which locks self.lock."""
71     def nested(self, *args, **kwargs):
72         self.lock.acquire()
73         try:
74             return func(self, *args, **kwargs)
75         finally:
76             self.lock.release()
77     return nested
78
79
80 def socket_exception(func):
81     """Decorator close object on socket.error."""
82     def read(self, *args, **kwargs):
83         try:
84             return func(self, *args, **kwargs)
85         except socket.error:
86             self.close()
87     return read
88
89
90 class Connection:
91     """Basic class is represented connection.
92
93     It can be in state:
94         WAIT_LEN --- connection is reading request len.
95         WAIT_MESSAGE --- connection is reading request.
96         WAIT_PROCESS --- connection has just read whole request and
97                          waits for call ready routine.
98         SEND_ANSWER --- connection is sending answer string (including length
99                         of answer).
100         CLOSED --- socket was closed and connection should be deleted.
101     """
102     def __init__(self, new_socket, wake_up):
103         self.socket = new_socket
104         self.socket.setblocking(False)
105         self.status = WAIT_LEN
106         self.len = 0
107         self.message = ''
108         self.lock = threading.Lock()
109         self.wake_up = wake_up
110
111     def _read_len(self):
112         """Reads length of request.
113
114         It's a safer alternative to self.socket.recv(4)
115         """
116         read = self.socket.recv(4 - len(self.message))
117         if len(read) == 0:
118             # if we read 0 bytes and self.message is empty, then
119             # the client closed the connection
120             if len(self.message) != 0:
121                 logger.error("can't read frame size from socket")
122             self.close()
123             return
124         self.message += read
125         if len(self.message) == 4:
126             self.len, = struct.unpack('!i', self.message)
127             if self.len < 0:
128                 logger.error("negative frame size, it seems client "
129                               "doesn't use FramedTransport")
130                 self.close()
131             elif self.len == 0:
132                 logger.error("empty frame, it's really strange")
133                 self.close()
134             else:
135                 self.message = ''
136                 self.status = WAIT_MESSAGE
137
138     @socket_exception
139     def read(self):
140         """Reads data from stream and switch state."""
141         assert self.status in (WAIT_LEN, WAIT_MESSAGE)
142         if self.status == WAIT_LEN:
143             self._read_len()
144             # go back to the main loop here for simplicity instead of
145             # falling through, even though there is a good chance that
146             # the message is already available
147         elif self.status == WAIT_MESSAGE:
148             read = self.socket.recv(self.len - len(self.message))
149             if len(read) == 0:
150                 logger.error("can't read frame from socket (get %d of "
151                               "%d bytes)" % (len(self.message), self.len))
152                 self.close()
153                 return
154             self.message += read
155             if len(self.message) == self.len:
156                 self.status = WAIT_PROCESS
157
158     @socket_exception
159     def write(self):
160         """Writes data from socket and switch state."""
161         assert self.status == SEND_ANSWER
162         sent = self.socket.send(self.message)
163         if sent == len(self.message):
164             self.status = WAIT_LEN
165             self.message = ''
166             self.len = 0
167         else:
168             self.message = self.message[sent:]
169
170     @locked
171     def ready(self, all_ok, message):
172         """Callback function for switching state and waking up main thread.
173
174         This function is the only function witch can be called asynchronous.
175
176         The ready can switch Connection to three states:
177             WAIT_LEN if request was oneway.
178             SEND_ANSWER if request was processed in normal way.
179             CLOSED if request throws unexpected exception.
180
181         The one wakes up main thread.
182         """
183         assert self.status == WAIT_PROCESS
184         if not all_ok:
185             self.close()
186             self.wake_up()
187             return
188         self.len = ''
189         if len(message) == 0:
190             # it was a oneway request, do not write answer
191             self.message = ''
192             self.status = WAIT_LEN
193         else:
194             self.message = struct.pack('!i', len(message)) + message
195             self.status = SEND_ANSWER
196         self.wake_up()
197
198     @locked
199     def is_writeable(self):
200         """Return True if connection should be added to write list of select"""
201         return self.status == SEND_ANSWER
202
203     # it's not necessary, but...
204     @locked
205     def is_readable(self):
206         """Return True if connection should be added to read list of select"""
207         return self.status in (WAIT_LEN, WAIT_MESSAGE)
208
209     @locked
210     def is_closed(self):
211         """Returns True if connection is closed."""
212         return self.status == CLOSED
213
214     def fileno(self):
215         """Returns the file descriptor of the associated socket."""
216         return self.socket.fileno()
217
218     def close(self):
219         """Closes connection"""
220         self.status = CLOSED
221         self.socket.close()
222
223
224 class TNonblockingServer:
225     """Non-blocking server."""
226
227     def __init__(self,
228                  processor,
229                  lsocket,
230                  inputProtocolFactory=None,
231                  outputProtocolFactory=None,
232                  threads=10):
233         self.processor = processor
234         self.socket = lsocket
235         self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory()
236         self.out_protocol = outputProtocolFactory or self.in_protocol
237         self.threads = int(threads)
238         self.clients = {}
239         self.tasks = Queue.Queue()
240         self._read, self._write = socket.socketpair()
241         self.prepared = False
242         self._stop = False
243
244     def setNumThreads(self, num):
245         """Set the number of worker threads that should be created."""
246         # implement ThreadPool interface
247         assert not self.prepared, "Can't change number of threads after start"
248         self.threads = num
249
250     def prepare(self):
251         """Prepares server for serve requests."""
252         if self.prepared:
253             return
254         self.socket.listen()
255         for _ in xrange(self.threads):
256             thread = Worker(self.tasks)
257             thread.setDaemon(True)
258             thread.start()
259         self.prepared = True
260
261     def wake_up(self):
262         """Wake up main thread.
263
264         The server usually waits in select call in we should terminate one.
265         The simplest way is using socketpair.
266
267         Select always wait to read from the first socket of socketpair.
268
269         In this case, we can just write anything to the second socket from
270         socketpair.
271         """
272         self._write.send('1')
273
274     def stop(self):
275         """Stop the server.
276
277         This method causes the serve() method to return.  stop() may be invoked
278         from within your handler, or from another thread.
279
280         After stop() is called, serve() will return but the server will still
281         be listening on the socket.  serve() may then be called again to resume
282         processing requests.  Alternatively, close() may be called after
283         serve() returns to close the server socket and shutdown all worker
284         threads.
285         """
286         self._stop = True
287         self.wake_up()
288
289     def _select(self):
290         """Does select on open connections."""
291         readable = [self.socket.handle.fileno(), self._read.fileno()]
292         writable = []
293         for i, connection in self.clients.items():
294             if connection.is_readable():
295                 readable.append(connection.fileno())
296             if connection.is_writeable():
297                 writable.append(connection.fileno())
298             if connection.is_closed():
299                 del self.clients[i]
300         return select.select(readable, writable, readable)
301
302     def handle(self):
303         """Handle requests.
304
305         WARNING! You must call prepare() BEFORE calling handle()
306         """
307         assert self.prepared, "You have to call prepare before handle"
308         rset, wset, xset = self._select()
309         for readable in rset:
310             if readable == self._read.fileno():
311                 # don't care i just need to clean readable flag
312                 self._read.recv(1024)
313             elif readable == self.socket.handle.fileno():
314                 client = self.socket.accept().handle
315                 self.clients[client.fileno()] = Connection(client,
316                                                            self.wake_up)
317             else:
318                 connection = self.clients[readable]
319                 connection.read()
320                 if connection.status == WAIT_PROCESS:
321                     itransport = TTransport.TMemoryBuffer(connection.message)
322                     otransport = TTransport.TMemoryBuffer()
323                     iprot = self.in_protocol.getProtocol(itransport)
324                     oprot = self.out_protocol.getProtocol(otransport)
325                     self.tasks.put([self.processor, iprot, oprot,
326                                     otransport, connection.ready])
327         for writeable in wset:
328             self.clients[writeable].write()
329         for oob in xset:
330             self.clients[oob].close()
331             del self.clients[oob]
332
333     def close(self):
334         """Closes the server."""
335         for _ in xrange(self.threads):
336             self.tasks.put([None, None, None, None, None])
337         self.socket.close()
338         self.prepared = False
339
340     def serve(self):
341         """Serve requests.
342
343         Serve requests forever, or until stop() is called.
344         """
345         self._stop = False
346         self.prepare()
347         while not self._stop:
348             self.handle()