"""
Classes for query objects.
"""
import abc
import collections
import io
import itertools
import typing
from asyncqlio.backends.base import BaseResultSet
from asyncqlio.meta import AsyncABC
from asyncqlio.orm import inspection as md_inspection, operators as md_operators, \
session as md_session
from asyncqlio.orm.schema import column as md_column, relationship as md_relationship, \
table as md_table
from asyncqlio.sentinels import NO_VALUE
[docs]class BaseQuery(AsyncABC):
"""
A base query object.
"""
def __init__(self, sess: 'md_session.Session'):
"""
:param sess: The :class:`.Session` associated with this query.
"""
self.session = sess
[docs] @abc.abstractmethod
def generate_sql(self) -> typing.Tuple[str, typing.Mapping[str, typing.Any]]:
"""
Generates the SQL for this query.
:return: A two item tuple, the SQL to use and a mapping of params to pass.
"""
[docs] @abc.abstractmethod
async def run(self):
"""
Runs this query.
"""
[docs]class ResultGenerator(collections.AsyncIterator):
"""
A helper class that will generate new results from a query when iterated over.
"""
def __init__(self, q: 'SelectQuery'):
"""
:param q: The :class:`.SelectQuery` to use.
"""
self.query = q
self._results = None # type: BaseResultSet
self._result_deque = collections.deque()
async def _fill(self):
# peek from the first item
try:
first = self._result_deque[0]
except IndexError:
# no row was stored from last time
last_pkey = None
rows_filled = 0
else:
# there was a row from last time, so we need to check for any run-off rows for that
last_pkey = tuple(first[col.alias_name(quoted=False)]
for col in self.query.table.primary_key.columns)
# also, since there was technically one row before, we start at 1 here
rows_filled = 1
while True:
# fetch a new row, make sure its not none
row = await self._results.fetch_row()
if row is None:
break
# add it to the results
self._result_deque.append(row)
# load the primary key via getting every column through alias name
pkey = tuple(row[col.alias_name(quoted=False)] for col in
self.query.table.primary_key.columns)
# if there was no results before, we just set the primary key and continue
if last_pkey is None:
last_pkey = pkey
rows_filled += 1
continue
else:
# otherwise, check if the primary key matches
if pkey == last_pkey:
# it does, so it's a run-on row that has been joined
rows_filled += 1
continue
else:
# it does not, so it's a new row
break
# return the rows filled to ensure
return rows_filled
async def __anext__(self):
# ensure we have a BaseResultSet
if self._results is None:
self._results = await self.query.session.cursor(*self.query.generate_sql())
# get the number of rows filled off of the end
filled = await self._fill()
if filled == 0:
raise StopAsyncIteration
rows = [self._result_deque.popleft() for x in range(0, filled)]
if len(rows) == 1:
return self.query.map_columns(rows[0])
return self.query.map_many(*rows)
async def next(self):
try:
return await self.__anext__()
except StopAsyncIteration:
return None
[docs] async def flatten(self) -> 'typing.List[md_table.Table]':
"""
Flattens this query into a single list.
"""
l = []
async for result in self:
l.append(result)
return l
[docs]class SelectQuery(BaseQuery):
"""
Represents a SELECT query, which fetches data from the database.
This is not normally created by user code directly, but rather as a result of a
:meth:`.Session.select` call.
.. code-block:: python3
sess = db.get_session()
async with sess:
query = sess.select.from_(User) # query is instance of SelectQuery
# alternatively, but not recommended
query = sess.select(User)
However, it is possible to create this class manually:
.. code-block:: python3
query = SelectQuery(db.get_session()
query.set_table(User)
query.add_condition(User.id == 2)
user = await query.first()
"""
def __init__(self, session: 'md_session.Session'):
super().__init__(session)
#: The table being queried.
self.table = None
#: A list of conditions to fulfil.
self.conditions = []
#: The limit on the number of rows returned from this query.
self.row_limit = None
#: The offset to start fetching rows from.
self.row_offset = None
#: The column to order by.
self.orderer = None
def __call__(self, table):
return self.from_(table)
# used so you can async iterate over a query directly
def __aiter__(self):
return ResultGenerator(q=self)
def _get_joins_for_table(self, parent: 'md_relationship.Relationship',
table: 'md_table.Table', seen: list = None):
"""
Gets the foreign joins for a table.
"""
if seen is None:
seen = [table]
foreign_tables = []
joins = []
for relationship in table.iter_relationships():
# ignore non-join relationships
if relationship.load_type != "joined":
continue
if relationship.foreign_table in seen:
continue
foreign_table = relationship.foreign_table
foreign_tables.append(foreign_table)
joins.append(relationship._get_join_query(parent))
return foreign_tables, joins
def _recursive_get_table_joins(self, parent: 'md_relationship.Relationship',
table: 'md_table.Table', seen: list = None):
"""
Recursively loads the joins for a table.
:param parent: The parent relationship this table is being loaded from, or None if it was \
loaded directly.
:param table: The table to get joins for.
:param seen: A list of tables that have already been seen and should not be re-joined.
"""
# this scans the tree of relationships
# and determines how to join them properly
if seen is None:
seen = [table]
elif table in seen:
return [], []
foreign_tables, joins = self._get_joins_for_table(parent, table, seen=seen)
for relationship in table.iter_relationships():
if relationship.load_type != "joined":
continue
if relationship.foreign_table in seen:
continue
# get the table joins for the foreign table
f, j = self._recursive_get_table_joins(relationship,
relationship.foreign_table, seen=seen)
seen.append(relationship.foreign_table)
foreign_tables.extend(f), joins.extend(j)
return foreign_tables, joins
[docs] def get_required_join_paths(self):
"""
Gets the required join paths for this query.
"""
# we can just pass None since it's the first in the chain
return self._recursive_get_table_joins(None, self.table, seen=None)
[docs] def generate_sql(self) -> typing.Tuple[str, dict]:
"""
Generates the SQL for this query.
"""
# calculate the column names
foreign_tables, joins = self.get_required_join_paths()
selected_columns = self.table.iter_columns()
column_names = []
for table in itertools.chain([self.table], foreign_tables):
for column in table.iter_columns():
a = column.alias_name(table=table, quoted=True)
column_names.append(r'{} AS {}'.format(column.quoted_fullname_with_table(table), a))
# BEGIN THE GENERATION
fmt = io.StringIO()
fmt.write("SELECT {} FROM {} ".format(", ".join(column_names), self.table.__quoted_name__))
# cleanup after ourselves for a bit
del selected_columns
# format conditions
params = {}
c_sql = []
for condition in self.conditions:
response = condition.generate_sql(self.session.bind.emit_param)
params.update(response.parameters)
c_sql.append(response.sql)
# append joins
fmt.write(" ".join(joins))
# append the fmt with the conditions
# these are assumed to be And if there are multiple!
if c_sql:
fmt.write(" WHERE {}".format(" AND ".join(c_sql)))
if self.orderer is not None:
res = self.orderer.generate_sql(self.session.bind.emit_param)
fmt.write(" ORDER BY {}".format(res.sql))
if self.row_limit is not None:
fmt.write(" LIMIT {}".format(self.row_limit))
if self.row_offset is not None:
fmt.write(" OFFSET {}".format(self.row_offset))
return fmt.getvalue(), params
# "fetch" methods
[docs] async def first(self) -> 'md_table.Table':
"""
Gets the first result that matches from this query.
:return: A :class:`.Table` instance representing the first item, or None if no item matched.
"""
gen = await self.session.run_select_query(self)
row = await gen.next()
if row is not None:
return row
[docs] async def all(self) -> 'ResultGenerator':
"""
Gets all results that match from this query.
:return: A :class:`.ResultGenerator` that can be iterated over.
"""
return await self.session.run_select_query(self)
return await self.all()
# ORM methods
[docs] def map_columns(self, results: typing.Mapping[str, typing.Any]) -> 'md_table.Table':
"""
Maps columns in a result row to a :class:`.Table` instance object.
:param results: A single row of results from the query cursor.
:return: A new :class:`.Table` instance that represents the row returned.
"""
# try and map columns to our Table
mapping = {column.alias_name(self.table, quoted=False): column
for column in self.table.iter_columns()}
row_expando = {}
relation_data = {}
for colname in results.keys():
if colname in mapping:
column = mapping[colname]
row_expando[column.name] = results[colname]
else:
relation_data[colname] = results[colname]
# create a new Table
row = self.table._internal_from_row(row_expando, existed=True)
# update values
for column in self.table.iter_columns():
val = row.get_column_value(column, return_default=False)
if val is not NO_VALUE:
row._values[column] = val
# update the existed
md_inspection._set_mangled(row, "existed", True)
# give the row a session
row._session = self.session
# ensure relationships are cascaded
row._update_relationships(relation_data)
return row
[docs] def map_many(self, *rows: typing.Mapping[str, typing.Any]):
"""
Maps many records to one row.
This will group the records by the primary key of the main query table, then add additional
columns as appropriate.
"""
# this assumes that the rows come in grouped by PK on the left
# also fuck right joins
# get the first row and construct the first table row using map_columns
# this will also map any extra relationship data there
first_row = rows[0]
tbl_row = self.map_columns(first_row)
# loop over every "extra" rows
# and update the relationship data in the table
for runon_row in rows[1:]:
tbl_row._update_relationships(runon_row)
pass
return tbl_row
# Helper methods for natural builder-style queries
[docs] def from_(self, tbl) -> 'SelectQuery':
"""
Sets the table this query is selecting from.
:param tbl: The :class:`.Table` object to select.
:return: This query.
"""
self.set_table(tbl)
return self
[docs] def where(self, *conditions: 'md_operators.BaseOperator') -> 'SelectQuery':
"""
Adds a WHERE clause to the query. This is a shortcut for :meth:`.SelectQuery.add_condition`.
.. code-block:: python3
sess.select.from_(User).where(User.id == 1)
:param conditions: The conditions to use for this WHERE clause.
:return: This query.
"""
for condition in conditions:
self.add_condition(condition)
return self
[docs] def limit(self, row_limit: int) -> 'SelectQuery':
"""
Sets a limit of the number of rows that can be returned from this query.
:param row_limit: The maximum number of rows to return.
:return: This query.
"""
self.row_limit = row_limit
return self
[docs] def offset(self, offset: int) -> 'SelectQuery':
"""
Sets the offset of rows to start returning results from/
:param offset: The row offset.
:return: This query.
"""
self.row_offset = offset
return self
[docs] def order_by(self, *col: 'typing.Union[md_column.Column, md_operators.Sorter]',
sort_order: str = "asc"):
"""
Sets the order by clause for this query.
The argument provided can either be a :class:`.Column`, or a :class:`.Sorter` which is
provided by :meth:`.Column.asc` / :meth:`.Column.desc`. By default, ``asc`` is used when
passing a column.
"""
if not col:
raise TypeError("Must provide at least one item to order with")
if len(col) == 1 and isinstance(col[0], md_operators.Sorter):
self.orderer = col[0]
else:
if sort_order == "asc":
self.orderer = md_operators.AscSorter(*col)
elif sort_order == "desc":
self.orderer = md_operators.DescSorter(*col)
else:
raise TypeError("Unknown sort order {}".format(sort_order))
return self
# "manual" methods
[docs] def set_table(self, tbl) -> 'SelectQuery':
"""
Sets the table to query on.
:param tbl: The :class:`.Table` object to set.
:return: This query.
"""
self.table = tbl
return self
[docs] def add_condition(self, condition: 'md_operators.BaseOperator') -> 'SelectQuery':
"""
Adds a condition to the query/
:param condition: The :class:`.BaseOperator` to add.
:return: This query.
"""
self.conditions.append(condition)
return self
[docs]class InsertQuery(BaseQuery):
"""
Represents an INSERT query.
"""
def __init__(self, sess: 'md_session.Session'):
super().__init__(sess)
#: A list of rows to generate the insert statements for.
self.rows_to_insert = []
def __await__(self):
return self.run().__await__()
[docs] async def run(self) -> 'typing.List[md_table.Table]':
"""
Runs this query.
:return: A list of inserted :class:`.md_table.Table`.
"""
return await self.session.run_insert_query(self)
[docs] def rows(self, *rows: 'md_table.Table') -> 'InsertQuery':
"""
Adds a set of rows to the query.
:param rows: The rows to insert.
:return: This query.
"""
for row in rows:
self.add_row(row)
return self
[docs] def add_row(self, row: 'md_table.Table') -> 'InsertQuery':
"""
Adds a row to this query, allowing it to be executed later.
:param row: The :class:`.Table` instance to use for this query.
:return: This query.
"""
self.rows_to_insert.append(row)
return self
[docs] def on_conflict(self, *columns: 'md_column.Column') -> 'UpsertQuery':
"""
Get an :class:`.UpsertQuery` to react upon a conflict.
:param columns: The :class:`.Column` objects upon which to check for a conflict.
"""
return UpsertQuery(self.session, *columns, rows=self.rows_to_insert)
[docs] def generate_sql(self) -> typing.List[typing.Tuple[str, tuple]]:
"""
Generates the SQL statements for this insert query.
:returns: A list of two-item tuples to execute:
- The SQL query+params to emit to actually insert the row
"""
queries = []
for row in self.rows_to_insert:
query, params = row._get_insert_sql(self.session.bind.emit_param, self.session)
queries.append((query, params))
return queries
[docs]class UpsertQuery(InsertQuery):
"""
Represents an UPSERT query.
.. versionadded:: 0.2.0
"""
def __init__(self, sess: 'md_session.Session', *columns: 'md_column.Column',
rows: 'md_table.Table'):
"""
:param sess: The :class:`.Session` this query is attached to.
:param column: The :class:`.Column` objects on which the conflict might happen.
:param rows: The :class:`.Table` objects that are to be added.
"""
super().__init__(sess)
self._on_conflict_update = False
self._conflict_cols = list(columns)
self._update_cols = []
self.rows_to_insert = list(rows)
[docs] def on_conflict(self, *columns: 'md_column.Column') -> 'UpsertQuery':
"""
Add more conflict columns to this query.
:param column: The :class:`.Column` objects upon which to check for a conflict.
"""
self._conflict_cols.extend(columns)
[docs] def update(self, *cols: 'md_column.Column') -> 'UpsertQuery':
"""
Used to specify which :class:`.Column` objects to update on a conflict.
:param cols: The :class:`.Column` objects to update.
"""
self._on_conflict_update = True
self._update_cols.extend(cols)
return self
[docs] def nothing(self) -> 'UpsertQuery':
"""
Specify that this query should do nothing if there's a conflict.
This is the default behavior.
"""
self._update_cols = [] # just in case
return self
[docs] def generate_sql(self) -> typing.List[typing.Tuple[str, tuple]]:
"""
Generates the SQL statements for this upsert query.
:returns: A list of two-item tuples:
- The SQL query to use
- The params to use with the query
"""
queries = []
counter = itertools.count()
def emit():
return "param_{}".format(next(counter))
for row in self.rows_to_insert:
query, params = row._get_upsert_sql(
emit,
self.session,
update_columns=self._update_cols,
on_conflict_columns=self._conflict_cols,
on_conflict_update=self._on_conflict_update,
)
queries.append((query, params))
return queries
[docs]class BulkQuery(BaseQuery, metaclass=abc.ABCMeta):
"""
Represents a **bulk query**.
This allows adding conditionals to the query.
"""
def __init__(self, sess: 'md_session.Session'):
super().__init__(sess)
#: The table to update.
self._table = None # type: md_table.TableMeta
#: The list of conditions to query by.
self.conditions = []
def __call__(self, *args, **kwargs):
return self.table(*args, **kwargs)
def __await__(self):
return self.run().__await__()
# Builder methods
[docs] def table(self, table: 'typing.Type[md_table.Table]'):
"""
Sets the table for this query.
"""
self._table = table
return self
[docs] def where(self, *conditions: 'md_operators.ComparisonOp'):
"""
Sets the conditions for this query.
"""
self.conditions.extend(conditions)
return self
# Manual-style methods
[docs] def set_table(self, table: 'typing.Type[md_table.Table]'):
"""
Sets a table on this query.
"""
self._table = table
[docs] def add_condition(self, condition: 'md_operators.BaseOperator'):
"""
Adds a condition to this query.
"""
self.conditions.append(condition)
[docs]class BulkUpdateQuery(BulkQuery):
"""
Represents a **bulk update query**. This updates many rows based on certain criteria.
.. code-block:: python3
query = BulkUpdateQuery(session)
# style 1: manual
query.set_table(User)
query.add_condition(User.xp < 300)
# add on a value
query.set_update(User.xp + 100)
# or set a value
query.set_update(User.xp.set(300))
await query.run()
# style 2: builder
await query.table(User).where(User.xp < 300).set(User.xp + 100).run()
await query.table(User).where(User.xp < 300).set(User.xp, 300).run()
"""
def __init__(self, sess: 'md_session.Session'):
super().__init__(sess)
#: The thing to set on the updated rows.
self.setting = None
[docs] def set(self, setter, value: typing.Any = None):
"""
Sets a column in this query.
"""
if value is not None:
setter = md_operators.ValueSetter(setter, value)
self.setting = setter
return self
[docs] def set_update(self, update):
"""
Sets the update for this query.
"""
self.setting = update
[docs] def generate_sql(self):
"""
Generates the SQL for this query.
"""
# base query is update table
query = io.StringIO()
query.write('UPDATE {} SET '.format(self._table.__quoted_name__))
# define params used in generating sql
params = {}
# get the sql and params from the generate_sql call
response = self.setting.generate_sql(self.session.bind.emit_param)
# update params
params.update(response.parameters)
query.write(response.sql)
# format conditions
c_sql = []
for condition in self.conditions:
# pass the condition offset
res = condition.generate_sql(self.session.bind.emit_param)
params.update(res.parameters)
c_sql.append(res.sql)
query.write(' WHERE ')
query.write(' AND '.join(c_sql))
# all generated
return query.getvalue(), params
return await self.session.run_update_query(self)
[docs]class BulkDeleteQuery(BulkQuery):
"""
Represents a **bulk delete query**. This deletes many rows based on criteria.
.. code-block:: python3
query = BulkDeleteQuery(session)
# style 1: manual
query.set_table(User)
query.add_condition(User.xp < 300)
await query.run()
# style 2: builder
await query.table(User).where(User.xp < 300).run()
await query.table(User).where(User.xp < 300).run()
"""
[docs] def generate_sql(self):
query = io.StringIO()
query.write("DELETE FROM {} ".format(self._table.__quoted_name__))
# define params used in generating sql
params = {}
# format conditions
c_sql = []
for condition in self.conditions:
# pass the condition offset
res = condition.generate_sql(self.session.bind.emit_param)
params.update(res.parameters)
c_sql.append(res.sql)
query.write(' WHERE ')
query.write(' AND '.join(c_sql))
return query.getvalue(), params
return await self.session.run_delete_query(self)
[docs]class RowUpdateQuery(BaseQuery):
"""
Represents a **row update query**. This is **NOT** a bulk update query - it is used for updating
specific rows.
"""
def __init__(self, sess: 'md_session.Session'):
super().__init__(sess)
#: The list of rows to update.
self.rows_to_update = []
def __await__(self):
return self.run().__await__()
[docs] async def run(self):
"""
Executes this query.
"""
return await self.session.run_update_query(self)
[docs] def rows(self, *rows: 'md_table.Table') -> 'RowUpdateQuery':
"""
Adds a set of rows to the query.
:param rows: The rows to insert.
:return: This query.
"""
for row in rows:
self.add_row(row)
return self
[docs] def add_row(self, row: 'md_table.Table') -> 'RowUpdateQuery':
"""
Adds a row to this query, allowing it to be executed later.
:param row: The :class:`.Table` instance to use for this query.
:return: This query.
"""
self.rows_to_update.append(row)
return self
[docs] def generate_sql(self) -> typing.List[typing.Tuple[str, tuple]]:
"""
Generates the SQL statements for this row update query.
This will return a list of two-item tuples to execute:
- The SQL query+params to emit to actually insert the row
"""
queries = []
for row in self.rows_to_update:
queries.append(row._get_update_sql(self.session.bind.emit_param, self.session))
return queries
[docs]class RowDeleteQuery(BaseQuery):
"""
Represents a row deletion query. This is **NOT** a bulk delete query - it is used for deleting
specific rows.
"""
def __init__(self, sess: 'md_session.Session'):
super().__init__(sess)
#: The list of rows to delete.
self.rows_to_delete = []
[docs] def rows(self, *rows: 'md_table.Table') -> 'RowDeleteQuery':
"""
Adds a set of rows to the query.
:param rows: The rows to insert.
:return: This query.
"""
for row in rows:
self.add_row(row)
return self
[docs] def add_row(self, row: 'md_table.Table'):
"""
Adds a row to this query.
:param row: The :class:`.Table` instance
"""
self.rows_to_delete.append(row)
[docs] def generate_sql(self) -> typing.List[typing.Tuple[str, tuple]]:
"""
Generates the SQL statements for this row delete query.
This will return a list of two-item tuples to execute:
- The SQL query+params to emit to actually insert the row
"""
queries = []
for row in self.rows_to_delete:
queries.append(row._get_delete_sql(self.session.bind.emit_param, self.session))
return queries
[docs] async def run(self):
"""
Runs this query.
"""
return await self.session.run_delete_query(self)