Use cosine similarity as well in client code

This commit is contained in:
Lucas Verney 2016-05-25 00:11:39 +02:00
parent 40b7df9140
commit bdf1f64577
4 changed files with 81 additions and 24 deletions

2
bliss

@ -1 +1 @@
Subproject commit 39fe24d79f2fa4e253a29e66f5f7c0c33e4b84d5 Subproject commit 4781124ff177fa1eb649b9b74993a0ae28b1e501

View File

@ -21,7 +21,7 @@ def main():
cur = conn.cursor() cur = conn.cursor()
# Get cached distances from db # Get cached distances from db
cur.execute("SELECT song1, song2, distance FROM distances") cur.execute("SELECT song1, song2, distance, similarity FROM distances")
cached_distances = cur.fetchall() cached_distances = cur.fetchall()
# Get all songs # Get all songs
@ -47,20 +47,41 @@ def main():
(song1["frequency"] - song2["frequency"])**2 + (song1["frequency"] - song2["frequency"])**2 +
(song1["attack"] - song2["attack"])**2 (song1["attack"] - song2["attack"])**2
) )
logging.debug("Distance between %s and %s is %f." % similarity = (
(song1["filename"], song2["filename"], distance)) (song1["tempo"] * song2["tempo"] +
song1["amplitude"] * song2["amplitude"] +
song1["frequency"] * song2["frequency"] +
song1["attack"] * song2["attack"]) /
(
math.sqrt(
song1["tempo"]**2 +
song1["amplitude"]**2 +
song1["frequency"]**2 +
song1["attack"]**2) *
math.sqrt(
song2["tempo"]**2 +
song2["amplitude"]**2 +
song2["frequency"]**2 +
song2["attack"]**2)
)
)
logging.debug("Distance between %s and %s is (%f, %f)." %
(song1["filename"], song2["filename"], distance,
similarity))
# Store distance in db cache # Store distance in db cache
try: try:
logging.debug("Storing distance in database.") logging.debug("Storing distance in database.")
conn.execute( conn.execute(
"INSERT INTO distances(song1, song2, distance) VALUES(?, ?, ?)", "INSERT INTO distances(song1, song2, distance, similarity) VALUES(?, ?, ?, ?)",
(song1["id"], song2["id"], distance)) (song1["id"], song2["id"], distance, similarity))
conn.commit() conn.commit()
# Update cached_distances list # Update cached_distances list
cached_distances.append({ cached_distances.append({
"song1": song1["id"], "song1": song1["id"],
"song2": song2["id"], "song2": song2["id"],
"distance": distance "distance": distance,
"similarity": similarity
}) })
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
logging.warning("Unable to insert distance in database.") logging.warning("Unable to insert distance in database.")

View File

