Add unit tests for image comparison

This commit is contained in:
nicofrand 2018-01-19 13:53:53 +01:00 committed by Phyks (Lucas Verney)
parent 4b41e6de2d
commit 9fa2177087
11 changed files with 130 additions and 24 deletions

View File

@ -13,17 +13,15 @@ class MemoryCache(object):
""" """
A cache in memory. A cache in memory.
""" """
def __init__(self, on_miss):
"""
Constructor
:param on_miss: Function to call to retrieve item when not already @staticmethod
cached. def on_miss(key):
""" raise NotImplementedError
def __init__(self):
self.hits = 0 self.hits = 0
self.misses = 0 self.misses = 0
self.map = {} self.map = {}
self.on_miss = on_miss
def get(self, key): def get(self, key):
""" """
@ -77,11 +75,8 @@ class ImageCache(MemoryCache):
A cache for images, stored in memory. A cache for images, stored in memory.
""" """
@staticmethod @staticmethod
def retrieve_photo(url): def on_miss(url):
""" """
Helper to actually retrieve photos if not already cached. Helper to actually retrieve photos if not already cached.
""" """
return requests.get(url) return requests.get(url)
def __init__(self):
super(ImageCache, self).__init__(on_miss=ImageCache.retrieve_photo)

View File

@ -67,31 +67,51 @@ def get_or_compute_photo_hash(photo, photo_cache):
return photo["hash"] return photo["hash"]
def find_number_common_photos(flat1_photos, flat2_photos, photo_cache): def compare_photos(photo1, photo2, photo_cache, hash_threshold=10):
"""
Compares two photos with average hash method.
:param photo1: First photo url.
:param photo2: Second photo url.
:param photo_cache: An instance of ``ImageCache`` to use to cache images.
:param hash_thresold: The hash threshold between two images. Usually two
different photos have a hash difference of 30.
:return: ``True`` if the photos are identical, else ``False``.
"""
try:
hash1 = get_or_compute_photo_hash(photo1, photo_cache)
hash2 = get_or_compute_photo_hash(photo2, photo_cache)
return hash1 - hash2 < hash_threshold
except (IOError, requests.exceptions.RequestException):
return False
def find_number_common_photos(
flat1_photos,
flat2_photos,
photo_cache,
hash_threshold=10
):
""" """
Compute the number of common photos between the two lists of photos for the Compute the number of common photos between the two lists of photos for the
flats. flats.
Fetch the photos and compare them with dHash method. Fetch the photos and compare them with average hash method.
:param flat1_photos: First list of flat photos. Each photo should be a :param flat1_photos: First list of flat photos. Each photo should be a
``dict`` with (at least) a ``url`` key. ``dict`` with (at least) a ``url`` key.
:param flat2_photos: First list of flat photos. Each photo should be a :param flat2_photos: Second list of flat photos. Each photo should be a
``dict`` with (at least) a ``url`` key. ``dict`` with (at least) a ``url`` key.
:param photo_cache: An instance of ``ImageCache`` to use to cache images. :param photo_cache: An instance of ``ImageCache`` to use to cache images.
:param hash_thresold: The hash threshold between two images.
:return: The found number of common photos. :return: The found number of common photos.
""" """
n_common_photos = 0 n_common_photos = 0
for photo1, photo2 in itertools.product(flat1_photos, flat2_photos): for photo1, photo2 in itertools.product(flat1_photos, flat2_photos):
try: if compare_photos(photo1, photo2, photo_cache, hash_threshold):
hash1 = get_or_compute_photo_hash(photo1, photo_cache) n_common_photos += 1
hash2 = get_or_compute_photo_hash(photo2, photo_cache)
if hash1 - hash2 == 0:
n_common_photos += 1
except (IOError, requests.exceptions.RequestException):
pass
return n_common_photos return n_common_photos
@ -182,7 +202,7 @@ def detect(flats_list, key="id", merge=True, should_intersect=False):
return unique_flats_list, duplicate_flats return unique_flats_list, duplicate_flats
def get_duplicate_score(flat1, flat2, photo_cache): def get_duplicate_score(flat1, flat2, photo_cache, hash_threshold=10):
""" """
Compute the duplicate score between two flats. The higher the score, the Compute the duplicate score between two flats. The higher the score, the
more likely the two flats to be duplicates. more likely the two flats to be duplicates.
@ -190,6 +210,7 @@ def get_duplicate_score(flat1, flat2, photo_cache):
:param flat1: First flat dict. :param flat1: First flat dict.
:param flat2: Second flat dict. :param flat2: Second flat dict.
:param photo_cache: An instance of ``ImageCache`` to use to cache images. :param photo_cache: An instance of ``ImageCache`` to use to cache images.
:param hash_thresold: The hash threshold between two images.
:return: The duplicate score as ``int``. :return: The duplicate score as ``int``.
""" """
n_common_items = 0 n_common_items = 0
@ -314,7 +335,12 @@ def deep_detect(flats_list, config):
if flat2["id"] in matching_flats[flat1["id"]]: if flat2["id"] in matching_flats[flat1["id"]]:
continue continue
n_common_items = get_duplicate_score(flat1, flat2, photo_cache) n_common_items = get_duplicate_score(
flat1,
flat2,
photo_cache,
config["duplicate_image_hash_threshold"]
)
# Minimal score to consider they are duplicates # Minimal score to consider they are duplicates
if n_common_items >= config["duplicate_threshold"]: if n_common_items >= config["duplicate_threshold"]:

