Use cosine similarity as well in client code
This commit is contained in:
parent
40b7df9140
commit
bdf1f64577
2
bliss
2
bliss
@ -1 +1 @@
|
||||
Subproject commit 39fe24d79f2fa4e253a29e66f5f7c0c33e4b84d5
|
||||
Subproject commit 4781124ff177fa1eb649b9b74993a0ae28b1e501
|
@ -21,7 +21,7 @@ def main():
|
||||
cur = conn.cursor()
|
||||
|
||||
# Get cached distances from db
|
||||
cur.execute("SELECT song1, song2, distance FROM distances")
|
||||
cur.execute("SELECT song1, song2, distance, similarity FROM distances")
|
||||
cached_distances = cur.fetchall()
|
||||
|
||||
# Get all songs
|
||||
@ -47,20 +47,41 @@ def main():
|
||||
(song1["frequency"] - song2["frequency"])**2 +
|
||||
(song1["attack"] - song2["attack"])**2
|
||||
)
|
||||
logging.debug("Distance between %s and %s is %f." %
|
||||
(song1["filename"], song2["filename"], distance))
|
||||
similarity = (
|
||||
(song1["tempo"] * song2["tempo"] +
|
||||
song1["amplitude"] * song2["amplitude"] +
|
||||
song1["frequency"] * song2["frequency"] +
|
||||
song1["attack"] * song2["attack"]) /
|
||||
(
|
||||
math.sqrt(
|
||||
song1["tempo"]**2 +
|
||||
song1["amplitude"]**2 +
|
||||
song1["frequency"]**2 +
|
||||
song1["attack"]**2) *
|
||||
math.sqrt(
|
||||
song2["tempo"]**2 +
|
||||
song2["amplitude"]**2 +
|
||||
song2["frequency"]**2 +
|
||||
song2["attack"]**2)
|
||||
)
|
||||
)
|
||||
|
||||
logging.debug("Distance between %s and %s is (%f, %f)." %
|
||||
(song1["filename"], song2["filename"], distance,
|
||||
similarity))
|
||||
# Store distance in db cache
|
||||
try:
|
||||
logging.debug("Storing distance in database.")
|
||||
conn.execute(
|
||||
"INSERT INTO distances(song1, song2, distance) VALUES(?, ?, ?)",
|
||||
(song1["id"], song2["id"], distance))
|
||||
"INSERT INTO distances(song1, song2, distance, similarity) VALUES(?, ?, ?, ?)",
|
||||
(song1["id"], song2["id"], distance, similarity))
|
||||
conn.commit()
|
||||
# Update cached_distances list
|
||||
cached_distances.append({
|
||||
"song1": song1["id"],
|
||||
"song2": song2["id"],
|
||||
"distance": distance
|
||||
"distance": distance,
|
||||
"similarity": similarity
|
||||
})
|
||||
except sqlite3.IntegrityError:
|
||||
logging.warning("Unable to insert distance in database.")
|
||||
|
@ -8,10 +8,11 @@ import subprocess
|
||||
import sys
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# TODO: Replace mpc calls by libmpd2?
|
||||
|
||||
_QUEUE_LENGTH = 10
|
||||
# TODO: Use cosine similarity as well
|
||||
_DISTANCE_THRESHOLD = 4.0
|
||||
_SIMILARITY_THRESHOLD = 0.95
|
||||
|
||||
if "XDG_DATA_HOME" in os.environ:
|
||||
_MPDBLISS_DATA_HOME = os.path.expandvars("$XDG_DATA_HOME/mpdbliss")
|
||||
@ -28,6 +29,14 @@ def main():
|
||||
conn.execute('pragma foreign_keys=ON')
|
||||
cur = conn.cursor()
|
||||
|
||||
# Ensure random is not enabled
|
||||
status = subprocess.check_output(["mpc", "status"]).decode("utf-8")
|
||||
random = [x.split(":")[1].strip() == "on"
|
||||
for x in status.split("\n")[-2].split(" ")
|
||||
if x.startswith("random")][0]
|
||||
if random:
|
||||
logging.warning("Random mode is enabled. Are you sure you want it?")
|
||||
|
||||
# Take the last song from current playlist and iterate from it
|
||||
current_song = subprocess.check_output(
|
||||
["mpc", "playlist", '--format', '"%file%"'])
|
||||
@ -54,7 +63,7 @@ def main():
|
||||
mpd_queue.append(current_song_coords["filename"])
|
||||
# Get cached distances from db
|
||||
cur.execute(
|
||||
"SELECT id, filename, distance, tempo, amplitude, frequency, attack FROM (SELECT s2.id AS id, s2.filename AS filename, s2.tempo AS tempo, s2.amplitude AS amplitude, s2.frequency AS frequency, s2.attack AS attack, distances.distance AS distance FROM distances INNER JOIN songs AS s1 ON s1.id=distances.song1 INNER JOIN songs AS s2 on s2.id=distances.song2 WHERE s1.filename=? UNION SELECT s1.id as id, s1.filename AS filename, s1.tempo AS tempo, s1.amplitude AS amplitude, s1.frequency AS frequency, s1.attack AS attack, distances.distance as distance FROM distances INNER JOIN songs AS s1 ON s1.id=distances.song1 INNER JOIN songs AS s2 on s2.id=distances.song2 WHERE s2.filename=?) ORDER BY distance ASC",
|
||||
"SELECT id, filename, distance, similarity, tempo, amplitude, frequency, attack FROM (SELECT s2.id AS id, s2.filename AS filename, s2.tempo AS tempo, s2.amplitude AS amplitude, s2.frequency AS frequency, s2.attack AS attack, distances.distance AS distance, distances.similarity AS similarity FROM distances INNER JOIN songs AS s1 ON s1.id=distances.song1 INNER JOIN songs AS s2 on s2.id=distances.song2 WHERE s1.filename=? UNION SELECT s1.id as id, s1.filename AS filename, s1.tempo AS tempo, s1.amplitude AS amplitude, s1.frequency AS frequency, s1.attack AS attack, distances.distance as distance, distances.similarity AS similarity FROM distances INNER JOIN songs AS s1 ON s1.id=distances.song1 INNER JOIN songs AS s2 on s2.id=distances.song2 WHERE s2.filename=?) ORDER BY distance ASC",
|
||||
(current_song_coords["filename"], current_song_coords["filename"]))
|
||||
cached_distances = [row
|
||||
for row in cur.fetchall()
|
||||
@ -63,14 +72,16 @@ def main():
|
||||
|
||||
# If distance to closest song is ok, just add the song
|
||||
if len(cached_distances) > 0:
|
||||
if cached_distances[0]["distance"] < _DISTANCE_THRESHOLD:
|
||||
if(cached_distances[0]["distance"] < _DISTANCE_THRESHOLD and
|
||||
cached_distances[0]["similarity"] > _SIMILARITY_THRESHOLD):
|
||||
# Push it on the queue
|
||||
subprocess.check_call(["mpc", "add",
|
||||
cached_distances[0]["filename"]])
|
||||
# Continue using latest pushed song as current song
|
||||
logging.info("Using cached distance. Found %s. Distance is %f." %
|
||||
logging.info("Using cached distance. Found %s. Distance is (%f, %f)." %
|
||||
(cached_distances[0]["filename"],
|
||||
cached_distances[0]["distance"]))
|
||||
cached_distances[0]["distance"],
|
||||
cached_distances[0]["similarity"]))
|
||||
current_song_coords = cached_distances[0]
|
||||
continue
|
||||
|
||||
@ -91,41 +102,65 @@ def main():
|
||||
(current_song_coords["frequency"] - tmp_song_data["frequency"])**2 +
|
||||
(current_song_coords["attack"] - tmp_song_data["attack"])**2
|
||||
)
|
||||
logging.debug("Distance between %s and %s is %f." %
|
||||
similarity = (
|
||||
(current_song_coords["tempo"] * tmp_song_data["tempo"] +
|
||||
current_song_coords["amplitude"] * tmp_song_data["amplitude"] +
|
||||
current_song_coords["frequency"] * tmp_song_data["frequency"] +
|
||||
current_song_coords["attack"] * tmp_song_data["attack"]) /
|
||||
(
|
||||
math.sqrt(
|
||||
current_song_coords["tempo"]**2 +
|
||||
current_song_coords["amplitude"]**2 +
|
||||
current_song_coords["frequency"]**2 +
|
||||
current_song_coords["attack"]**2) *
|
||||
math.sqrt(
|
||||
tmp_song_data["tempo"]**2 +
|
||||
tmp_song_data["amplitude"]**2 +
|
||||
tmp_song_data["frequency"]**2 +
|
||||
tmp_song_data["attack"]**2)
|
||||
)
|
||||
)
|
||||
logging.debug("Distance between %s and %s is (%f, %f)." %
|
||||
(current_song_coords["filename"],
|
||||
tmp_song_data["filename"], distance))
|
||||
tmp_song_data["filename"], distance, similarity))
|
||||
# Store distance in db cache
|
||||
try:
|
||||
logging.debug("Storing distance in database.")
|
||||
conn.execute(
|
||||
"INSERT INTO distances(song1, song2, distance) VALUES(?, ?, ?)",
|
||||
(current_song_coords["id"], tmp_song_data["id"], distance))
|
||||
"INSERT INTO distances(song1, song2, distance, similarity) VALUES(?, ?, ?)",
|
||||
(current_song_coords["id"], tmp_song_data["id"], distance,
|
||||
similarity))
|
||||
conn.commit()
|
||||
except sqlite3.IntegrityError:
|
||||
logging.warning("Unable to insert distance in database.")
|
||||
conn.rollback()
|
||||
|
||||
# Update the closest song
|
||||
if closest_song is None or distance < closest_song[1]:
|
||||
closest_song = (tmp_song_data, distance)
|
||||
# TODO: Find a better heuristic?
|
||||
if closest_song is None or (distance < closest_song[1] and
|
||||
similarity > closest_song[2]):
|
||||
closest_song = (tmp_song_data, distance, similarity)
|
||||
|
||||
# If distance is ok, break from the loop
|
||||
if distance < _DISTANCE_THRESHOLD:
|
||||
if(distance < _DISTANCE_THRESHOLD and
|
||||
similarity > _SIMILARITY_THRESHOLD):
|
||||
break
|
||||
|
||||
# If a close enough song is found
|
||||
if distance < _DISTANCE_THRESHOLD:
|
||||
if(distance < _DISTANCE_THRESHOLD and
|
||||
similarity > _SIMILARITY_THRESHOLD):
|
||||
# Push it on the queue
|
||||
subprocess.check_call(["mpc", "add", tmp_song_data["filename"]])
|
||||
# Continue using latest pushed song as current song
|
||||
logging.info("Found a close song: %s. Distance is %f." %
|
||||
(tmp_song_data["filename"], distance))
|
||||
logging.info("Found a close song: %s. Distance is (%f, %f)." %
|
||||
(tmp_song_data["filename"], distance, similarity))
|
||||
current_song_coords = tmp_song_data
|
||||
continue
|
||||
# If no song found, take the closest one
|
||||
else:
|
||||
logging.info("No close enough song found. Using %s. Distance is %f." %
|
||||
(closest_song[0]["filename"], closest_song[1]))
|
||||
logging.info("No close enough song found. Using %s. Distance is (%f, %f)." %
|
||||
(closest_song[0]["filename"], closest_song[1],
|
||||
closest_song[2]))
|
||||
current_song_coords = closest_song[0]
|
||||
subprocess.check_call(["mpc", "add", closest_song[0]["filename"]])
|
||||
continue
|
||||
|
@ -62,6 +62,7 @@ int _init_db(char *data_folder, char* db_path)
|
||||
song1 INTEGER, \
|
||||
song2 INTEGER, \
|
||||
distance REAL, \
|
||||
similarity REAL, \
|
||||
FOREIGN KEY(song1) REFERENCES songs(id) ON DELETE CASCADE, \
|
||||
FOREIGN KEY(song2) REFERENCES songs(id) ON DELETE CASCADE, \
|
||||
UNIQUE (song1, song2))",
|
||||
|
Loading…
x
Reference in New Issue
Block a user