278 lines
9.4 KiB
Python
278 lines
9.4 KiB
Python
|
import math
|
||
|
import hashlib
|
||
|
from struct import unpack, pack, calcsize
|
||
|
|
||
|
try:
|
||
|
import bitarray
|
||
|
except ImportError:
|
||
|
raise ImportError('pybloom requires bitarray >= 0.3.4')
|
||
|
|
||
|
__version__ = '2.0'
|
||
|
__author__ = "Jay Baird <jay.baird@me.com>, Bob Ippolito <bob@redivi.com>,\
|
||
|
Marius Eriksen <marius@monkey.org>,\
|
||
|
Alex Brasetvik <alex@brasetvik.com>"
|
||
|
|
||
|
def make_hashfuncs(num_slices, num_bits):
|
||
|
if num_bits >= (1 << 31):
|
||
|
fmt_code, chunk_size = 'Q', 8
|
||
|
elif num_bits >= (1 << 15):
|
||
|
fmt_code, chunk_size = 'I', 4
|
||
|
else:
|
||
|
fmt_code, chunk_size = 'H', 2
|
||
|
total_hash_bits = 8 * num_slices * chunk_size
|
||
|
if total_hash_bits > 384:
|
||
|
hashfn = hashlib.sha512
|
||
|
elif total_hash_bits > 256:
|
||
|
hashfn = hashlib.sha384
|
||
|
elif total_hash_bits > 160:
|
||
|
hashfn = hashlib.sha256
|
||
|
elif total_hash_bits > 128:
|
||
|
hashfn = hashlib.sha1
|
||
|
else:
|
||
|
hashfn = hashlib.md5
|
||
|
fmt = fmt_code * (hashfn().digest_size // chunk_size)
|
||
|
num_salts, extra = divmod(num_slices, len(fmt))
|
||
|
if extra:
|
||
|
num_salts += 1
|
||
|
salts = [hashfn(hashfn(pack('I', i)).digest()) for i in range(num_salts)]
|
||
|
def _make_hashfuncs(key):
|
||
|
key = str(key)
|
||
|
rval = []
|
||
|
for salt in salts:
|
||
|
h = salt.copy()
|
||
|
h.update(key.encode('utf-8'))
|
||
|
rval.extend(uint % num_bits for uint in unpack(fmt, h.digest()))
|
||
|
del rval[num_slices:]
|
||
|
return rval
|
||
|
return _make_hashfuncs
|
||
|
|
||
|
|
||
|
class BloomFilter(object):
|
||
|
FILE_FMT = '<dQQQQ'
|
||
|
|
||
|
def __init__(self, capacity, error_rate=0.001):
|
||
|
if not (0 < error_rate < 1):
|
||
|
raise ValueError("Error_Rate must be between 0 and 1.")
|
||
|
if not capacity > 0:
|
||
|
raise ValueError("Capacity must be > 0")
|
||
|
num_slices = int(math.ceil(math.log(1.0 / error_rate, 2)))
|
||
|
bits_per_slice = int(math.ceil(
|
||
|
(capacity * abs(math.log(error_rate))) /
|
||
|
(num_slices * (math.log(2) ** 2))))
|
||
|
self._setup(error_rate, num_slices, bits_per_slice, capacity, 0)
|
||
|
self.bitarray = bitarray.bitarray(self.num_bits, endian='little')
|
||
|
self.bitarray.setall(False)
|
||
|
|
||
|
def _setup(self, error_rate, num_slices, bits_per_slice, capacity, count):
|
||
|
self.error_rate = error_rate
|
||
|
self.num_slices = num_slices
|
||
|
self.bits_per_slice = bits_per_slice
|
||
|
self.capacity = capacity
|
||
|
self.num_bits = num_slices * bits_per_slice
|
||
|
self.count = count
|
||
|
self.make_hashes = make_hashfuncs(self.num_slices, self.bits_per_slice)
|
||
|
|
||
|
def __contains__(self, key):
|
||
|
bits_per_slice = self.bits_per_slice
|
||
|
bitarray = self.bitarray
|
||
|
if not isinstance(key, list):
|
||
|
hashes = self.make_hashes(key)
|
||
|
else:
|
||
|
hashes = key
|
||
|
offset = 0
|
||
|
for k in hashes:
|
||
|
if not bitarray[offset + k]:
|
||
|
return False
|
||
|
offset += bits_per_slice
|
||
|
return True
|
||
|
|
||
|
def __len__(self):
|
||
|
"""Return the number of keys stored by this bloom filter."""
|
||
|
return self.count
|
||
|
|
||
|
def add(self, key, skip_check=False):
|
||
|
bitarray = self.bitarray
|
||
|
bits_per_slice = self.bits_per_slice
|
||
|
hashes = self.make_hashes(key)
|
||
|
if not skip_check and hashes in self:
|
||
|
return True
|
||
|
if self.count > self.capacity:
|
||
|
raise IndexError("BloomFilter is at capacity")
|
||
|
offset = 0
|
||
|
for k in hashes:
|
||
|
self.bitarray[offset + k] = True
|
||
|
offset += bits_per_slice
|
||
|
self.count += 1
|
||
|
return False
|
||
|
|
||
|
def copy(self):
|
||
|
"""Return a copy of this bloom filter.
|
||
|
"""
|
||
|
new_filter = BloomFilter(self.capacity, self.error_rate)
|
||
|
new_filter.bitarray = self.bitarray.copy()
|
||
|
return new_filter
|
||
|
|
||
|
def union(self, other):
|
||
|
""" Calculates the union of the two underlying bitarrays and returns
|
||
|
a new bloom filter object."""
|
||
|
if self.capacity != other.capacity or \
|
||
|
self.error_rate != other.error_rate:
|
||
|
raise ValueError("Unioning filters requires both filters to have \
|
||
|
both the same capacity and error rate")
|
||
|
new_bloom = self.copy()
|
||
|
new_bloom.bitarray = new_bloom.bitarray | other.bitarray
|
||
|
return new_bloom
|
||
|
|
||
|
def __or__(self, other):
|
||
|
return self.union(other)
|
||
|
|
||
|
def intersection(self, other):
|
||
|
""" Calculates the intersection of the two underlying bitarrays and returns
|
||
|
a new bloom filter object."""
|
||
|
if self.capacity != other.capacity or \
|
||
|
self.error_rate != other.error_rate:
|
||
|
raise ValueError("Intersecting filters requires both filters to \
|
||
|
have equal capacity and error rate")
|
||
|
new_bloom = self.copy()
|
||
|
new_bloom.bitarray = new_bloom.bitarray & other.bitarray
|
||
|
return new_bloom
|
||
|
|
||
|
def __and__(self, other):
|
||
|
return self.intersection(other)
|
||
|
|
||
|
def tofile(self, f):
|
||
|
"""Write the bloom filter to file object `f'. Underlying bits
|
||
|
are written as machine values. This is much more space
|
||
|
efficient than pickling the object."""
|
||
|
f.write(pack(self.FILE_FMT, self.error_rate, self.num_slices,
|
||
|
self.bits_per_slice, self.capacity, self.count))
|
||
|
self.bitarray.tofile(f)
|
||
|
|
||
|
@classmethod
|
||
|
def fromfile(cls, f, n=-1):
|
||
|
"""Read a bloom filter from file-object `f' serialized with
|
||
|
``BloomFilter.tofile''. If `n' > 0 read only so many bytes."""
|
||
|
headerlen = calcsize(cls.FILE_FMT)
|
||
|
|
||
|
if 0 < n < headerlen:
|
||
|
raise ValueError('n too small!')
|
||
|
|
||
|
filter = cls(1) # Bogus instantiation, we will `_setup'.
|
||
|
filter._setup(*unpack(cls.FILE_FMT, f.read(headerlen)))
|
||
|
filter.bitarray = bitarray.bitarray(endian='little')
|
||
|
if n > 0:
|
||
|
filter.bitarray.fromfile(f, n - headerlen)
|
||
|
else:
|
||
|
filter.bitarray.fromfile(f)
|
||
|
if filter.num_bits != filter.bitarray.length() and \
|
||
|
(filter.num_bits + (8 - filter.num_bits % 8)
|
||
|
!= filter.bitarray.length()):
|
||
|
raise ValueError('Bit length mismatch!')
|
||
|
|
||
|
return filter
|
||
|
|
||
|
def __getstate__(self):
|
||
|
d = self.__dict__.copy()
|
||
|
del d['make_hashes']
|
||
|
return d
|
||
|
|
||
|
def __setstate__(self, d):
|
||
|
self.__dict__.update(d)
|
||
|
self.make_hashes = make_hashfuncs(self.num_slices, self.bits_per_slice)
|
||
|
|
||
|
class ScalableBloomFilter(object):
|
||
|
SMALL_SET_GROWTH = 2 # slower, but takes up less memory
|
||
|
LARGE_SET_GROWTH = 4 # faster, but takes up more memory faster
|
||
|
FILE_FMT = '<idQd'
|
||
|
|
||
|
def __init__(self, initial_capacity=100, error_rate=0.001,
|
||
|
mode=SMALL_SET_GROWTH):
|
||
|
if not error_rate or error_rate < 0:
|
||
|
raise ValueError("Error_Rate must be a decimal less than 0.")
|
||
|
self._setup(mode, 0.9, initial_capacity, error_rate)
|
||
|
self.filters = []
|
||
|
|
||
|
def _setup(self, mode, ratio, initial_capacity, error_rate):
|
||
|
self.scale = mode
|
||
|
self.ratio = ratio
|
||
|
self.initial_capacity = initial_capacity
|
||
|
self.error_rate = error_rate
|
||
|
|
||
|
def __contains__(self, key):
|
||
|
for f in reversed(self.filters):
|
||
|
if key in f:
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
def add(self, key):
|
||
|
if key in self:
|
||
|
return True
|
||
|
if not self.filters:
|
||
|
filter = BloomFilter(
|
||
|
capacity=self.initial_capacity,
|
||
|
error_rate=self.error_rate * (1.0 - self.ratio))
|
||
|
self.filters.append(filter)
|
||
|
else:
|
||
|
filter = self.filters[-1]
|
||
|
if filter.count >= filter.capacity:
|
||
|
filter = BloomFilter(
|
||
|
capacity=filter.capacity * self.scale,
|
||
|
error_rate=filter.error_rate * self.ratio)
|
||
|
self.filters.append(filter)
|
||
|
filter.add(key, skip_check=True)
|
||
|
return False
|
||
|
|
||
|
@property
|
||
|
def capacity(self):
|
||
|
"""Returns the total capacity for all filters in this SBF"""
|
||
|
return sum([f.capacity for f in self.filters])
|
||
|
|
||
|
@property
|
||
|
def count(self):
|
||
|
return len(self)
|
||
|
|
||
|
def tofile(self, f):
|
||
|
"""Serialize this ScalableBloomFilter into the file-object
|
||
|
`f'."""
|
||
|
f.write(pack(self.FILE_FMT, self.scale, self.ratio,
|
||
|
self.initial_capacity, self.error_rate))
|
||
|
|
||
|
# Write #-of-filters
|
||
|
f.write(pack('<l', len(self.filters)))
|
||
|
|
||
|
if len(self.filters) > 0:
|
||
|
# Then each filter directly, with a header describing
|
||
|
# their lengths.
|
||
|
headerpos = f.tell()
|
||
|
headerfmt = '<' + 'Q'*(len(self.filters))
|
||
|
f.write('.' * calcsize(headerfmt))
|
||
|
filter_sizes = []
|
||
|
for filter in self.filters:
|
||
|
begin = f.tell()
|
||
|
filter.tofile(f)
|
||
|
filter_sizes.append(f.tell() - begin)
|
||
|
|
||
|
f.seek(headerpos)
|
||
|
f.write(pack(headerfmt, *filter_sizes))
|
||
|
|
||
|
@classmethod
|
||
|
def fromfile(cls, f):
|
||
|
"""Deserialize the ScalableBloomFilter in file object `f'."""
|
||
|
filter = cls()
|
||
|
filter._setup(*unpack(cls.FILE_FMT, f.read(calcsize(cls.FILE_FMT))))
|
||
|
nfilters, = unpack('<l', f.read(calcsize('<l')))
|
||
|
if nfilters > 0:
|
||
|
header_fmt = '<' + 'Q'*nfilters
|
||
|
bytes = f.read(calcsize(header_fmt))
|
||
|
filter_lengths = unpack(header_fmt, bytes)
|
||
|
for fl in filter_lengths:
|
||
|
filter.filters.append(BloomFilter.fromfile(f, fl))
|
||
|
else:
|
||
|
filter.filters = []
|
||
|
|
||
|
return filter
|
||
|
|
||
|
def __len__(self):
|
||
|
"""Returns the total number of elements stored in this SBF"""
|
||
|
return sum([f.count for f in self.filters])
|