flatisfy/flatisfy/database/__init__.py

65 lines
1.8 KiB
Python

# coding: utf-8
"""
This module contains functions related to the database.
"""
from __future__ import absolute_import, print_function, unicode_literals
import sqlite3
from contextlib import contextmanager
from sqlalchemy import event, create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
import flatisfy.models.flat # noqa: F401
from flatisfy.database.base import BASE
@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, _):
"""
Auto enable foreign keys for SQLite.
"""
# Play well with other DB backends
if isinstance(dbapi_connection, sqlite3.Connection):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
def init_db(database_uri=None):
"""
Initialize the database, ensuring tables exist etc.
:param database_uri: An URI describing an engine to use. Defaults to
in-memory SQLite database.
:return: A tuple of an SQLAlchemy session maker and the created engine.
"""
if database_uri is None:
database_uri = "sqlite:///:memory:"
engine = create_engine(database_uri)
BASE.metadata.create_all(engine, checkfirst=True)
Session = sessionmaker(bind=engine) # pylint: disable=invalid-name
@contextmanager
def get_session():
"""
Provide a transactional scope around a series of operations.
From [1].
[1]: http://docs.sqlalchemy.org/en/latest/orm/session_basics.html#when-do-i-construct-a-session-when-do-i-commit-it-and-when-do-i-close-it.
"""
session = Session()
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()
return get_session