@ -8,10 +8,11 @@ import subprocess
import sys import sys
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# TODO: Replace mpc calls by libmpd2?
_QUEUE_LENGTH = 10 _QUEUE_LENGTH = 10
# TODO: Use cosine similarity as well
_DISTANCE_THRESHOLD = 4.0 _DISTANCE_THRESHOLD = 4.0
_SIMILARITY_THRESHOLD = 0.95
if "XDG_DATA_HOME" in os.environ: if "XDG_DATA_HOME" in os.environ:
_MPDBLISS_DATA_HOME = os.path.expandvars("$XDG_DATA_HOME/mpdbliss") _MPDBLISS_DATA_HOME = os.path.expandvars("$XDG_DATA_HOME/mpdbliss")
@ -28,6 +29,14 @@ def main():
conn.execute('pragma foreign_keys=ON') conn.execute('pragma foreign_keys=ON')
cur = conn.cursor() cur = conn.cursor()
# Ensure random is not enabled
status = subprocess.check_output(["mpc", "status"]).decode("utf-8")
random = [x.split(":")[1].strip() == "on"
for x in status.split("\n")[-2].split(" ")
if x.startswith("random")][0]
if random:
logging.warning("Random mode is enabled. Are you sure you want it?")
# Take the last song from current playlist and iterate from it # Take the last song from current playlist and iterate from it
current_song = subprocess.check_output( current_song = subprocess.check_output(
["mpc", "playlist", '--format', '"%file%"']) ["mpc", "playlist", '--format', '"%file%"'])
@ -54,7 +63,7 @@ def main():
mpd_queue.append(current_song_coords["filename"]) mpd_queue.append(current_song_coords["filename"])
# Get cached distances from db # Get cached distances from db
cur.execute( cur.execute(
"SELECT id, filename, distance, tempo, amplitude, frequency, attack FROM (SELECT s2.id AS id, s2.filename AS filename, s2.tempo AS tempo, s2.amplitude AS amplitude, s2.frequency AS frequency, s2.attack AS attack, distances.distance AS distance FROM distances INNER JOIN songs AS s1 ON s1.id=distances.song1 INNER JOIN songs AS s2 on s2.id=distances.song2 WHERE s1.filename=? UNION SELECT s1.id as id, s1.filename AS filename, s1.tempo AS tempo, s1.amplitude AS amplitude, s1.frequency AS frequency, s1.attack AS attack, distances.distance as distance FROM distances INNER JOIN songs AS s1 ON s1.id=distances.song1 INNER JOIN songs AS s2 on s2.id=distances.song2 WHERE s2.filename=?) ORDER BY distance ASC", "SELECT id, filename, distance, similarity, tempo, amplitude, frequency, attack FROM (SELECT s2.id AS id, s2.filename AS filename, s2.tempo AS tempo, s2.amplitude AS amplitude, s2.frequency AS frequency, s2.attack AS attack, distances.distance AS distance, distances.similarity AS similarity FROM distances INNER JOIN songs AS s1 ON s1.id=distances.song1 INNER JOIN songs AS s2 on s2.id=distances.song2 WHERE s1.filename=? UNION SELECT s1.id as id, s1.filename AS filename, s1.tempo AS tempo, s1.amplitude AS amplitude, s1.frequency AS frequency, s1.attack AS attack, distances.distance as distance, distances.similarity AS similarity FROM distances INNER JOIN songs AS s1 ON s1.id=distances.song1 INNER JOIN songs AS s2 on s2.id=distances.song2 WHERE s2.filename=?) ORDER BY distance ASC",
(current_song_coords["filename"], current_song_coords["filename"])) (current_song_coords["filename"], current_song_coords["filename"]))
cached_distances = [row cached_distances = [row
for row in cur.fetchall() for row in cur.fetchall()
@ -63,14 +72,16 @@ def main():
# If distance to closest song is ok, just add the song # If distance to closest song is ok, just add the song
if len(cached_distances) > 0: if len(cached_distances) > 0:
if cached_distances[0]["distance"] < _DISTANCE_THRESHOLD: if(cached_distances[0]["distance"] < _DISTANCE_THRESHOLD and
cached_distances[0]["similarity"] > _SIMILARITY_THRESHOLD):
# Push it on the queue # Push it on the queue
subprocess.check_call(["mpc", "add", subprocess.check_call(["mpc", "add",
cached_distances[0]["filename"]]) cached_distances[0]["filename"]])
# Continue using latest pushed song as current song # Continue using latest pushed song as current song
logging.info("Using cached distance. Found %s. Distance is %f." % logging.info("Using cached distance. Found %s. Distance is (%f, %f)." %
(cached_distances[0]["filename"], (cached_distances[0]["filename"],
cached_distances[0]["distance"])) cached_distances[0]["distance"],
cached_distances[0]["similarity"]))
current_song_coords = cached_distances[0] current_song_coords = cached_distances[0]
continue continue
@ -91,41 +102,65 @@ def main():
(current_song_coords["frequency"] - tmp_song_data["frequency"])**2 + (current_song_coords["frequency"] - tmp_song_data["frequency"])**2 +
(current_song_coords["attack"] - tmp_song_data["attack"])**2 (current_song_coords["attack"] - tmp_song_data["attack"])**2
) )
logging.debug("Distance between %s and %s is %f." % similarity = (
(current_song_coords["tempo"] * tmp_song_data["tempo"] +
current_song_coords["amplitude"] * tmp_song_data["amplitude"] +
current_song_coords["frequency"] * tmp_song_data["frequency"] +
current_song_coords["attack"] * tmp_song_data["attack"]) /
(
math.sqrt(
current_song_coords["tempo"]**2 +
current_song_coords["amplitude"]**2 +
current_song_coords["frequency"]**2 +
current_song_coords["attack"]**2) *
math.sqrt(
tmp_song_data["tempo"]**2 +
tmp_song_data["amplitude"]**2 +
tmp_song_data["frequency"]**2 +
tmp_song_data["attack"]**2)
)
)
logging.debug("Distance between %s and %s is (%f, %f)." %
(current_song_coords["filename"], (current_song_coords["filename"],
tmp_song_data["filename"], distance)) tmp_song_data["filename"], distance, similarity))
# Store distance in db cache # Store distance in db cache
try: try:
logging.debug("Storing distance in database.") logging.debug("Storing distance in database.")
conn.execute( conn.execute(
"INSERT INTO distances(song1, song2, distance) VALUES(?, ?, ?)", "INSERT INTO distances(song1, song2, distance, similarity) VALUES(?, ?, ?)",
(current_song_coords["id"], tmp_song_data["id"], distance)) (current_song_coords["id"], tmp_song_data["id"], distance,
similarity))
conn.commit() conn.commit()
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
logging.warning("Unable to insert distance in database.") logging.warning("Unable to insert distance in database.")
conn.rollback() conn.rollback()
# Update the closest song # Update the closest song
if closest_song is None or distance < closest_song[1]: # TODO: Find a better heuristic?
closest_song = (tmp_song_data, distance) if closest_song is None or (distance < closest_song[1] and
similarity > closest_song[2]):
closest_song = (tmp_song_data, distance, similarity)
# If distance is ok, break from the loop # If distance is ok, break from the loop
if distance < _DISTANCE_THRESHOLD: if(distance < _DISTANCE_THRESHOLD and
similarity > _SIMILARITY_THRESHOLD):
break break
# If a close enough song is found # If a close enough song is found
if distance < _DISTANCE_THRESHOLD: if(distance < _DISTANCE_THRESHOLD and
similarity > _SIMILARITY_THRESHOLD):
# Push it on the queue # Push it on the queue
subprocess.check_call(["mpc", "add", tmp_song_data["filename"]]) subprocess.check_call(["mpc", "add", tmp_song_data["filename"]])
# Continue using latest pushed song as current song # Continue using latest pushed song as current song
logging.info("Found a close song: %s. Distance is %f." % logging.info("Found a close song: %s. Distance is (%f, %f)." %
(tmp_song_data["filename"], distance)) (tmp_song_data["filename"], distance, similarity))
current_song_coords = tmp_song_data current_song_coords = tmp_song_data
continue continue
# If no song found, take the closest one # If no song found, take the closest one
else: else:
logging.info("No close enough song found. Using %s. Distance is %f." % logging.info("No close enough song found. Using %s. Distance is (%f, %f)." %
(closest_song[0]["filename"], closest_song[1])) (closest_song[0]["filename"], closest_song[1],
closest_song[2]))
current_song_coords = closest_song[0] current_song_coords = closest_song[0]
subprocess.check_call(["mpc", "add", closest_song[0]["filename"]]) subprocess.check_call(["mpc", "add", closest_song[0]["filename"]])
continue continue

View File

@ -62,6 +62,7 @@ int _init_db(char *data_folder, char* db_path)
song1 INTEGER, \ song1 INTEGER, \
song2 INTEGER, \ song2 INTEGER, \
distance REAL, \ distance REAL, \
similarity REAL, \
FOREIGN KEY(song1) REFERENCES songs(id) ON DELETE CASCADE, \ FOREIGN KEY(song1) REFERENCES songs(id) ON DELETE CASCADE, \
FOREIGN KEY(song2) REFERENCES songs(id) ON DELETE CASCADE, \ FOREIGN KEY(song2) REFERENCES songs(id) ON DELETE CASCADE, \
UNIQUE (song1, song2))", UNIQUE (song1, song2))",