Functest test script file update
[domino.git] / lib / thrift / transport / TTwisted.py
1 #
2 # Licensed to the Apache Software Foundation (ASF) under one
3 # or more contributor license agreements. See the NOTICE file
4 # distributed with this work for additional information
5 # regarding copyright ownership. The ASF licenses this file
6 # to you under the Apache License, Version 2.0 (the
7 # "License"); you may not use this file except in compliance
8 # with the License. You may obtain a copy of the License at
9 #
10 #   http://www.apache.org/licenses/LICENSE-2.0
11 #
12 # Unless required by applicable law or agreed to in writing,
13 # software distributed under the License is distributed on an
14 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 # KIND, either express or implied. See the License for the
16 # specific language governing permissions and limitations
17 # under the License.
18 #
19
20 import struct
21 from cStringIO import StringIO
22
23 from zope.interface import implements, Interface, Attribute
24 from twisted.internet.protocol import ServerFactory, ClientFactory, \
25     connectionDone
26 from twisted.internet import defer
27 from twisted.internet.threads import deferToThread
28 from twisted.protocols import basic
29 from twisted.web import server, resource, http
30
31 from thrift.transport import TTransport
32
33
34 class TMessageSenderTransport(TTransport.TTransportBase):
35
36     def __init__(self):
37         self.__wbuf = StringIO()
38
39     def write(self, buf):
40         self.__wbuf.write(buf)
41
42     def flush(self):
43         msg = self.__wbuf.getvalue()
44         self.__wbuf = StringIO()
45         return self.sendMessage(msg)
46
47     def sendMessage(self, message):
48         raise NotImplementedError
49
50
51 class TCallbackTransport(TMessageSenderTransport):
52
53     def __init__(self, func):
54         TMessageSenderTransport.__init__(self)
55         self.func = func
56
57     def sendMessage(self, message):
58         return self.func(message)
59
60
61 class ThriftClientProtocol(basic.Int32StringReceiver):
62
63     MAX_LENGTH = 2 ** 31 - 1
64
65     def __init__(self, client_class, iprot_factory, oprot_factory=None):
66         self._client_class = client_class
67         self._iprot_factory = iprot_factory
68         if oprot_factory is None:
69             self._oprot_factory = iprot_factory
70         else:
71             self._oprot_factory = oprot_factory
72
73         self.recv_map = {}
74         self.started = defer.Deferred()
75
76     def dispatch(self, msg):
77         self.sendString(msg)
78
79     def connectionMade(self):
80         tmo = TCallbackTransport(self.dispatch)
81         self.client = self._client_class(tmo, self._oprot_factory)
82         self.started.callback(self.client)
83
84     def connectionLost(self, reason=connectionDone):
85         for k, v in self.client._reqs.iteritems():
86             tex = TTransport.TTransportException(
87                 type=TTransport.TTransportException.END_OF_FILE,
88                 message='Connection closed')
89             v.errback(tex)
90
91     def stringReceived(self, frame):
92         tr = TTransport.TMemoryBuffer(frame)
93         iprot = self._iprot_factory.getProtocol(tr)
94         (fname, mtype, rseqid) = iprot.readMessageBegin()
95
96         try:
97             method = self.recv_map[fname]
98         except KeyError:
99             method = getattr(self.client, 'recv_' + fname)
100             self.recv_map[fname] = method
101
102         method(iprot, mtype, rseqid)
103
104
105 class ThriftSASLClientProtocol(ThriftClientProtocol):
106
107     START = 1
108     OK = 2
109     BAD = 3
110     ERROR = 4
111     COMPLETE = 5
112
113     MAX_LENGTH = 2 ** 31 - 1
114
115     def __init__(self, client_class, iprot_factory, oprot_factory=None,
116             host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
117         """
118         host: the name of the server, from a SASL perspective
119         service: the name of the server's service, from a SASL perspective
120         mechanism: the name of the preferred mechanism to use
121
122         All other kwargs will be passed to the puresasl.client.SASLClient
123         constructor.
124         """
125
126         from puresasl.client import SASLClient
127         self.SASLCLient = SASLClient
128
129         ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
130
131         self._sasl_negotiation_deferred = None
132         self._sasl_negotiation_status = None
133         self.client = None
134
135         if host is not None:
136             self.createSASLClient(host, service, mechanism, **sasl_kwargs)
137
138     def createSASLClient(self, host, service, mechanism, **kwargs):
139         self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
140
141     def dispatch(self, msg):
142         encoded = self.sasl.wrap(msg)
143         len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
144         ThriftClientProtocol.dispatch(self, len_and_encoded)
145
146     @defer.inlineCallbacks
147     def connectionMade(self):
148         self._sendSASLMessage(self.START, self.sasl.mechanism)
149         initial_message = yield deferToThread(self.sasl.process)
150         self._sendSASLMessage(self.OK, initial_message)
151
152         while True:
153             status, challenge = yield self._receiveSASLMessage()
154             if status == self.OK:
155                 response = yield deferToThread(self.sasl.process, challenge)
156                 self._sendSASLMessage(self.OK, response)
157             elif status == self.COMPLETE:
158                 if not self.sasl.complete:
159                     msg = "The server erroneously indicated that SASL " \
160                           "negotiation was complete"
161                     raise TTransport.TTransportException(msg, message=msg)
162                 else:
163                     break
164             else:
165                 msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
166                 raise TTransport.TTransportException(msg, message=msg)
167
168         self._sasl_negotiation_deferred = None
169         ThriftClientProtocol.connectionMade(self)
170
171     def _sendSASLMessage(self, status, body):
172         if body is None:
173             body = ""
174         header = struct.pack(">BI", status, len(body))
175         self.transport.write(header + body)
176
177     def _receiveSASLMessage(self):
178         self._sasl_negotiation_deferred = defer.Deferred()
179         self._sasl_negotiation_status = None
180         return self._sasl_negotiation_deferred
181
182     def connectionLost(self, reason=connectionDone):
183         if self.client:
184             ThriftClientProtocol.connectionLost(self, reason)
185
186     def dataReceived(self, data):
187         if self._sasl_negotiation_deferred:
188             # we got a sasl challenge in the format (status, length, challenge)
189             # save the status, let IntNStringReceiver piece the challenge data together
190             self._sasl_negotiation_status, = struct.unpack("B", data[0])
191             ThriftClientProtocol.dataReceived(self, data[1:])
192         else:
193             # normal frame, let IntNStringReceiver piece it together
194             ThriftClientProtocol.dataReceived(self, data)
195
196     def stringReceived(self, frame):
197         if self._sasl_negotiation_deferred:
198             # the frame is just a SASL challenge
199             response = (self._sasl_negotiation_status, frame)
200             self._sasl_negotiation_deferred.callback(response)
201         else:
202             # there's a second 4 byte length prefix inside the frame
203             decoded_frame = self.sasl.unwrap(frame[4:])
204             ThriftClientProtocol.stringReceived(self, decoded_frame)
205
206
207 class ThriftServerProtocol(basic.Int32StringReceiver):
208
209     MAX_LENGTH = 2 ** 31 - 1
210
211     def dispatch(self, msg):
212         self.sendString(msg)
213
214     def processError(self, error):
215         self.transport.loseConnection()
216
217     def processOk(self, _, tmo):
218         msg = tmo.getvalue()
219
220         if len(msg) > 0:
221             self.dispatch(msg)
222
223     def stringReceived(self, frame):
224         tmi = TTransport.TMemoryBuffer(frame)
225         tmo = TTransport.TMemoryBuffer()
226
227         iprot = self.factory.iprot_factory.getProtocol(tmi)
228         oprot = self.factory.oprot_factory.getProtocol(tmo)
229
230         d = self.factory.processor.process(iprot, oprot)
231         d.addCallbacks(self.processOk, self.processError,
232             callbackArgs=(tmo,))
233
234
235 class IThriftServerFactory(Interface):
236
237     processor = Attribute("Thrift processor")
238
239     iprot_factory = Attribute("Input protocol factory")
240
241     oprot_factory = Attribute("Output protocol factory")
242
243
244 class IThriftClientFactory(Interface):
245
246     client_class = Attribute("Thrift client class")
247
248     iprot_factory = Attribute("Input protocol factory")
249
250     oprot_factory = Attribute("Output protocol factory")
251
252
253 class ThriftServerFactory(ServerFactory):
254
255     implements(IThriftServerFactory)
256
257     protocol = ThriftServerProtocol
258
259     def __init__(self, processor, iprot_factory, oprot_factory=None):
260         self.processor = processor
261         self.iprot_factory = iprot_factory
262         if oprot_factory is None:
263             self.oprot_factory = iprot_factory
264         else:
265             self.oprot_factory = oprot_factory
266
267
268 class ThriftClientFactory(ClientFactory):
269
270     implements(IThriftClientFactory)
271
272     protocol = ThriftClientProtocol
273
274     def __init__(self, client_class, iprot_factory, oprot_factory=None):
275         self.client_class = client_class
276         self.iprot_factory = iprot_factory
277         if oprot_factory is None:
278             self.oprot_factory = iprot_factory
279         else:
280             self.oprot_factory = oprot_factory
281
282     def buildProtocol(self, addr):
283         p = self.protocol(self.client_class, self.iprot_factory,
284             self.oprot_factory)
285         p.factory = self
286         return p
287
288
289 class ThriftResource(resource.Resource):
290
291     allowedMethods = ('POST',)
292
293     def __init__(self, processor, inputProtocolFactory,
294         outputProtocolFactory=None):
295         resource.Resource.__init__(self)
296         self.inputProtocolFactory = inputProtocolFactory
297         if outputProtocolFactory is None:
298             self.outputProtocolFactory = inputProtocolFactory
299         else:
300             self.outputProtocolFactory = outputProtocolFactory
301         self.processor = processor
302
303     def getChild(self, path, request):
304         return self
305
306     def _cbProcess(self, _, request, tmo):
307         msg = tmo.getvalue()
308         request.setResponseCode(http.OK)
309         request.setHeader("content-type", "application/x-thrift")
310         request.write(msg)
311         request.finish()
312
313     def render_POST(self, request):
314         request.content.seek(0, 0)
315         data = request.content.read()
316         tmi = TTransport.TMemoryBuffer(data)
317         tmo = TTransport.TMemoryBuffer()
318
319         iprot = self.inputProtocolFactory.getProtocol(tmi)
320         oprot = self.outputProtocolFactory.getProtocol(tmo)
321
322         d = self.processor.process(iprot, oprot)
323         d.addCallback(self._cbProcess, request, tmo)
324         return server.NOT_DONE_YET