Bind Bliss to MPD.

client.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. #!/usr/bin/env python3
  2. """
  3. This is a client for MPD to generate a random playlist starting from the last
  4. song of the current playlist and iterating using values computed using Bliss.
  5. MPD connection settings are taken from environment variables, following MPD_HOST
  6. and MPD_PORT scheme described in `mpc` man.
  7. You can pass an integer argument to the script to change the length of the
  8. generated playlist (default is to add 20 songs).
  9. """
  10. import logging
  11. import math
  12. import os
  13. import random
  14. import sqlite3
  15. import socket
  16. import sys
  17. import mpd
  18. # logging.basicConfig(level='DEBUG')
  19. class PersistentMPDClient(mpd.MPDClient):
  20. """
  21. From
  22. https://github.com/schamp/PersistentMPDClient/blob/master/PersistentMPDClient.py
  23. """
  24. def __init__(self, socket=None, host=None, port=None):
  25. super().__init__()
  26. self.socket = socket
  27. self.host = host
  28. self.port = port
  29. self.do_connect()
  30. # get list of available commands from client
  31. self.command_list = self.commands()
  32. # commands not to intercept
  33. self.command_blacklist = ['ping']
  34. # wrap all valid MPDClient functions
  35. # in a ping-connection-retry wrapper
  36. for cmd in self.command_list:
  37. if cmd not in self.command_blacklist:
  38. if hasattr(super(PersistentMPDClient, self), cmd):
  39. super_fun = super(PersistentMPDClient, self).__getattribute__(cmd)
  40. new_fun = self.try_cmd(super_fun)
  41. setattr(self, cmd, new_fun)
  42. # create a wrapper for a function (such as an MPDClient
  43. # member function) that will verify a connection (and
  44. # reconnect if necessary) before executing that function.
  45. # functions wrapped in this way should always succeed
  46. # (if the server is up)
  47. # we ping first because we don't want to retry the same
  48. # function if there's a failure, we want to use the noop
  49. # to check connectivity
  50. def try_cmd(self, cmd_fun):
  51. def fun(*pargs, **kwargs):
  52. try:
  53. self.ping()
  54. except (mpd.ConnectionError, OSError):
  55. self.do_connect()
  56. return cmd_fun(*pargs, **kwargs)
  57. return fun
  58. # needs a name that does not collide with parent connect() function
  59. def do_connect(self):
  60. try:
  61. try:
  62. self.disconnect()
  63. # if it's a TCP connection, we'll get a socket error
  64. # if we try to disconnect when the connection is lost
  65. except mpd.ConnectionError:
  66. pass
  67. # if it's a socket connection, we'll get a BrokenPipeError
  68. # if we try to disconnect when the connection is lost
  69. # but we have to retry the disconnect, because we'll get
  70. # an "Already connected" error if we don't.
  71. # the second one should succeed.
  72. except BrokenPipeError:
  73. try:
  74. self.disconnect()
  75. except:
  76. print("Second disconnect failed, yikes.")
  77. if self.socket:
  78. self.connect(self.socket, None)
  79. else:
  80. self.connect(self.host, self.port)
  81. except socket.error:
  82. print("Connection refused.")
  83. logging.basicConfig(level=logging.INFO)
  84. _QUEUE_LENGTH = 20
  85. _DISTANCE_THRESHOLD = 4.0
  86. _SIMILARITY_THRESHOLD = 0.95
  87. if "XDG_DATA_HOME" in os.environ:
  88. _BLISSIFY_DATA_HOME = os.path.expandvars("$XDG_DATA_HOME/blissify")
  89. else:
  90. _BLISSIFY_DATA_HOME = os.path.expanduser("~/.local/share/blissify")
  91. def main(queue_length):
  92. # Get MPD connection settings
  93. try:
  94. mpd_host = os.environ["MPD_HOST"]
  95. try:
  96. mpd_password, mpd_host = mpd_host.split("@")
  97. except ValueError:
  98. mpd_password = None
  99. except KeyError:
  100. mpd_host = "localhost"
  101. mpd_password = None
  102. try:
  103. mpd_port = os.environ["MPD_PORT"]
  104. except KeyError:
  105. mpd_port = 6600
  106. # Connect to MPD
  107. client = PersistentMPDClient(host=mpd_host, port=mpd_port)
  108. if mpd_password is not None:
  109. client.password(mpd_password)
  110. # Connect to db
  111. db_path = os.path.join(_BLISSIFY_DATA_HOME, "db.sqlite3")
  112. logging.debug("Using DB path: %s." % (db_path,))
  113. conn = sqlite3.connect(db_path)
  114. conn.row_factory = sqlite3.Row
  115. conn.execute('pragma foreign_keys=ON')
  116. cur = conn.cursor()
  117. # Ensure random is not enabled
  118. status = client.status()
  119. if int(status["random"]) != 0:
  120. logging.warning("Random mode is enabled. Are you sure you want it?")
  121. # Take the last song from current playlist and iterate from it
  122. playlist = client.playlist()
  123. if len(playlist) > 0:
  124. current_song = playlist[-1].replace("file:", "").strip()
  125. # If current playlist is empty
  126. else:
  127. # Add a random song to start with
  128. all_songs = [x["file"] for x in client.listall() if "file" in x]
  129. current_song = random.choice(all_songs)
  130. client.add(current_song)
  131. logging.info("Currently played song is %s." % (current_song,))
  132. # Get current song coordinates
  133. cur.execute("SELECT id, tempo, amplitude, frequency, attack, filename FROM songs WHERE filename=?", (current_song,))
  134. current_song_coords = cur.fetchone()
  135. if current_song_coords is None:
  136. logging.error("Current song %s is not in db. You should update the db." %
  137. (current_song,))
  138. client.close()
  139. client.disconnect()
  140. sys.exit(1)
  141. for i in range(queue_length):
  142. # Get cached distances from db
  143. cur.execute(
  144. "SELECT id, filename, distance, similarity, tempo, amplitude, frequency, attack FROM (SELECT s2.id AS id, s2.filename AS filename, s2.tempo AS tempo, s2.amplitude AS amplitude, s2.frequency AS frequency, s2.attack AS attack, distances.distance AS distance, distances.similarity AS similarity FROM distances INNER JOIN songs AS s1 ON s1.id=distances.song1 INNER JOIN songs AS s2 on s2.id=distances.song2 WHERE s1.filename=? UNION SELECT s1.id as id, s1.filename AS filename, s1.tempo AS tempo, s1.amplitude AS amplitude, s1.frequency AS frequency, s1.attack AS attack, distances.distance as distance, distances.similarity AS similarity FROM distances INNER JOIN songs AS s1 ON s1.id=distances.song1 INNER JOIN songs AS s2 on s2.id=distances.song2 WHERE s2.filename=?) ORDER BY distance ASC",
  145. (current_song_coords["filename"], current_song_coords["filename"]))
  146. cached_distances = [row
  147. for row in cur.fetchall()
  148. if ("file: %s" % (row["filename"],)) not in client.playlist()]
  149. cached_distances_songs = [i["filename"] for i in cached_distances]
  150. # Keep track of closest song
  151. if cached_distances:
  152. closest_song = (cached_distances[0],
  153. cached_distances[0]["distance"],
  154. cached_distances[1]["similarity"])
  155. else:
  156. closest_song = None
  157. # Get the songs close enough
  158. cached_distances_close_enough = [
  159. row for row in cached_distances
  160. if row["distance"] < _DISTANCE_THRESHOLD and row["similarity"] > _SIMILARITY_THRESHOLD ]
  161. if len(cached_distances_close_enough) > 0:
  162. # If there are some close enough songs in the cache
  163. random_close_enough = random.choice(cached_distances_close_enough)
  164. # Push it on the queue
  165. client.add(random_close_enough["filename"])
  166. # Continue using latest pushed song as current song
  167. logging.info("Using cached distance. Found %s. Distance is (%f, %f)." %
  168. (random_close_enough["filename"],
  169. random_close_enough["distance"],
  170. random_close_enough["similarity"]))
  171. current_song_coords = random_close_enough
  172. continue
  173. # Get all other songs coordinates and iterate randomly on them
  174. cur.execute("SELECT id, tempo, amplitude, frequency, attack, filename FROM songs ORDER BY RANDOM()")
  175. for tmp_song_data in cur.fetchall():
  176. if(tmp_song_data["filename"] == current_song_coords["filename"] or
  177. tmp_song_data["filename"] in cached_distances_songs or
  178. ("file: %s" % (tmp_song_data["filename"],)) in client.playlist()):
  179. # Skip current song and already processed songs
  180. logging.debug("Skipping %s." % (tmp_song_data["filename"]))
  181. continue
  182. # Compute distance
  183. distance = math.sqrt(
  184. (current_song_coords["tempo"] - tmp_song_data["tempo"])**2 +
  185. (current_song_coords["amplitude"] - tmp_song_data["amplitude"])**2 +
  186. (current_song_coords["frequency"] - tmp_song_data["frequency"])**2 +
  187. (current_song_coords["attack"] - tmp_song_data["attack"])**2
  188. )
  189. similarity = (
  190. (current_song_coords["tempo"] * tmp_song_data["tempo"] +
  191. current_song_coords["amplitude"] * tmp_song_data["amplitude"] +
  192. current_song_coords["frequency"] * tmp_song_data["frequency"] +
  193. current_song_coords["attack"] * tmp_song_data["attack"]) /
  194. (
  195. math.sqrt(
  196. current_song_coords["tempo"]**2 +
  197. current_song_coords["amplitude"]**2 +
  198. current_song_coords["frequency"]**2 +
  199. current_song_coords["attack"]**2) *
  200. math.sqrt(
  201. tmp_song_data["tempo"]**2 +
  202. tmp_song_data["amplitude"]**2 +
  203. tmp_song_data["frequency"]**2 +
  204. tmp_song_data["attack"]**2)
  205. )
  206. )
  207. logging.debug("Distance between %s and %s is (%f, %f)." %
  208. (current_song_coords["filename"],
  209. tmp_song_data["filename"], distance, similarity))
  210. # Store distance in db cache
  211. try:
  212. logging.debug("Storing distance in database.")
  213. conn.execute(
  214. "INSERT INTO distances(song1, song2, distance, similarity) VALUES(?, ?, ?, ?)",
  215. (current_song_coords["id"], tmp_song_data["id"], distance,
  216. similarity))
  217. conn.commit()
  218. except sqlite3.IntegrityError:
  219. logging.warning("Unable to insert distance in database.")
  220. conn.rollback()
  221. # Update the closest song
  222. if closest_song is None or distance < closest_song[1]:
  223. closest_song = (tmp_song_data, distance, similarity)
  224. # If distance is ok, break from the loop
  225. if(distance < _DISTANCE_THRESHOLD and
  226. similarity > _SIMILARITY_THRESHOLD):
  227. break
  228. # If a close enough song is found
  229. if(distance < _DISTANCE_THRESHOLD and
  230. similarity > _SIMILARITY_THRESHOLD):
  231. # Push it on the queue
  232. client.add(tmp_song_data["filename"])
  233. # Continue using latest pushed song as current song
  234. logging.info("Found a close song: %s. Distance is (%f, %f)." %
  235. (tmp_song_data["filename"], distance, similarity))
  236. current_song_coords = tmp_song_data
  237. continue
  238. # If no song found, take the closest one
  239. else:
  240. logging.info("No close enough song found. Using %s. Distance is (%f, %f)." %
  241. (closest_song[0]["filename"], closest_song[1],
  242. closest_song[2]))
  243. current_song_coords = closest_song[0]
  244. client.add(closest_song[0]["filename"])
  245. continue
  246. conn.close()
  247. client.close()
  248. client.disconnect()
  249. if __name__ == "__main__":
  250. queue_length = _QUEUE_LENGTH
  251. if len(sys.argv) > 1:
  252. try:
  253. queue_length = int(sys.argv[1])
  254. except ValueError:
  255. sys.exit("Usage: %s [PLAYLIST_LENGTH]" % (sys.argv[0],))
  256. main(queue_length)