Source code for asyncqlio.backends.base

"""
The base implementation of a backend. This provides some ABC classes.
"""
import collections
import typing
from abc import abstractmethod
from collections import OrderedDict
from urllib.parse import ParseResult, parse_qs

from asyncqlio.meta import AsyncABC


class BaseDialect:
    """
    The base class for a SQL dialect describer.

    This class signifies what features the SQL dialect can use, and as such can be used to customize
    query creation for faster results on certain servers, or new features on certain servers, etc.

    By default, all ``has_`` properties will default to False, so that none of them need be
    implemented. Regular methods will raise NotImplementedError, however.
    """

    @property
    def has_checkpoints(self) -> bool:
        """
        Returns True if this dialect can use transaction checkpoints.
        """
        return False

    @property
    def has_serial(self) -> bool:
        """
        Returns True if this dialect can use the SERIAL datatype.
        """
        return False

    @property
    def has_returns(self) -> bool:
        """
        Returns True if this dialect has RETURNS.
        """
        return False

    @property
    def has_ilike(self) -> bool:
        """
        Returns True if this dialect has ILIKE.
        """
        return False

    @property
    def has_default(self) -> bool:
        """
        Returns True if this dialect has DEFAULT.
        """
        return False

    @property
    def has_truncate(self) -> bool:
        """
        Returns TRUE if this dialect has TRUNCATE.
        """
        return False

    @property
    def has_cascade(self) -> bool:
        """
        Returns True if this dialect has DROP TABLE ... CASCADE.
        """
        return False

    @property
    def lastval_method(self):
        """
        The last value method for a dialect. For example, in PostgreSQL this is LASTVAL();
        """
        raise NotImplementedError

    def get_primary_key_index_name(self, table_name: str) -> str:
        """
        Get the name a dialect gives to a table's primary key index.
        """
        raise NotImplementedError

    def get_unique_column_index_name(self, table_name: str, column_name: str) -> str:
        """
        Get the name a dialect gives to a unique column index.

        :param table_name: The name of the table to use.
        :param column_name: The name of the column to use.
        """
        raise NotImplementedError

    def get_column_sql(self, table_name: str = None,
                       *, emitter: 'typing.Callable[[str], str]') -> str:
        """
        Get a query to find information on all columns, optionally limiting by table.

        :param table_name: The name of the table to use.
        :param emitter: The emitter to use.
        """
        raise NotImplementedError

    def get_index_sql(self, table_name: str = None,
                      *, emitter: 'typing.Callable[[str], str]') -> str:
        """
        Get a query to find information on all indexes, optionally limiting by table.

        :param table_name: The name of the table to use.
        :param emitter: The emitter to use.
        """
        raise NotImplementedError

    def get_upsert_sql(self, table_name: str,
                       *, on_conflict_update: bool = True) -> 'typing.Tuple[str, set]':
        """
        Get a formattable query and a set of required params to execute upsert-like functionality.

        :param table_name: The name of the table to upsert into.
        :param on_conflict_update: If this is to update on conflict.
        """
        raise NotImplementedError

    def transform_columns_to_indexes(self, *rows: 'DictRow', table_name: str):
        """
        Transform appropriate database rows to Column objects.

        :param rows: A list of :class:`.DictRow` objects returned from the database.
        :param table_name: The name of the table being transformed.
        """
        raise NotImplementedError

    def transform_rows_to_indexes(self, *rows: 'DictRow'):
        """
        Transform appropriate database rows to Index objects.

        :param rows: A list of :class:`.DictRow` objects returned from the database.
        """
        raise NotImplementedError


