import logging
import warnings
import threading

from sqlalchemy.sql import and_, expression
from sqlalchemy.sql.expression import ClauseElement
from sqlalchemy.schema import Column, Index
from sqlalchemy import func, select, false
from sqlalchemy.schema import Table as SQLATable
from sqlalchemy.exc import NoSuchTableError

from dataset.types import Types
from dataset.util import normalize_column_name, index_name, ensure_tuple
from dataset.util import DatasetException, ResultIter, QUERY_STEP
from dataset.util import normalize_table_name, pad_chunk_columns

log = logging.getLogger(__name__)

[docs]class Table(object): """Represents a table in a database and exposes common operations.""" PRIMARY_DEFAULT = 'id' def __init__(self, database, table_name, primary_id=None, primary_type=None, auto_create=False): """Initialise the table from database schema.""" self.db = database = normalize_table_name(table_name) self._table = None self._indexes = [] self._primary_id = primary_id self._primary_type = primary_type self._auto_create = auto_create @property def exists(self): """Check to see if the table currently exists in the database.""" if self._table is not None: return True return in self.db @property def table(self): """Get a reference to the table, which may be reflected or created.""" if self._table is None: self._sync_table(()) return self._table @property def columns(self): """Get a listing of all columns that exist in the table.""" if not self.exists: return [] return self.table.columns.keys() def has_column(self, column): """Check if a column with the given name exists on this table.""" return normalize_column_name(column) in self.columns
[docs] def insert(self, row, ensure=None, types=None): """Add a ``row`` dict by inserting it into the table. If ``ensure`` is set, any of the keys of the row are not table columns, they will be created automatically. During column creation, ``types`` will be checked for a key matching the name of a column to be created, and the given SQLAlchemy column type will be used. Otherwise, the type is guessed from the row value, defaulting to a simple unicode field. :: data = dict(title='I am a banana!') table.insert(data) Returns the inserted row's primary key. """ row = self._sync_columns(row, ensure, types=types) res = self.db.executable.execute(self.table.insert(row)) if len(res.inserted_primary_key) > 0: return res.inserted_primary_key[0] return True
[docs] def insert_ignore(self, row, keys, ensure=None, types=None): """Add a ``row`` dict into the table if the row does not exist. If rows with matching ``keys`` exist they will be added to the table. Setting ``ensure`` results in automatically creating missing columns, i.e., keys of the row are not table columns. During column creation, ``types`` will be checked for a key matching the name of a column to be created, and the given SQLAlchemy column type will be used. Otherwise, the type is guessed from the row value, defaulting to a simple unicode field. :: data = dict(id=10, title='I am a banana!') table.insert_ignore(data, ['id']) """ row = self._sync_columns(row, ensure, types=types) if self._check_ensure(ensure): self.create_index(keys) args, _ = self._keys_to_args(row, keys) if self.count(**args) == 0: return self.insert(row, ensure=False) return False
[docs] def insert_many(self, rows, chunk_size=1000, ensure=None, types=None): """Add many rows at a time. This is significantly faster than adding them one by one. Per default the rows are processed in chunks of 1000 per commit, unless you specify a different ``chunk_size``. See :py:meth:`insert() <dataset.Table.insert>` for details on the other parameters. :: rows = [dict(name='Dolly')] * 10000 table.insert_many(rows) """ chunk = [] for row in rows: row = self._sync_columns(row, ensure, types=types) chunk.append(row) if len(chunk) == chunk_size: chunk = pad_chunk_columns(chunk) self.table.insert().execute(chunk) chunk = [] if len(chunk): chunk = pad_chunk_columns(chunk) self.table.insert().execute(chunk)
[docs] def update(self, row, keys, ensure=None, types=None, return_count=False): """Update a row in the table. The update is managed via the set of column names stated in ``keys``: they will be used as filters for the data to be updated, using the values in ``row``. :: # update all entries with id matching 10, setting their title columns data = dict(id=10, title='I am a banana!') table.update(data, ['id']) If keys in ``row`` update columns not present in the table, they will be created based on the settings of ``ensure`` and ``types``, matching the behavior of :py:meth:`insert() <dataset.Table.insert>`. """ row = self._sync_columns(row, ensure, types=types) args, row = self._keys_to_args(row, keys) clause = self._args_to_clause(args) if not len(row): return self.count(clause) stmt = self.table.update(whereclause=clause, values=row) rp = self.db.executable.execute(stmt) if rp.supports_sane_rowcount(): return rp.rowcount if return_count: return self.count(clause)
[docs] def upsert(self, row, keys, ensure=None, types=None): """An UPSERT is a smart combination of insert and update. If rows with matching ``keys`` exist they will be updated, otherwise a new row is inserted in the table. :: data = dict(id=10, title='I am a banana!') table.upsert(data, ['id']) """ row = self._sync_columns(row, ensure, types=types) if self._check_ensure(ensure): self.create_index(keys) row_count = self.update(row, keys, ensure=False, return_count=True) if row_count == 0: return self.insert(row, ensure=False) return True
[docs] def delete(self, *clauses, **filters): """Delete rows from the table. Keyword arguments can be used to add column-based filters. The filter criterion will always be equality: :: table.delete(place='Berlin') If no arguments are given, all records are deleted. """ if not self.exists: return False clause = self._args_to_clause(filters, clauses=clauses) stmt = self.table.delete(whereclause=clause) rp = self.db.executable.execute(stmt) return rp.rowcount > 0
def _reflect_table(self): """Load the tables definition from the database.""" with self.db.lock: try: self._table = SQLATable(, self.db.metadata, schema=self.db.schema, autoload=True) except NoSuchTableError: pass def _threading_warn(self): if self.db.in_transaction and threading.active_count() > 1: warnings.warn("Changing the database schema inside a transaction " "in a multi-threaded environment is likely to lead " "to race conditions and synchronization issues.", RuntimeWarning) def _sync_table(self, columns): """Lazy load, create or adapt the table structure in the database.""" if self._table is None: # Load an existing table from the database. self._reflect_table() if self._table is None: # Create the table with an initial set of columns. if not self._auto_create: raise DatasetException("Table does not exist: %s" % # Keep the lock scope small because this is run very often. with self.db.lock: self._threading_warn() self._table = SQLATable(, self.db.metadata, schema=self.db.schema) if self._primary_id is not False: # This can go wrong on DBMS like MySQL and SQLite where # tables cannot have no columns. primary_id = self._primary_id or self.PRIMARY_DEFAULT primary_type = self._primary_type or Types.integer increment = primary_type in [Types.integer, Types.bigint] column = Column(primary_id, primary_type, primary_key=True, autoincrement=increment) self._table.append_column(column) for column in columns: if not == self._primary_id: self._table.append_column(column) self._table.create(self.db.executable, checkfirst=True) elif len(columns): with self.db.lock: self._reflect_table() self._threading_warn() for column in columns: if not self.has_column( self.db.op.add_column(, column, self.db.schema) self._reflect_table() def _sync_columns(self, row, ensure, types=None): """Create missing columns (or the table) prior to writes. If automatic schema generation is disabled (``ensure`` is ``False``), this will remove any keys from the ``row`` for which there is no matching column. """ columns = self.columns ensure = self._check_ensure(ensure) types = types or {} types = {normalize_column_name(k): v for (k, v) in types.items()} out = {} sync_columns = [] for name, value in row.items(): name = normalize_column_name(name) if ensure and name not in columns: _type = types.get(name) if _type is None: _type = self.db.types.guess(value) sync_columns.append(Column(name, _type)) columns.append(name) if name in columns: out[name] = value self._sync_table(sync_columns) return out def _check_ensure(self, ensure): if ensure is None: return self.db.ensure_schema return ensure def _args_to_clause(self, args, clauses=()): clauses = list(clauses) for column, value in args.items(): if not self.has_column(column): clauses.append(false()) elif isinstance(value, (list, tuple)): clauses.append(self.table.c[column].in_(value)) elif isinstance(value, dict): key = list(value.keys())[0] if key in ('like',): clauses.append(self.table.c[column].like(value[key])) elif key in ('>', 'gt'): clauses.append(self.table.c[column] > value[key]) elif key in ('<', 'lt'): clauses.append(self.table.c[column] < value[key]) elif key in ('>=', 'gte'): clauses.append(self.table.c[column] >= value[key]) elif key in ('<=', 'lte'): clauses.append(self.table.c[column] <= value[key]) elif key in ('!=', '<>', 'not'): clauses.append(self.table.c[column] != value[key]) elif key in ('between', '..'): clauses.append(self.table.c[column].between(value[key][0], value[key][1])) else: clauses.append(false()) else: clauses.append(self.table.c[column] == value) return and_(*clauses) def _args_to_order_by(self, order_by): orderings = [] for ordering in ensure_tuple(order_by): if ordering is None: continue column = ordering.lstrip('-') if column not in self.table.columns: continue if ordering.startswith('-'): orderings.append(self.table.c[column].desc()) else: orderings.append(self.table.c[column].asc()) return orderings def _keys_to_args(self, row, keys): keys = ensure_tuple(keys) keys = [normalize_column_name(k) for k in keys] # keys = [self.has_column(k) for k in keys] row = row.copy() args = {k: row.pop(k) for k in keys if k in row} return args, row
[docs] def create_column(self, name, type): """Create a new column ``name`` of a specified type. :: table.create_column('created_at', db.types.datetime) `type` corresponds to an SQLAlchemy type as described by `dataset.db.Types` """ name = normalize_column_name(name) if self.has_column(name): log.debug("Column exists: %s" % name) return self._sync_table((Column(name, type),))
def create_column_by_example(self, name, value): """ Explicitly create a new column ``name`` with a type that is appropriate to store the given example ``value``. The type is guessed in the same way as for the insert method with ``ensure=True``. :: table.create_column_by_example('length', 4.2) If a column of the same name already exists, no action is taken, even if it is not of the type we would have created. """ type_ = self.db.types.guess(value) self.create_column(name, type_)
[docs] def drop_column(self, name): """Drop the column ``name``. :: table.drop_column('created_at') """ if == 'sqlite': raise RuntimeError("SQLite does not support dropping columns.") name = normalize_column_name(name) with self.db.lock: if not self.exists or not self.has_column(name): log.debug("Column does not exist: %s", name) return self._threading_warn() self.db.op.drop_column(, name, self.table.schema ) self._reflect_table()
[docs] def drop(self): """Drop the table from the database. Deletes both the schema and all the contents within it. """ with self.db.lock: if self.exists: self._threading_warn() self.table.drop(self.db.executable, checkfirst=True) self._table = None
def has_index(self, columns): """Check if an index exists to cover the given ``columns``.""" if not self.exists: return False columns = set([normalize_column_name(c) for c in columns]) if columns in self._indexes: return True for column in columns: if not self.has_column(column): return False indexes = self.db.inspect.get_indexes(, schema=self.db.schema) for index in indexes: if columns == set(index.get('column_names', [])): self._indexes.append(columns) return True return False
[docs] def create_index(self, columns, name=None, **kw): """Create an index to speed up queries on a table. If no ``name`` is given a random name is created. :: table.create_index(['name', 'country']) """ columns = [normalize_column_name(c) for c in ensure_tuple(columns)] with self.db.lock: if not self.exists: raise DatasetException("Table has not been created yet.") for column in columns: if not self.has_column(column): return if not self.has_index(columns): self._threading_warn() name = name or index_name(, columns) columns = [self.table.c[c] for c in columns] idx = Index(name, *columns, **kw) idx.create(self.db.executable)
[docs] def find(self, *_clauses, **kwargs): """Perform a simple search on the table. Simply pass keyword arguments as ``filter``. :: results = table.find(country='France') results = table.find(country='France', year=1980) Using ``_limit``:: # just return the first 10 rows results = table.find(country='France', _limit=10) You can sort the results by single or multiple columns. Append a minus sign to the column name for descending order:: # sort results by a column 'year' results = table.find(country='France', order_by='year') # return all rows sorted by multiple columns (descending by year) results = table.find(order_by=['country', '-year']) To perform complex queries with advanced filters or to perform aggregation, use :py:meth:`db.query() <dataset.Database.query>` instead. """ if not self.exists: return iter([]) _limit = kwargs.pop('_limit', None) _offset = kwargs.pop('_offset', 0) order_by = kwargs.pop('order_by', None) _streamed = kwargs.pop('_streamed', False) _step = kwargs.pop('_step', QUERY_STEP) if _step is False or _step == 0: _step = None order_by = self._args_to_order_by(order_by) args = self._args_to_clause(kwargs, clauses=_clauses) query =, limit=_limit, offset=_offset) if len(order_by): query = query.order_by(*order_by) conn = self.db.executable if _streamed: conn = self.db.engine.connect() conn = conn.execution_options(stream_results=True) return ResultIter(conn.execute(query), row_type=self.db.row_type, step=_step)
[docs] def find_one(self, *args, **kwargs): """Get a single result from the table. Works just like :py:meth:`find() <dataset.Table.find>` but returns one result, or ``None``. :: row = table.find_one(country='United States') """ if not self.exists: return None kwargs['_limit'] = 1 kwargs['_step'] = None resiter = self.find(*args, **kwargs) try: for row in resiter: return row finally: resiter.close()
[docs] def count(self, *_clauses, **kwargs): """Return the count of results for the given filter set.""" # NOTE: this does not have support for limit and offset since I can't # see how this is useful. Still, there might be compatibility issues # with people using these flags. Let's see how it goes. if not self.exists: return 0 args = self._args_to_clause(kwargs, clauses=_clauses) query = select([func.count()], whereclause=args) query = query.select_from(self.table) rp = self.db.executable.execute(query) return rp.fetchone()[0]
def __len__(self): """Return the number of rows in the table.""" return self.count()
[docs] def distinct(self, *args, **_filter): """Return all the unique (distinct) values for the given ``columns``. :: # returns only one row per year, ignoring the rest table.distinct('year') # works with multiple columns, too table.distinct('year', 'country') # you can also combine this with a filter table.distinct('year', country='China') """ if not self.exists: return iter([]) columns = [] clauses = [] for column in args: if isinstance(column, ClauseElement): clauses.append(column) else: if not self.has_column(column): raise DatasetException("No such column: %s" % column) columns.append(self.table.c[column]) clause = self._args_to_clause(_filter, clauses=clauses) if not len(columns): return iter([]) q =, distinct=True, whereclause=clause, order_by=[c.asc() for c in columns]) return self.db.query(q)
# Legacy methods for running find queries. all = find def __iter__(self): """Return all rows of the table as simple dictionaries. Allows for iterating over all rows in the table without explicetly calling :py:meth:`find() <dataset.Table.find>`. :: for row in table: print(row) """ return self.find() def __repr__(self): """Get table representation.""" return '<Table(%s)>' %