Bind Bliss to MPD.

client.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. #!/usr/bin/env python3
  2. """
  3. This is a client for MPD to generate a random playlist starting from the last
  4. song of the current playlist and iterating using values computed using Bliss.
  5. MPD connection settings are taken from environment variables, following MPD_HOST
  6. and MPD_PORT scheme described in `mpc` man.
  7. You can pass an integer argument to the script to change the length of the
  8. generated playlist (default is to add 20 songs).
  9. """
  10. import logging
  11. import math
  12. import os
  13. import random
  14. import sqlite3
  15. import socket
  16. import sys
  17. import mpd
  18. class PersistentMPDClient(mpd.MPDClient):
  19. """
  20. From
  21. https://github.com/schamp/PersistentMPDClient/blob/master/PersistentMPDClient.py
  22. """
  23. def __init__(self, socket=None, host=None, port=None):
  24. super().__init__()
  25. self.socket = socket
  26. self.host = host
  27. self.port = port
  28. self.do_connect()
  29. # get list of available commands from client
  30. self.command_list = self.commands()
  31. # commands not to intercept
  32. self.command_blacklist = ['ping']
  33. # wrap all valid MPDClient functions
  34. # in a ping-connection-retry wrapper
  35. for cmd in self.command_list:
  36. if cmd not in self.command_blacklist:
  37. if hasattr(super(PersistentMPDClient, self), cmd):
  38. super_fun = super(PersistentMPDClient, self).__getattribute__(cmd)
  39. new_fun = self.try_cmd(super_fun)
  40. setattr(self, cmd, new_fun)
  41. # create a wrapper for a function (such as an MPDClient
  42. # member function) that will verify a connection (and
  43. # reconnect if necessary) before executing that function.
  44. # functions wrapped in this way should always succeed
  45. # (if the server is up)
  46. # we ping first because we don't want to retry the same
  47. # function if there's a failure, we want to use the noop
  48. # to check connectivity
  49. def try_cmd(self, cmd_fun):
  50. def fun(*pargs, **kwargs):
  51. try:
  52. self.ping()
  53. except (mpd.ConnectionError, OSError):
  54. self.do_connect()
  55. return cmd_fun(*pargs, **kwargs)
  56. return fun
  57. # needs a name that does not collide with parent connect() function
  58. def do_connect(self):
  59. try:
  60. try:
  61. self.disconnect()
  62. # if it's a TCP connection, we'll get a socket error
  63. # if we try to disconnect when the connection is lost
  64. except mpd.ConnectionError:
  65. pass
  66. # if it's a socket connection, we'll get a BrokenPipeError
  67. # if we try to disconnect when the connection is lost
  68. # but we have to retry the disconnect, because we'll get
  69. # an "Already connected" error if we don't.
  70. # the second one should succeed.
  71. except BrokenPipeError:
  72. try:
  73. self.disconnect()
  74. except:
  75. print("Second disconnect failed, yikes.")
  76. if self.socket:
  77. self.connect(self.socket, None)
  78. else:
  79. self.connect(self.host, self.port)
  80. except socket.error:
  81. print("Connection refused.")
  82. logging.basicConfig(level=logging.INFO)
  83. _QUEUE_LENGTH = 20
  84. _DISTANCE_THRESHOLD = 4.0
  85. _SIMILARITY_THRESHOLD = 0.95
  86. if "XDG_DATA_HOME" in os.environ:
  87. _BLISSIFY_DATA_HOME = os.path.expandvars("$XDG_DATA_HOME/blissify")
  88. else:
  89. _BLISSIFY_DATA_HOME = os.path.expanduser("~/.local/share/blissify")
  90. def main(queue_length):
  91. # Get MPD connection settings
  92. try:
  93. mpd_host = os.environ["MPD_HOST"]
  94. try:
  95. mpd_password, mpd_host = mpd_host.split("@")
  96. except ValueError:
  97. mpd_password = None
  98. except KeyError:
  99. mpd_host = "localhost"
  100. mpd_password = None
  101. try:
  102. mpd_port = os.environ["MPD_PORT"]
  103. except KeyError:
  104. mpd_port = 6600
  105. # Connect to MPD
  106. client = PersistentMPDClient(host=mpd_host, port=mpd_port)
  107. if mpd_password is not None:
  108. client.password(mpd_password)
  109. # Connect to db
  110. db_path = os.path.join(_BLISSIFY_DATA_HOME, "db.sqlite3")
  111. logging.debug("Using DB path: %s." % (db_path,))
  112. conn = sqlite3.connect(db_path)
  113. conn.row_factory = sqlite3.Row
  114. conn.execute('pragma foreign_keys=ON')
  115. cur = conn.cursor()
  116. # Ensure random is not enabled
  117. status = client.status()
  118. if int(status["random"]) != 0:
  119. logging.warning("Random mode is enabled. Are you sure you want it?")
  120. # Take the last song from current playlist and iterate from it
  121. playlist = client.playlist()
  122. if len(playlist) > 0:
  123. current_song = playlist[-1].replace("file:", "").strip()
  124. # If current playlist is empty
  125. else:
  126. # Add a random song to start with
  127. all_songs = [x["file"] for x in client.listall() if "file" in x]
  128. current_song = random.choice(all_songs)
  129. client.add(current_song)
  130. logging.info("Currently played song is %s." % (current_song,))
  131. # Get current song coordinates
  132. cur.execute("SELECT id, tempo1, tempo2, tempo3, amplitude, frequency, attack, filename FROM songs WHERE filename=?", (current_song,))
  133. current_song_coords = cur.fetchone()
  134. if current_song_coords is None:
  135. logging.error("Current song %s is not in db. You should update the db." %
  136. (current_song,))
  137. client.close()
  138. client.disconnect()
  139. sys.exit(1)
  140. for i in range(queue_length):
  141. # Get cached distances from db
  142. cur.execute(
  143. "SELECT id, filename, distance, similarity, tempo1, tempo2, tempo3, amplitude, frequency, attack FROM (SELECT s2.id AS id, s2.filename AS filename, s2.tempo1 AS tempo1, s2.tempo2 AS tempo2, s2.tempo3 AS tempo3, s2.amplitude AS amplitude, s2.frequency AS frequency, s2.attack AS attack, distances.distance AS distance, distances.similarity AS similarity FROM distances INNER JOIN songs AS s1 ON s1.id=distances.song1 INNER JOIN songs AS s2 on s2.id=distances.song2 WHERE s1.filename=? UNION SELECT s1.id as id, s1.filename AS filename, s1.tempo1 AS tempo1, s1.tempo2 AS tempo2, s1.tempo3 AS tempo3, s1.amplitude AS amplitude, s1.frequency AS frequency, s1.attack AS attack, distances.distance as distance, distances.similarity AS similarity FROM distances INNER JOIN songs AS s1 ON s1.id=distances.song1 INNER JOIN songs AS s2 on s2.id=distances.song2 WHERE s2.filename=?) ORDER BY distance ASC",
  144. (current_song_coords["filename"], current_song_coords["filename"]))
  145. cached_distances = [row
  146. for row in cur.fetchall()
  147. if ("file: %s" % (row["filename"],)) not in client.playlist()]
  148. cached_distances_songs = [i["filename"] for i in cached_distances]
  149. # Keep track of closest song
  150. if cached_distances:
  151. closest_song = (cached_distances[0],
  152. cached_distances[0]["distance"],
  153. cached_distances[1]["similarity"])
  154. else:
  155. closest_song = None
  156. # Get the songs close enough
  157. cached_distances_close_enough = [
  158. row for row in cached_distances
  159. if row["distance"] < _DISTANCE_THRESHOLD and row["similarity"] > _SIMILARITY_THRESHOLD ]
  160. if len(cached_distances_close_enough) > 0:
  161. # If there are some close enough songs in the cache
  162. random_close_enough = random.choice(cached_distances_close_enough)
  163. # Push it on the queue
  164. client.add(random_close_enough["filename"])
  165. # Continue using latest pushed song as current song
  166. logging.info("Using cached distance. Found %s. Distance is (%f, %f)." %
  167. (random_close_enough["filename"],
  168. random_close_enough["distance"],
  169. random_close_enough["similarity"]))
  170. current_song_coords = random_close_enough
  171. continue
  172. # Get all other songs coordinates and iterate randomly on them
  173. cur.execute("SELECT id, tempo1, tempo2, tempo3, amplitude, frequency, attack, filename FROM songs ORDER BY RANDOM()")
  174. for tmp_song_data in cur.fetchall():
  175. if(tmp_song_data["filename"] == current_song_coords["filename"] or
  176. tmp_song_data["filename"] in cached_distances_songs or
  177. ("file: %s" % (tmp_song_data["filename"],)) in client.playlist()):
  178. # Skip current song and already processed songs
  179. logging.debug("Skipping %s." % (tmp_song_data["filename"]))
  180. continue
  181. # Compute distance
  182. distance = math.sqrt(
  183. (current_song_coords["tempo1"] - tmp_song_data["tempo1"])**2 +
  184. (current_song_coords["tempo2"] - tmp_song_data["tempo2"])**2 +
  185. (current_song_coords["tempo3"] - tmp_song_data["tempo3"])**2 +
  186. (current_song_coords["amplitude"] - tmp_song_data["amplitude"])**2 +
  187. (current_song_coords["frequency"] - tmp_song_data["frequency"])**2 +
  188. (current_song_coords["attack"] - tmp_song_data["attack"])**2
  189. )
  190. similarity = (
  191. (current_song_coords["tempo1"] * tmp_song_data["tempo1"] +
  192. current_song_coords["tempo2"] * tmp_song_data["tempo2"] +
  193. current_song_coords["tempo3"] * tmp_song_data["tempo3"] +
  194. current_song_coords["amplitude"] * tmp_song_data["amplitude"] +
  195. current_song_coords["frequency"] * tmp_song_data["frequency"] +
  196. current_song_coords["attack"] * tmp_song_data["attack"]) /
  197. (
  198. math.sqrt(
  199. current_song_coords["tempo1"]**2 +
  200. current_song_coords["tempo2"]**2 +
  201. current_song_coords["tempo3"]**2 +
  202. current_song_coords["amplitude"]**2 +
  203. current_song_coords["frequency"]**2 +
  204. current_song_coords["attack"]**2) *
  205. math.sqrt(
  206. tmp_song_data["tempo1"]**2 +
  207. tmp_song_data["tempo2"]**2 +
  208. tmp_song_data["tempo3"]**2 +
  209. tmp_song_data["amplitude"]**2 +
  210. tmp_song_data["frequency"]**2 +
  211. tmp_song_data["attack"]**2)
  212. )
  213. )
  214. logging.debug("Distance between %s and %s is (%f, %f)." %
  215. (current_song_coords["filename"],
  216. tmp_song_data["filename"], distance, similarity))
  217. # Store distance in db cache
  218. try:
  219. logging.debug("Storing distance in database.")
  220. conn.execute(
  221. "INSERT INTO distances(song1, song2, distance, similarity) VALUES(?, ?, ?, ?)",
  222. (current_song_coords["id"], tmp_song_data["id"], distance,
  223. similarity))
  224. conn.commit()
  225. except sqlite3.IntegrityError:
  226. logging.warning("Unable to insert distance in database.")
  227. conn.rollback()
  228. # Update the closest song
  229. if closest_song is None or distance < closest_song[1]:
  230. closest_song = (tmp_song_data, distance, similarity)
  231. # If distance is ok, break from the loop
  232. if(distance < _DISTANCE_THRESHOLD and
  233. similarity > _SIMILARITY_THRESHOLD):
  234. break
  235. # If a close enough song is found
  236. if(distance < _DISTANCE_THRESHOLD and
  237. similarity > _SIMILARITY_THRESHOLD):
  238. # Push it on the queue
  239. client.add(tmp_song_data["filename"])
  240. # Continue using latest pushed song as current song
  241. logging.info("Found a close song: %s. Distance is (%f, %f)." %
  242. (tmp_song_data["filename"], distance, similarity))
  243. current_song_coords = tmp_song_data
  244. continue
  245. # If no song found, take the closest one
  246. else:
  247. logging.info("No close enough song found. Using %s. Distance is (%f, %f)." %
  248. (closest_song[0]["filename"], closest_song[1],
  249. closest_song[2]))
  250. current_song_coords = closest_song[0]
  251. client.add(closest_song[0]["filename"])
  252. continue
  253. conn.close()
  254. client.close()
  255. client.disconnect()
  256. if __name__ == "__main__":
  257. queue_length = _QUEUE_LENGTH
  258. if len(sys.argv) > 1:
  259. try:
  260. queue_length = int(sys.argv[1])
  261. except ValueError:
  262. sys.exit("Usage: %s [PLAYLIST_LENGTH]" % (sys.argv[0],))
  263. main(queue_length)