2 # Licensed to the Apache Software Foundation (ASF) under one
3 # or more contributor license agreements. See the NOTICE file
4 # distributed with this work for additional information
5 # regarding copyright ownership. The ASF licenses this file
6 # to you under the Apache License, Version 2.0 (the
7 # "License"); you may not use this file except in compliance
8 # with the License. You may obtain a copy of the License at
10 # http://www.apache.org/licenses/LICENSE-2.0
12 # Unless required by applicable law or agreed to in writing,
13 # software distributed under the License is distributed on an
14 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 # KIND, either express or implied. See the License for the
16 # specific language governing permissions and limitations
20 from thrift.Thrift import *
23 class TProtocolException(TException):
24 """Custom Protocol Exception class"""
34 def __init__(self, type=UNKNOWN, message=None):
35 TException.__init__(self, message)
40 """Base class for Thrift protocol driver."""
42 def __init__(self, trans):
45 def writeMessageBegin(self, name, ttype, seqid):
48 def writeMessageEnd(self):
51 def writeStructBegin(self, name):
54 def writeStructEnd(self):
57 def writeFieldBegin(self, name, ttype, fid):
60 def writeFieldEnd(self):
63 def writeFieldStop(self):
66 def writeMapBegin(self, ktype, vtype, size):
69 def writeMapEnd(self):
72 def writeListBegin(self, etype, size):
75 def writeListEnd(self):
78 def writeSetBegin(self, etype, size):
81 def writeSetEnd(self):
84 def writeBool(self, bool_val):
87 def writeByte(self, byte):
90 def writeI16(self, i16):
93 def writeI32(self, i32):
96 def writeI64(self, i64):
99 def writeDouble(self, dub):
102 def writeString(self, str_val):
105 def readMessageBegin(self):
108 def readMessageEnd(self):
111 def readStructBegin(self):
114 def readStructEnd(self):
117 def readFieldBegin(self):
120 def readFieldEnd(self):
123 def readMapBegin(self):
126 def readMapEnd(self):
129 def readListBegin(self):
132 def readListEnd(self):
135 def readSetBegin(self):
138 def readSetEnd(self):
156 def readDouble(self):
159 def readString(self):
162 def skip(self, ttype):
163 if ttype == TType.STOP:
165 elif ttype == TType.BOOL:
167 elif ttype == TType.BYTE:
169 elif ttype == TType.I16:
171 elif ttype == TType.I32:
173 elif ttype == TType.I64:
175 elif ttype == TType.DOUBLE:
177 elif ttype == TType.STRING:
179 elif ttype == TType.STRUCT:
180 name = self.readStructBegin()
182 (name, ttype, id) = self.readFieldBegin()
183 if ttype == TType.STOP:
188 elif ttype == TType.MAP:
189 (ktype, vtype, size) = self.readMapBegin()
190 for i in xrange(size):
194 elif ttype == TType.SET:
195 (etype, size) = self.readSetBegin()
196 for i in xrange(size):
199 elif ttype == TType.LIST:
200 (etype, size) = self.readListBegin()
201 for i in xrange(size):
205 # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )
207 (None, None, False), # 0 TType.STOP
208 (None, None, False), # 1 TType.VOID # TODO: handle void?
209 ('readBool', 'writeBool', False), # 2 TType.BOOL
210 ('readByte', 'writeByte', False), # 3 TType.BYTE and I08
211 ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE
212 (None, None, False), # 5 undefined
213 ('readI16', 'writeI16', False), # 6 TType.I16
214 (None, None, False), # 7 undefined
215 ('readI32', 'writeI32', False), # 8 TType.I32
216 (None, None, False), # 9 undefined
217 ('readI64', 'writeI64', False), # 10 TType.I64
218 ('readString', 'writeString', False), # 11 TType.STRING and UTF7
219 ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT
220 ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP
221 ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET
222 ('readContainerList', 'writeContainerList', True), # 15 TType.LIST
223 (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types?
224 (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
227 def readFieldByTType(self, ttype, spec):
229 (r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype]
231 raise TProtocolException(type=TProtocolException.INVALID_DATA,
232 message='Invalid field type %d' % (ttype))
233 if r_handler is None:
234 raise TProtocolException(type=TProtocolException.INVALID_DATA,
235 message='Invalid field type %d' % (ttype))
236 reader = getattr(self, r_handler)
241 def readContainerList(self, spec):
243 ttype, tspec = spec[0], spec[1]
244 r_handler = self._TTYPE_HANDLERS[ttype][0]
245 reader = getattr(self, r_handler)
246 (list_type, list_len) = self.readListBegin()
248 # list values are simple types
249 for idx in xrange(list_len):
250 results.append(reader())
252 # this is like an inlined readFieldByTType
253 container_reader = self._TTYPE_HANDLERS[list_type][0]
254 val_reader = getattr(self, container_reader)
255 for idx in xrange(list_len):
256 val = val_reader(tspec)
261 def readContainerSet(self, spec):
263 ttype, tspec = spec[0], spec[1]
264 r_handler = self._TTYPE_HANDLERS[ttype][0]
265 reader = getattr(self, r_handler)
266 (set_type, set_len) = self.readSetBegin()
268 # set members are simple types
269 for idx in xrange(set_len):
270 results.add(reader())
272 container_reader = self._TTYPE_HANDLERS[set_type][0]
273 val_reader = getattr(self, container_reader)
274 for idx in xrange(set_len):
275 results.add(val_reader(tspec))
279 def readContainerStruct(self, spec):
280 (obj_class, obj_spec) = spec
285 def readContainerMap(self, spec):
287 key_ttype, key_spec = spec[0], spec[1]
288 val_ttype, val_spec = spec[2], spec[3]
289 (map_ktype, map_vtype, map_len) = self.readMapBegin()
290 # TODO: compare types we just decoded with thrift_spec and
291 # abort/skip if types disagree
292 key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0])
293 val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0])
294 # list values are simple types
295 for idx in xrange(map_len):
299 k_val = self.readFieldByTType(key_ttype, key_spec)
303 v_val = self.readFieldByTType(val_ttype, val_spec)
304 # this raises a TypeError with unhashable keys types
305 # i.e. this fails: d=dict(); d[[0,1]] = 2
306 results[k_val] = v_val
310 def readStruct(self, obj, thrift_spec):
311 self.readStructBegin()
313 (fname, ftype, fid) = self.readFieldBegin()
314 if ftype == TType.STOP:
317 field = thrift_spec[fid]
321 if field is not None and ftype == field[1]:
324 val = self.readFieldByTType(ftype, fspec)
325 setattr(obj, fname, val)
331 def writeContainerStruct(self, val, spec):
334 def writeContainerList(self, val, spec):
335 self.writeListBegin(spec[0], len(val))
336 r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
337 e_writer = getattr(self, w_handler)
343 e_writer(elem, spec[1])
346 def writeContainerSet(self, val, spec):
347 self.writeSetBegin(spec[0], len(val))
348 r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
349 e_writer = getattr(self, w_handler)
355 e_writer(elem, spec[1])
358 def writeContainerMap(self, val, spec):
361 ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type]
362 ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type]
363 k_writer = getattr(self, ktype_name)
364 v_writer = getattr(self, vtype_name)
365 self.writeMapBegin(k_type, v_type, len(val))
366 for m_key, m_val in val.iteritems():
367 if not k_is_container:
370 k_writer(m_key, spec[1])
371 if not v_is_container:
374 v_writer(m_val, spec[3])
377 def writeStruct(self, obj, thrift_spec):
378 self.writeStructBegin(obj.__class__.__name__)
379 for field in thrift_spec:
383 val = getattr(obj, fname)
385 # skip writing out unset fields
390 # get the writer method for this value
391 self.writeFieldBegin(fname, ftype, fid)
392 self.writeFieldByTType(ftype, val, fspec)
394 self.writeFieldStop()
395 self.writeStructEnd()
397 def writeFieldByTType(self, ttype, val, spec):
398 r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype]
399 writer = getattr(self, w_handler)
405 def checkIntegerLimits(i, bits):
406 if bits == 8 and (i < -128 or i > 127):
407 raise TProtocolException(TProtocolException.INVALID_DATA,
408 "i8 requires -128 <= number <= 127")
409 elif bits == 16 and (i < -32768 or i > 32767):
410 raise TProtocolException(TProtocolException.INVALID_DATA,
411 "i16 requires -32768 <= number <= 32767")
412 elif bits == 32 and (i < -2147483648 or i > 2147483647):
413 raise TProtocolException(TProtocolException.INVALID_DATA,
414 "i32 requires -2147483648 <= number <= 2147483647")
415 elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807):
416 raise TProtocolException(TProtocolException.INVALID_DATA,
417 "i64 requires -9223372036854775808 <= number <= 9223372036854775807")
419 class TProtocolFactory:
420 def getProtocol(self, trans):