import logging
import os
from sqlalchemy import Column, DateTime, Integer
from sqlalchemy import create_engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import ColumnProperty, Query, class_mapper, sessionmaker
from sqlalchemy.sql import func
from wrapt import decorator
from . import types
log = logging.getLogger(__name__)
[docs]class Query(Query):
pass
echo = False
engine = create_engine('sqlite://', convert_unicode=True, echo=True)
Session = sessionmaker(
autoflush=False,
expire_on_commit=False,
bind=engine,
query_cls=Query
)
[docs]class txn(object):
def __init__(self, session_class=None):
self.session_class = session_class or Session
def __call__(self, func):
def wrapped(*args, **kwargs):
with self:
return func(self.session, *args, **kwargs)
return wrapped
def __enter__(self):
self.session = self.session_class()
return self.session
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val:
self.session.rollback()
if isinstance(exc_val, Exception):
raise
else:
raise exc_type(exc_val)
else:
self.session.commit()
self.session.close()
@decorator
[docs]def with_txn(function, instance, args, kwargs):
with txn() as session:
return function(session, *args, **kwargs)
class Base(object):
id = Column(Integer, primary_key=True)
created_at = Column(DateTime, nullable=False, default=func.now())
def upsert_model(self, id_key='id'):
return self.upsert([self], id_key=id_key)
@classmethod
def columns(cls):
return [prop for prop in class_mapper(cls).iterate_properties
if isinstance(prop, ColumnProperty)]
@classmethod
def upsert_no_txn(cls, session, models, id_key='id'):
ids_to_look_for = [getattr(model, id_key) for model in models]
found_models = cls.find_all_no_txn(session, ids_to_look_for,
id_key=id_key)
id_to_model = {getattr(model, id_key): model for model in found_models}
try:
key_type = type(next(iter(id_to_model.keys())))
except:
key_type = int
for model in models:
model_id = key_type(getattr(model, id_key))
if model_id in id_to_model:
# Get the primary key in a better way here
model.id = id_to_model[model_id].id
session.merge(model)
else:
id_to_model[model_id] = model
session.add(model)
log.info(id_to_model)
return id_to_model
@classmethod
def upsert_one_no_txn(cls, session, model, **kwargs):
return next(iter(cls.upsert_no_txn(session, [model], **kwargs).values()))
upsert = with_txn(upsert_no_txn)
@classmethod
def safe_upsert(cls, *args, **kwargs):
# TODO(imalison): use a retry decorator here.
try:
return cls.upsert(*args, **kwargs)
except IntegrityError:
return cls.upsert(*args, **kwargs)
@classmethod
def upsert_okc(cls, model, **kwargs):
return next(iter(cls.safe_upsert([model], id_key='okc_id').values()))
@classmethod
def find_all_no_txn(cls, session, identifiers, id_key='id'):
return cls.find_query(session, identifiers, id_key=id_key).all()
find_all = with_txn(find_all_no_txn)
@classmethod
def find_no_txn(cls, session, identifier, id_key='id'):
return cls.find_query(session, [identifier], id_key=id_key).one()
find = with_txn(find_no_txn)
@classmethod
def query_no_txn(cls, session, *args, **kwargs):
if not kwargs and len(args) == 1 and isinstance(args[0], (int, str)):
args = cls.id == args[0],
return cls.build_query(session, *args, **kwargs).all()
query = with_txn(query_no_txn)
@classmethod
def find_query(cls, session, identifiers, id_key='id'):
return session.query(cls).filter(getattr(cls, id_key).in_(identifiers))
@classmethod
def build_query(cls, session, *args, **kwargs):
return session.query(cls).filter(*args).filter_by(**kwargs)
Base = declarative_base(engine, cls=Base)
[docs]class OKCBase(Base):
__abstract__ = True
okc_id = Column(types.StringBackedInteger, nullable=False, unique=True)
[docs]def reset_engine(engine):
Session.configure(bind=engine)
Base.metadata.bind = engine
return engine
[docs]def set_sqlite_db_file(file_path):
return reset_engine(create_engine('sqlite:///{0}'.format(file_path),
convert_unicode=True,
echo=echo))
database_uri = os.path.join(os.path.dirname(__file__), 'okcupyd.db')
set_sqlite_db_file(database_uri)