2 # Licensed to the Apache Software Foundation (ASF) under one
3 # or more contributor license agreements. See the NOTICE file
4 # distributed with this work for additional information
5 # regarding copyright ownership. The ASF licenses this file
6 # to you under the Apache License, Version 2.0 (the
7 # "License"); you may not use this file except in compliance
8 # with the License. You may obtain a copy of the License at
10 # http://www.apache.org/licenses/LICENSE-2.0
12 # Unless required by applicable law or agreed to in writing,
13 # software distributed under the License is distributed on an
14 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 # KIND, either express or implied. See the License for the
16 # specific language governing permissions and limitations
21 from cStringIO import StringIO
23 from zope.interface import implements, Interface, Attribute
24 from twisted.internet.protocol import ServerFactory, ClientFactory, \
26 from twisted.internet import defer
27 from twisted.internet.threads import deferToThread
28 from twisted.protocols import basic
29 from twisted.web import server, resource, http
31 from thrift.transport import TTransport
34 class TMessageSenderTransport(TTransport.TTransportBase):
37 self.__wbuf = StringIO()
40 self.__wbuf.write(buf)
43 msg = self.__wbuf.getvalue()
44 self.__wbuf = StringIO()
45 return self.sendMessage(msg)
47 def sendMessage(self, message):
48 raise NotImplementedError
51 class TCallbackTransport(TMessageSenderTransport):
53 def __init__(self, func):
54 TMessageSenderTransport.__init__(self)
57 def sendMessage(self, message):
58 return self.func(message)
61 class ThriftClientProtocol(basic.Int32StringReceiver):
63 MAX_LENGTH = 2 ** 31 - 1
65 def __init__(self, client_class, iprot_factory, oprot_factory=None):
66 self._client_class = client_class
67 self._iprot_factory = iprot_factory
68 if oprot_factory is None:
69 self._oprot_factory = iprot_factory
71 self._oprot_factory = oprot_factory
74 self.started = defer.Deferred()
76 def dispatch(self, msg):
79 def connectionMade(self):
80 tmo = TCallbackTransport(self.dispatch)
81 self.client = self._client_class(tmo, self._oprot_factory)
82 self.started.callback(self.client)
84 def connectionLost(self, reason=connectionDone):
85 for k, v in self.client._reqs.iteritems():
86 tex = TTransport.TTransportException(
87 type=TTransport.TTransportException.END_OF_FILE,
88 message='Connection closed')
91 def stringReceived(self, frame):
92 tr = TTransport.TMemoryBuffer(frame)
93 iprot = self._iprot_factory.getProtocol(tr)
94 (fname, mtype, rseqid) = iprot.readMessageBegin()
97 method = self.recv_map[fname]
99 method = getattr(self.client, 'recv_' + fname)
100 self.recv_map[fname] = method
102 method(iprot, mtype, rseqid)
105 class ThriftSASLClientProtocol(ThriftClientProtocol):
113 MAX_LENGTH = 2 ** 31 - 1
115 def __init__(self, client_class, iprot_factory, oprot_factory=None,
116 host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
118 host: the name of the server, from a SASL perspective
119 service: the name of the server's service, from a SASL perspective
120 mechanism: the name of the preferred mechanism to use
122 All other kwargs will be passed to the puresasl.client.SASLClient
126 from puresasl.client import SASLClient
127 self.SASLCLient = SASLClient
129 ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
131 self._sasl_negotiation_deferred = None
132 self._sasl_negotiation_status = None
136 self.createSASLClient(host, service, mechanism, **sasl_kwargs)
138 def createSASLClient(self, host, service, mechanism, **kwargs):
139 self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
141 def dispatch(self, msg):
142 encoded = self.sasl.wrap(msg)
143 len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
144 ThriftClientProtocol.dispatch(self, len_and_encoded)
146 @defer.inlineCallbacks
147 def connectionMade(self):
148 self._sendSASLMessage(self.START, self.sasl.mechanism)
149 initial_message = yield deferToThread(self.sasl.process)
150 self._sendSASLMessage(self.OK, initial_message)
153 status, challenge = yield self._receiveSASLMessage()
154 if status == self.OK:
155 response = yield deferToThread(self.sasl.process, challenge)
156 self._sendSASLMessage(self.OK, response)
157 elif status == self.COMPLETE:
158 if not self.sasl.complete:
159 msg = "The server erroneously indicated that SASL " \
160 "negotiation was complete"
161 raise TTransport.TTransportException(msg, message=msg)
165 msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
166 raise TTransport.TTransportException(msg, message=msg)
168 self._sasl_negotiation_deferred = None
169 ThriftClientProtocol.connectionMade(self)
171 def _sendSASLMessage(self, status, body):
174 header = struct.pack(">BI", status, len(body))
175 self.transport.write(header + body)
177 def _receiveSASLMessage(self):
178 self._sasl_negotiation_deferred = defer.Deferred()
179 self._sasl_negotiation_status = None
180 return self._sasl_negotiation_deferred
182 def connectionLost(self, reason=connectionDone):
184 ThriftClientProtocol.connectionLost(self, reason)
186 def dataReceived(self, data):
187 if self._sasl_negotiation_deferred:
188 # we got a sasl challenge in the format (status, length, challenge)
189 # save the status, let IntNStringReceiver piece the challenge data together
190 self._sasl_negotiation_status, = struct.unpack("B", data[0])
191 ThriftClientProtocol.dataReceived(self, data[1:])
193 # normal frame, let IntNStringReceiver piece it together
194 ThriftClientProtocol.dataReceived(self, data)
196 def stringReceived(self, frame):
197 if self._sasl_negotiation_deferred:
198 # the frame is just a SASL challenge
199 response = (self._sasl_negotiation_status, frame)
200 self._sasl_negotiation_deferred.callback(response)
202 # there's a second 4 byte length prefix inside the frame
203 decoded_frame = self.sasl.unwrap(frame[4:])
204 ThriftClientProtocol.stringReceived(self, decoded_frame)
207 class ThriftServerProtocol(basic.Int32StringReceiver):
209 MAX_LENGTH = 2 ** 31 - 1
211 def dispatch(self, msg):
214 def processError(self, error):
215 self.transport.loseConnection()
217 def processOk(self, _, tmo):
223 def stringReceived(self, frame):
224 tmi = TTransport.TMemoryBuffer(frame)
225 tmo = TTransport.TMemoryBuffer()
227 iprot = self.factory.iprot_factory.getProtocol(tmi)
228 oprot = self.factory.oprot_factory.getProtocol(tmo)
230 d = self.factory.processor.process(iprot, oprot)
231 d.addCallbacks(self.processOk, self.processError,
235 class IThriftServerFactory(Interface):
237 processor = Attribute("Thrift processor")
239 iprot_factory = Attribute("Input protocol factory")
241 oprot_factory = Attribute("Output protocol factory")
244 class IThriftClientFactory(Interface):
246 client_class = Attribute("Thrift client class")
248 iprot_factory = Attribute("Input protocol factory")
250 oprot_factory = Attribute("Output protocol factory")
253 class ThriftServerFactory(ServerFactory):
255 implements(IThriftServerFactory)
257 protocol = ThriftServerProtocol
259 def __init__(self, processor, iprot_factory, oprot_factory=None):
260 self.processor = processor
261 self.iprot_factory = iprot_factory
262 if oprot_factory is None:
263 self.oprot_factory = iprot_factory
265 self.oprot_factory = oprot_factory
268 class ThriftClientFactory(ClientFactory):
270 implements(IThriftClientFactory)
272 protocol = ThriftClientProtocol
274 def __init__(self, client_class, iprot_factory, oprot_factory=None):
275 self.client_class = client_class
276 self.iprot_factory = iprot_factory
277 if oprot_factory is None:
278 self.oprot_factory = iprot_factory
280 self.oprot_factory = oprot_factory
282 def buildProtocol(self, addr):
283 p = self.protocol(self.client_class, self.iprot_factory,
289 class ThriftResource(resource.Resource):
291 allowedMethods = ('POST',)
293 def __init__(self, processor, inputProtocolFactory,
294 outputProtocolFactory=None):
295 resource.Resource.__init__(self)
296 self.inputProtocolFactory = inputProtocolFactory
297 if outputProtocolFactory is None:
298 self.outputProtocolFactory = inputProtocolFactory
300 self.outputProtocolFactory = outputProtocolFactory
301 self.processor = processor
303 def getChild(self, path, request):
306 def _cbProcess(self, _, request, tmo):
308 request.setResponseCode(http.OK)
309 request.setHeader("content-type", "application/x-thrift")
313 def render_POST(self, request):
314 request.content.seek(0, 0)
315 data = request.content.read()
316 tmi = TTransport.TMemoryBuffer(data)
317 tmo = TTransport.TMemoryBuffer()
319 iprot = self.inputProtocolFactory.getProtocol(tmi)
320 oprot = self.outputProtocolFactory.getProtocol(tmo)
322 d = self.processor.process(iprot, oprot)
323 d.addCallback(self._cbProcess, request, tmo)
324 return server.NOT_DONE_YET