SQLAlchemy: Prevent implicit cross join

Posted by Moser on 02 Jan 2020

I really like to create SQL queries using SQLAlchemy’s explicit and declarative API. When using this instead of raw strings, roughly half of errors I tend to introduce are caught before even sending the query.

Here is an example:

import sqlalchemy as sa

metadata = sa.MetaData()
a = sa.Table(
    "a",
    metadata,
    sa.Column("a_id", sa.Integer, primary_key=True, autoincrement=True),
    sa.Column("name", sa.String),
)
b = sa.Table(
    "b",
    metadata,
    sa.Column("b_id", sa.Integer, primary_key=True, autoincrement=True),
    sa.Column("a_id", sa.Integer, sa.ForeignKey(a.c.a_id)),
)

def create_select(additional_filters):
    return sa.select([a], whereclause=sa.and_(*additional_filters))

print(create_select([a.c.name == 'Foo']))
# SELECT a.a_id, a.name
# FROM a
# WHERE a.name = :name_1

A huge advantage over dealing with string queries is that you can create different parts of the query on their own and combine them. In the example I create the SELECT statements in a central place and allow to pass in parts of the where clause.

Unfortunately, this pattern has a dangerous property: It will implicitly add a CROSS JOIN when you pass a filter expression that contains a column from a table that is not part of the select statement already.

print(create_select([a.c.name == 'Foo', b.c.b_id == 1]))
# SELECT a.a_id, a.name
# FROM a, b
# WHERE a.name = :name_1 AND b.b_id = :b_id_1

On small tables this will just create an unexpected result set, but when the involved tables are large this query might well exhaust the DB server’s resources. The concrete problem can be fixed by adding a join to all the tables that should be allowed in the filters:

def create_select_corrected(additional_filters):
    return sa.select([a], from_obj=a.join(b), whereclause=sa.and_(*additional_filters))
print(create_select_corrected([a.c.name == 'Foo', b.c.b_id == 1]))
# SELECT a.a_id, a.name
# FROM a JOIN b ON a.a_id = b.a_id
# WHERE a.name = :name_1 AND b.b_id = :b_id_1

A more abstract problem is that we can create queries that will lead to unexpected results without noticing. A good but pretty expensive solution would be to validate the parts that are passed in. Depending on how variable the input can be, I would go for this solution. But in other cases it’s just myself using my query construction logic in new ways. So I prefer a cheaper solution that makes me aware of the problem.

Thankfully, SQLAlchemy’s events allow us to be notified when a query is about to be executed. We can create a listener that raises an exception when we try to run a problematic query.

def before_execute(conn, clauseelement, multiparams, params):
    if (
        isinstance(clauseelement, sa.sql.selectable.Select)
        and len(clauseelement.froms) > 1
    ):
        raise RuntimeError("Cross join detected:\n{}".format(clauseelement))

sa.event.listen(engine, "before_execute", before_execute)

It’s not perfect because it does not check subqueries or CTEs but it gives us a line of defense. I am also thinking of adding an assertion that checks for the problem to the tests that exercise the query construction logic.

Here is the complete example code: sqlalchemy_implicit_cross_join.py.