Source code for asyncqlio.backends.postgresql
"""
PostgreSQL backends.
.. currentmodule:: asyncqlio.backends.postgresql
.. autosummary::
:toctree:
asyncpg
"""
# used for namespace packages
from pkgutil import extend_path
import io
import re
from asyncqlio.exc import DatabaseException
from asyncqlio.sentinels import NO_DEFAULT
from asyncqlio.backends.base import BaseDialect
from asyncqlio.orm.schema import column as md_column
from asyncqlio.orm.schema import index as md_index
from asyncqlio.orm.schema import types as md_types
__path__ = extend_path(__path__, __name__)
DEFAULT_CONNECTOR = "asyncpg"
idx_regex = re.compile(
r"CREATE( UNIQUE)? INDEX (\S+) ON (\S+).*\((.+)\)",
flags=re.IGNORECASE,
)
[docs]class PostgresqlDialect(BaseDialect):
"""
The dialect for Postgres.
"""
@property
def has_checkpoints(self):
return True
@property
def has_serial(self):
return True
@property
def lastval_method(self):
return "LASTVAL()"
@property
def has_returns(self):
return True
@property
def has_ilike(self):
return True
@property
def has_default(self):
return True
@property
def has_truncate(self):
return True
@property
def has_cascade(self):
return True
[docs] def get_primary_key_index_name(self, table_name):
return "{}_pkey".format(table_name)
[docs] def get_unique_column_index_name(self, table_name, column_name):
return "{}_{}_key".format(table_name, column_name)
[docs] def get_column_sql(self, table_name=None, *, emitter):
sql = '''
SELECT columns.*, (
SELECT COUNT(*)
FROM information_schema.table_constraints
AS constraints
JOIN information_schema.constraint_column_usage
AS usage
ON constraints.constraint_name=usage.constraint_name
WHERE constraints.constraint_type='PRIMARY KEY'
AND constraints.table_name=columns.table_name
AND usage.column_name=columns.column_name) AS primary_key
FROM information_schema.columns
AS columns'''
if table_name:
sql += " WHERE columns.table_name={}".format(emitter("table_name"))
return sql
[docs] def get_index_sql(self, table_name=None, *, emitter):
sql = "SELECT * FROM pg_indexes"
if table_name:
sql += (" WHERE tablename={}"
.format(emitter("table_name")))
return sql
[docs] def get_upsert_sql(self, table_name, *, on_conflict_update=True):
sql = io.StringIO()
params = {"insert", "col", "returning"}
sql.write("INSERT INTO ")
sql.write(table_name)
sql.write(" {insert} ON CONFLICT ({col}) DO ")
if on_conflict_update:
params.add("update")
sql.write("UPDATE SET {update} ")
else:
sql.write("NOTHING ")
sql.write("RETURNING {returning};")
return sql.getvalue(), params
def transform_rows_to_columns(self, *rows, table_name=None):
for row in rows:
table_name = row['table_name']
column_name = row['column_name']
primary_key = bool(row['primary_key'])
nullable = row["is_nullable"]
default = row["column_default"] or NO_DEFAULT
psql_type = row["data_type"]
if psql_type == "integer":
real_type = md_types.Integer
elif psql_type == "text":
real_type = md_types.Text
elif psql_type == "character varying":
real_type = md_types.String
elif psql_type == "smallint":
real_type = md_types.SmallInt
elif psql_type == "bigint":
real_type = md_types.BigInt
elif psql_type == "boolean":
real_type = md_types.Boolean
elif psql_type == "real":
real_type = md_types.Real
elif psql_type.startswith("timestamp"):
real_type = md_types.Timestamp
elif psql_type == "decimal":
real_type = md_types.Numeric
elif psql_type == "numeric":
real_type = md_types.Numeric
else:
raise DatabaseException("Cannot parse type {}".format(psql_type))
yield md_column.Column.with_name(
name=column_name,
type_=real_type(),
table=table_name,
nullable=nullable,
default=default,
primary_key=primary_key,
)
[docs] def transform_rows_to_indexes(self, *rows, table_name=None):
for row in rows:
groups = idx_regex.match(row["indexdef"]).groups()
unique, name, table, columns = groups
columns = columns.split(', ')
index = md_index.Index.with_name(name, *columns, table=table, unique=unique)
yield index