[docs]class BaseResultSet(collections.AsyncIterator, AsyncABC): """ The base class for a result set. This represents the results from a database query, as an async iterable. Children classes must implement: - :attr:`.BaseResultSet.keys` - :attr:`.BaseResultSet.fetch_row` - :attr:`.BaseResultSet.fetch_many` """ @property @abstractmethod def keys(self) -> typing.Iterable[str]: """ :return: An iterable of keys that this query contained. """
[docs] @abstractmethod async def fetch_row(self) -> 'DictRow': """ Fetches the **next row** in this query. This should return None if the row could not be fetched.
"""
[docs] @abstractmethod async def fetch_many(self, n: int) -> 'DictRow': """ Fetches the **next N rows** in this query. :param n: The number of rows to fetch.
"""
[docs] @abstractmethod async def close(self): """ Closes this result set.
""" async def __anext__(self) -> 'DictRow': res = await self.fetch_row() if not res: raise StopAsyncIteration return res
[docs] async def flatten(self) -> 'typing.List[DictRow]': """ Flattens this ResultSet. :return: A list of :class:`.DictRow` objects. """ rows = [] async for row in self: rows.append(row)
return rows async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close()
return False
[docs]class BaseTransaction(AsyncABC): """ The base class for a transaction. This represents a database transaction (i.e SQL statements guarded with a BEGIN and a COMMIT/ROLLBACK). Children classes must implement: - :meth:`.BaseTransaction.begin` - :meth:`.BaseTransaction.rollback` - :meth:`.BaseTransaction.commit` - :meth:`.BaseTransaction.execute` - :meth:`.BaseTransaction.cursor` - :meth:`.BaseTransaction.close` Additionally, some extra methods can be implemented: - :meth:`.BaseTransaction.create_savepoint` - :meth:`.BaseTransaction.release_savepoint` These methods are not required to be implemented, but will raise :class:`NotImplementedError` if they are not. This class takes one parameter in the constructor: the :class:`.BaseConnector` used to connect to the DB server. """ def __init__(self, connector: 'BaseConnector'): self.connector = connector async def __aenter__(self) -> 'BaseTransaction': await self.begin() return self async def __aexit__(self, exc_type, exc_val, exc_tb): try: if exc_type is not None: await self.rollback() return False await self.commit() return False finally: await self.close()
[docs] @abstractmethod async def begin(self): """ Begins the transaction, emitting a BEGIN instruction.
"""
[docs] @abstractmethod async def rollback(self, checkpoint: str = None): """ Rolls back the transaction. :param checkpoint: If provided, the checkpoint to rollback to. Otherwise, the entire \ transaction will be rolled back.
"""
[docs] @abstractmethod async def commit(self): """ Commits the current transaction, emitting a COMMIT instruction.
"""
[docs] @abstractmethod async def execute(self, sql: str, params: typing.Union[typing.Mapping, typing.Iterable] = None): """ Executes SQL in the current transaction. :param sql: The SQL statement to execute. :param params: Any parameters to pass to the query.
"""
[docs] @abstractmethod async def close(self, *, has_error: bool = False): """ Called at the end of a transaction to cleanup. The connection will be released if there's no error; otherwise it will be closed. :param has_error: If the transaction has an error.
"""
[docs] @abstractmethod async def cursor(self, sql: str, params: typing.Union[typing.Mapping, typing.Iterable] = None) \ -> 'BaseResultSet': """ Executes SQL and returns a database cursor for the rows. :param sql: The SQL statement to execute. :param params: Any parameters to pass to the query. :return: The :class:`.BaseResultSet` returned from the query, if applicable.
"""
[docs] def create_savepoint(self, name: str): """ Creates a savepoint in the current transaction. .. warning:: This is not supported in all DB engines. If so, this will raise :class:`NotImplementedError`. :param name: The name of the savepoint to create. """
raise NotImplementedError
[docs] def release_savepoint(self, name: str): """ Releases a savepoint in the current transaction. :param name: The name of the savepoint to release. """
raise NotImplementedError class BaseConnector(AsyncABC): """ The base class for a connector. This should be used for all connector classes as the parent class. Children classes must implement: - :meth:`.BaseConnector.connect` - :meth:`.BaseConnector.close` - :meth:`.BaseConnector.emit_param` - :meth:`.BaseConnector.get_transaction` - :meth:`.BaseConnector.get_db_server_info` """ def __init__(self, dsn: ParseResult): """ :param dsn: The :class:`urllib.parse.ParseResult` created from parsing a DSN. """ self._parse_result = dsn self.dsn = dsn.geturl() self.host = dsn.hostname self.port = dsn.port self.username = dsn.username self.password = dsn.password self.db = dsn.path[1:] self.params = {k: v[0] for k, v in parse_qs(dsn.query).items()} @abstractmethod async def connect(self, **kwargs) -> 'BaseConnector': """ Connects the current connector to the database server. This is called automatically by the :class:`.DatabaseInterface :return: The original BaseConnector instance. """ @abstractmethod async def close(self): """ Closes the current Connector. """ @abstractmethod def get_transaction(self) -> BaseTransaction: """ Gets a new transaction object for this connection. :return: A new :class:`~.BaseTransaction` object attached to this connection. """ @abstractmethod def emit_param(self, name: str) -> str: """ Emits a parameter that can be used as a substitute during a query. :param name: The name of the parameter. :return: A string that represents the substitute to be placed in the query. """ @abstractmethod async def get_db_server_version(self) -> str: """ Gets the version of the DB server running. """ class DictRow(OrderedDict): """ Represents a row returned from a base result set, in dict form. This class allows for accessing both via key and index. """ def __getitem__(self, item): if isinstance(item, int): try: return list(self.values())[item] except IndexError: raise KeyError(item) return super().__getitem__(item) def __setitem__(self, key, value, **kwargs): if isinstance(key, int): # find the actual string key at position ``key`` # then set the item using said dict key d_key = list(self.keys())[key] return super().__setitem__(d_key, value, **kwargs) return super().__setitem__(key, value, **kwargs)