Simple mapper/scheduler/partitioner functions implemented
[domino.git] / lib / thrift / protocol / TCompactProtocol.py
1 #
2 # Licensed to the Apache Software Foundation (ASF) under one
3 # or more contributor license agreements. See the NOTICE file
4 # distributed with this work for additional information
5 # regarding copyright ownership. The ASF licenses this file
6 # to you under the Apache License, Version 2.0 (the
7 # "License"); you may not use this file except in compliance
8 # with the License. You may obtain a copy of the License at
9 #
10 #   http://www.apache.org/licenses/LICENSE-2.0
11 #
12 # Unless required by applicable law or agreed to in writing,
13 # software distributed under the License is distributed on an
14 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 # KIND, either express or implied. See the License for the
16 # specific language governing permissions and limitations
17 # under the License.
18 #
19
20 from TProtocol import *
21 from struct import pack, unpack
22
23 __all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
24
25 CLEAR = 0
26 FIELD_WRITE = 1
27 VALUE_WRITE = 2
28 CONTAINER_WRITE = 3
29 BOOL_WRITE = 4
30 FIELD_READ = 5
31 CONTAINER_READ = 6
32 VALUE_READ = 7
33 BOOL_READ = 8
34
35
36 def make_helper(v_from, container):
37   def helper(func):
38     def nested(self, *args, **kwargs):
39       assert self.state in (v_from, container), (self.state, v_from, container)
40       return func(self, *args, **kwargs)
41     return nested
42   return helper
43 writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
44 reader = make_helper(VALUE_READ, CONTAINER_READ)
45
46
47 def makeZigZag(n, bits):
48   checkIntegerLimits(n, bits)
49   return (n << 1) ^ (n >> (bits - 1))
50
51
52 def fromZigZag(n):
53   return (n >> 1) ^ -(n & 1)
54
55
56 def writeVarint(trans, n):
57   out = []
58   while True:
59     if n & ~0x7f == 0:
60       out.append(n)
61       break
62     else:
63       out.append((n & 0xff) | 0x80)
64       n = n >> 7
65   trans.write(''.join(map(chr, out)))
66
67
68 def readVarint(trans):
69   result = 0
70   shift = 0
71   while True:
72     x = trans.readAll(1)
73     byte = ord(x)
74     result |= (byte & 0x7f) << shift
75     if byte >> 7 == 0:
76       return result
77     shift += 7
78
79
80 class CompactType:
81   STOP = 0x00
82   TRUE = 0x01
83   FALSE = 0x02
84   BYTE = 0x03
85   I16 = 0x04
86   I32 = 0x05
87   I64 = 0x06
88   DOUBLE = 0x07
89   BINARY = 0x08
90   LIST = 0x09
91   SET = 0x0A
92   MAP = 0x0B
93   STRUCT = 0x0C
94
95 CTYPES = {TType.STOP: CompactType.STOP,
96           TType.BOOL: CompactType.TRUE,  # used for collection
97           TType.BYTE: CompactType.BYTE,
98           TType.I16: CompactType.I16,
99           TType.I32: CompactType.I32,
100           TType.I64: CompactType.I64,
101           TType.DOUBLE: CompactType.DOUBLE,
102           TType.STRING: CompactType.BINARY,
103           TType.STRUCT: CompactType.STRUCT,
104           TType.LIST: CompactType.LIST,
105           TType.SET: CompactType.SET,
106           TType.MAP: CompactType.MAP
107           }
108
109 TTYPES = {}
110 for k, v in CTYPES.items():
111   TTYPES[v] = k
112 TTYPES[CompactType.FALSE] = TType.BOOL
113 del k
114 del v
115
116
117 class TCompactProtocol(TProtocolBase):
118   """Compact implementation of the Thrift protocol driver."""
119
120   PROTOCOL_ID = 0x82
121   VERSION = 1
122   VERSION_MASK = 0x1f
123   TYPE_MASK = 0xe0
124   TYPE_BITS = 0x07
125   TYPE_SHIFT_AMOUNT = 5
126
127   def __init__(self, trans):
128     TProtocolBase.__init__(self, trans)
129     self.state = CLEAR
130     self.__last_fid = 0
131     self.__bool_fid = None
132     self.__bool_value = None
133     self.__structs = []
134     self.__containers = []
135
136   def __writeVarint(self, n):
137     writeVarint(self.trans, n)
138
139   def writeMessageBegin(self, name, type, seqid):
140     assert self.state == CLEAR
141     self.__writeUByte(self.PROTOCOL_ID)
142     self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
143     self.__writeVarint(seqid)
144     self.__writeString(name)
145     self.state = VALUE_WRITE
146
147   def writeMessageEnd(self):
148     assert self.state == VALUE_WRITE
149     self.state = CLEAR
150
151   def writeStructBegin(self, name):
152     assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
153     self.__structs.append((self.state, self.__last_fid))
154     self.state = FIELD_WRITE
155     self.__last_fid = 0
156
157   def writeStructEnd(self):
158     assert self.state == FIELD_WRITE
159     self.state, self.__last_fid = self.__structs.pop()
160
161   def writeFieldStop(self):
162     self.__writeByte(0)
163
164   def __writeFieldHeader(self, type, fid):
165     delta = fid - self.__last_fid
166     if 0 < delta <= 15:
167       self.__writeUByte(delta << 4 | type)
168     else:
169       self.__writeByte(type)
170       self.__writeI16(fid)
171     self.__last_fid = fid
172
173   def writeFieldBegin(self, name, type, fid):
174     assert self.state == FIELD_WRITE, self.state
175     if type == TType.BOOL:
176       self.state = BOOL_WRITE
177       self.__bool_fid = fid
178     else:
179       self.state = VALUE_WRITE
180       self.__writeFieldHeader(CTYPES[type], fid)
181
182   def writeFieldEnd(self):
183     assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
184     self.state = FIELD_WRITE
185
186   def __writeUByte(self, byte):
187     self.trans.write(pack('!B', byte))
188
189   def __writeByte(self, byte):
190     self.trans.write(pack('!b', byte))
191
192   def __writeI16(self, i16):
193     self.__writeVarint(makeZigZag(i16, 16))
194
195   def __writeSize(self, i32):
196     self.__writeVarint(i32)
197
198   def writeCollectionBegin(self, etype, size):
199     assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
200     if size <= 14:
201       self.__writeUByte(size << 4 | CTYPES[etype])
202     else:
203       self.__writeUByte(0xf0 | CTYPES[etype])
204       self.__writeSize(size)
205     self.__containers.append(self.state)
206     self.state = CONTAINER_WRITE
207   writeSetBegin = writeCollectionBegin
208   writeListBegin = writeCollectionBegin
209
210   def writeMapBegin(self, ktype, vtype, size):
211     assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
212     if size == 0:
213       self.__writeByte(0)
214     else:
215       self.__writeSize(size)
216       self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
217     self.__containers.append(self.state)
218     self.state = CONTAINER_WRITE
219
220   def writeCollectionEnd(self):
221     assert self.state == CONTAINER_WRITE, self.state
222     self.state = self.__containers.pop()
223   writeMapEnd = writeCollectionEnd
224   writeSetEnd = writeCollectionEnd
225   writeListEnd = writeCollectionEnd
226
227   def writeBool(self, bool):
228     if self.state == BOOL_WRITE:
229       if bool:
230         ctype = CompactType.TRUE
231       else:
232         ctype = CompactType.FALSE
233       self.__writeFieldHeader(ctype, self.__bool_fid)
234     elif self.state == CONTAINER_WRITE:
235       if bool:
236         self.__writeByte(CompactType.TRUE)
237       else:
238         self.__writeByte(CompactType.FALSE)
239     else:
240       raise AssertionError("Invalid state in compact protocol")
241
242   writeByte = writer(__writeByte)
243   writeI16 = writer(__writeI16)
244
245   @writer
246   def writeI32(self, i32):
247     self.__writeVarint(makeZigZag(i32, 32))
248
249   @writer
250   def writeI64(self, i64):
251     self.__writeVarint(makeZigZag(i64, 64))
252
253   @writer
254   def writeDouble(self, dub):
255     self.trans.write(pack('<d', dub))
256
257   def __writeString(self, s):
258     self.__writeSize(len(s))
259     self.trans.write(s)
260   writeString = writer(__writeString)
261
262   def readFieldBegin(self):
263     assert self.state == FIELD_READ, self.state
264     type = self.__readUByte()
265     if type & 0x0f == TType.STOP:
266       return (None, 0, 0)
267     delta = type >> 4
268     if delta == 0:
269       fid = self.__readI16()
270     else:
271       fid = self.__last_fid + delta
272     self.__last_fid = fid
273     type = type & 0x0f
274     if type == CompactType.TRUE:
275       self.state = BOOL_READ
276       self.__bool_value = True
277     elif type == CompactType.FALSE:
278       self.state = BOOL_READ
279       self.__bool_value = False
280     else:
281       self.state = VALUE_READ
282     return (None, self.__getTType(type), fid)
283
284   def readFieldEnd(self):
285     assert self.state in (VALUE_READ, BOOL_READ), self.state
286     self.state = FIELD_READ
287
288   def __readUByte(self):
289     result, = unpack('!B', self.trans.readAll(1))
290     return result
291
292   def __readByte(self):
293     result, = unpack('!b', self.trans.readAll(1))
294     return result
295
296   def __readVarint(self):
297     return readVarint(self.trans)
298
299   def __readZigZag(self):
300     return fromZigZag(self.__readVarint())
301
302   def __readSize(self):
303     result = self.__readVarint()
304     if result < 0:
305       raise TException("Length < 0")
306     return result
307
308   def readMessageBegin(self):
309     assert self.state == CLEAR
310     proto_id = self.__readUByte()
311     if proto_id != self.PROTOCOL_ID:
312       raise TProtocolException(TProtocolException.BAD_VERSION,
313           'Bad protocol id in the message: %d' % proto_id)
314     ver_type = self.__readUByte()
315     type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS
316     version = ver_type & self.VERSION_MASK
317     if version != self.VERSION:
318       raise TProtocolException(TProtocolException.BAD_VERSION,
319           'Bad version: %d (expect %d)' % (version, self.VERSION))
320     seqid = self.__readVarint()
321     name = self.__readString()
322     return (name, type, seqid)
323
324   def readMessageEnd(self):
325     assert self.state == CLEAR
326     assert len(self.__structs) == 0
327
328   def readStructBegin(self):
329     assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
330     self.__structs.append((self.state, self.__last_fid))
331     self.state = FIELD_READ
332     self.__last_fid = 0
333
334   def readStructEnd(self):
335     assert self.state == FIELD_READ
336     self.state, self.__last_fid = self.__structs.pop()
337
338   def readCollectionBegin(self):
339     assert self.state in (VALUE_READ, CONTAINER_READ), self.state
340     size_type = self.__readUByte()
341     size = size_type >> 4
342     type = self.__getTType(size_type)
343     if size == 15:
344       size = self.__readSize()
345     self.__containers.append(self.state)
346     self.state = CONTAINER_READ
347     return type, size
348   readSetBegin = readCollectionBegin
349   readListBegin = readCollectionBegin
350
351   def readMapBegin(self):
352     assert self.state in (VALUE_READ, CONTAINER_READ), self.state
353     size = self.__readSize()
354     types = 0
355     if size > 0:
356       types = self.__readUByte()
357     vtype = self.__getTType(types)
358     ktype = self.__getTType(types >> 4)
359     self.__containers.append(self.state)
360     self.state = CONTAINER_READ
361     return (ktype, vtype, size)
362
363   def readCollectionEnd(self):
364     assert self.state == CONTAINER_READ, self.state
365     self.state = self.__containers.pop()
366   readSetEnd = readCollectionEnd
367   readListEnd = readCollectionEnd
368   readMapEnd = readCollectionEnd
369
370   def readBool(self):
371     if self.state == BOOL_READ:
372       return self.__bool_value == CompactType.TRUE
373     elif self.state == CONTAINER_READ:
374       return self.__readByte() == CompactType.TRUE
375     else:
376       raise AssertionError("Invalid state in compact protocol: %d" %
377                            self.state)
378
379   readByte = reader(__readByte)
380   __readI16 = __readZigZag
381   readI16 = reader(__readZigZag)
382   readI32 = reader(__readZigZag)
383   readI64 = reader(__readZigZag)
384
385   @reader
386   def readDouble(self):
387     buff = self.trans.readAll(8)
388     val, = unpack('<d', buff)
389     return val
390
391   def __readString(self):
392     len = self.__readSize()
393     return self.trans.readAll(len)
394   readString = reader(__readString)
395
396   def __getTType(self, byte):
397     return TTYPES[byte & 0x0f]
398
399
400 class TCompactProtocolFactory:
401   def __init__(self):
402     pass
403
404   def getProtocol(self, trans):
405     return TCompactProtocol(trans)