Binary file not shown.

After

Width:  |  Height:  |  Size: 122 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 114 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

View File

@ -9,6 +9,8 @@ import os
import random import random
import sys import sys
import unittest import unittest
import requests
import requests_mock
from flatisfy import tools from flatisfy import tools
from flatisfy.filters import duplicates from flatisfy.filters import duplicates
@ -19,6 +21,22 @@ LOGGER = logging.getLogger(__name__)
TESTS_DATA_DIR = os.path.dirname(os.path.realpath(__file__)) + "/test_files/" TESTS_DATA_DIR = os.path.dirname(os.path.realpath(__file__)) + "/test_files/"
class LocalImageCache(ImageCache):
"""
A local cache for images, stored in memory.
"""
@staticmethod
def on_miss(path):
"""
Helper to actually retrieve photos if not already cached.
"""
url = "mock://flatisfy" + path
with requests_mock.Mocker() as mock:
with open(path, "rb") as fh:
mock.get(url, content=fh.read())
return requests.get(url)
class TestTexts(unittest.TestCase): class TestTexts(unittest.TestCase):
""" """
Checks string normalizations. Checks string normalizations.
@ -118,6 +136,68 @@ class TestPhoneNumbers(unittest.TestCase):
) )
class TestPhotos(unittest.TestCase):
IMAGE_CACHE = LocalImageCache() # pylint: disable=invalid-name
def test_same_photo_twice(self):
"""
Compares a photo against itself.
"""
photo = {
"url": TESTS_DATA_DIR + "127028739@seloger.jpg"
}
self.assertTrue(duplicates.compare_photos(
photo,
photo,
TestPhotos.IMAGE_CACHE
))
def test_different_photos(self):
"""
Compares two different photos.
"""
self.assertFalse(duplicates.compare_photos(
{"url": TESTS_DATA_DIR + "127028739@seloger.jpg"},
{"url": TESTS_DATA_DIR + "127028739-2@seloger.jpg"},
TestPhotos.IMAGE_CACHE
))
self.assertFalse(duplicates.compare_photos(
{"url": TESTS_DATA_DIR + "127028739-2@seloger.jpg"},
{"url": TESTS_DATA_DIR + "127028739-3@seloger.jpg"},
TestPhotos.IMAGE_CACHE
))
def test_matching_photos(self):
"""
Compares two matching photos with different size and source.
"""
self.assertTrue(duplicates.compare_photos(
{"url": TESTS_DATA_DIR + "127028739@seloger.jpg"},
{"url": TESTS_DATA_DIR + "14428129@explorimmo.jpg"},
TestPhotos.IMAGE_CACHE
))
self.assertTrue(duplicates.compare_photos(
{"url": TESTS_DATA_DIR + "127028739-2@seloger.jpg"},
{"url": TESTS_DATA_DIR + "14428129-2@explorimmo.jpg"},
TestPhotos.IMAGE_CACHE
))
self.assertTrue(duplicates.compare_photos(
{"url": TESTS_DATA_DIR + "127028739-3@seloger.jpg"},
{"url": TESTS_DATA_DIR + "14428129-3@explorimmo.jpg"},
TestPhotos.IMAGE_CACHE
))
self.assertTrue(duplicates.compare_photos(
{"url": TESTS_DATA_DIR + "127028739@seloger.jpg"},
{"url": TESTS_DATA_DIR + "127028739-watermark@seloger.jpg"},
TestPhotos.IMAGE_CACHE
))
class TestDuplicates(unittest.TestCase): class TestDuplicates(unittest.TestCase):
""" """
Checks duplicates detection. Checks duplicates detection.
@ -286,5 +366,9 @@ def run():
suite = unittest.TestLoader().loadTestsFromTestCase(TestDuplicates) suite = unittest.TestLoader().loadTestsFromTestCase(TestDuplicates)
result = unittest.TextTestRunner(verbosity=2).run(suite) result = unittest.TextTestRunner(verbosity=2).run(suite)
assert result.wasSuccessful() assert result.wasSuccessful()
suite = unittest.TestLoader().loadTestsFromTestCase(TestPhotos)
result = unittest.TextTestRunner(verbosity=2).run(suite)
assert result.wasSuccessful()
except AssertionError: except AssertionError:
sys.exit(1) sys.exit(1)

View File

@ -10,6 +10,7 @@ future
imagehash imagehash
pillow pillow
requests requests
requests_mock
sqlalchemy sqlalchemy
titlecase titlecase
unidecode unidecode