# coding: utf-8 """ This module contains functions related to the database. """ from __future__ import absolute_import, print_function, unicode_literals import sqlite3 from contextlib import contextmanager from sqlalchemy import event, create_engine from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker import flatisfy.models.flat # noqa: F401 from flatisfy.database.base import BASE @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, _): """ Auto enable foreign keys for SQLite. """ # Play well with other DB backends if isinstance(dbapi_connection, sqlite3.Connection): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() def init_db(database_uri=None): """ Initialize the database, ensuring tables exist etc. :param database_uri: An URI describing an engine to use. Defaults to in-memory SQLite database. :return: A tuple of an SQLAlchemy session maker and the created engine. """ if database_uri is None: database_uri = "sqlite:///:memory:" engine = create_engine(database_uri) BASE.metadata.create_all(engine, checkfirst=True) Session = sessionmaker(bind=engine) # pylint: disable=invalid-name @contextmanager def get_session(): """ Provide a transactional scope around a series of operations. From [1]. [1]: http://docs.sqlalchemy.org/en/latest/orm/session_basics.html#when-do-i-construct-a-session-when-do-i-commit-it-and-when-do-i-close-it. """ session = Session() try: yield session session.commit() except: session.rollback() raise finally: session.close() return get_session