将 SQL 查询限制为在 Graphene-SQLAlchemy 中定义的 fields/columns

Limiting SQL query to defined fields/columns in Graphene-SQLAlchemy

This question has been posted as a GH issues under https://github.com/graphql-python/graphene-sqlalchemy/issues/134 but I thought I'd post it here too to tap into the SO crowd.

A full working demo can be found under https://github.com/somada141/demo-graphql-sqlalchemy-falcon.

考虑以下 SQLAlchemy ORM class:

class Author(Base, OrmBaseMixin):
    __tablename__ = "authors"

    author_id = sqlalchemy.Column(
        sqlalchemy.types.Integer(),
        primary_key=True,
    )

    name_first = sqlalchemy.Column(
        sqlalchemy.types.Unicode(length=80),
        nullable=False,
    )

    name_last = sqlalchemy.Column(
        sqlalchemy.types.Unicode(length=80),
        nullable=False,
    )

简单地包裹在 SQLAlchemyObjectType 中:

class TypeAuthor(SQLAlchemyObjectType):
    class Meta:
        model = Author

并暴露于:

author = graphene.Field(
    TypeAuthor,
    author_id=graphene.Argument(type=graphene.Int, required=False),
    name_first=graphene.Argument(type=graphene.String, required=False),
    name_last=graphene.Argument(type=graphene.String, required=False),
)

@staticmethod
def resolve_author(
    args,
    info,
    author_id: Union[int, None] = None,
    name_first: Union[str, None] = None,
    name_last: Union[str, None] = None,
):
    query = TypeAuthor.get_query(info=info)

    if author_id:
        query = query.filter(Author.author_id == author_id)

    if name_first:
        query = query.filter(Author.name_first == name_first)

    if name_last:
        query = query.filter(Author.name_last == name_last)

    author = query.first()

    return author

GraphQL 查询,例如:

query GetAuthor{
  author(authorId: 1) {
    nameFirst
  }
}

将导致发出以下原始 SQL(取自 SQLA 引擎的回显日志):

SELECT authors.author_id AS authors_author_id, authors.name_first AS authors_name_first, authors.name_last AS authors_name_last
FROM authors
WHERE authors.author_id = ?
 LIMIT ? OFFSET ?
2018-05-24 16:23:03,669 INFO sqlalchemy.engine.base.Engine (1, 1, 0)

正如你所看到的,我们可能只需要 nameFirst 字段,即 name_first 列,但整行都被提取了。当然,GraphQL 响应仅包含请求的字段,即

{
  "data": {
    "author": {
      "nameFirst": "Robert"
    }
  }
}

但我们仍然获取了整行,这在处理宽表时成为一个主要问题。

有没有一种方法可以自动将需要哪些列传送给 SQLAlchemy,从而避免这种形式的过度获取?

我的问题在 GitHub 问题 (https://github.com/graphql-python/graphene-sqlalchemy/issues/134) 上得到了回答。

想法是从 info 参数(graphql.execution.base.ResolveInfo 类型)中识别请求的字段,该参数通过 get_field_names 函数传递给解析器函数,例如下面:

def get_field_names(info):
    """
    Parses a query info into a list of composite field names.
    For example the following query:
        {
          carts {
            edges {
              node {
                id
                name
                ...cartInfo
              }
            }
          }
        }
        fragment cartInfo on CartType { whatever }

    Will result in an array:
        [
            'carts',
            'carts.edges',
            'carts.edges.node',
            'carts.edges.node.id',
            'carts.edges.node.name',
            'carts.edges.node.whatever'
        ]
    """

    fragments = info.fragments

    def iterate_field_names(prefix, field):
        name = field.name.value

        if isinstance(field, FragmentSpread):
            _results = []
            new_prefix = prefix
            sub_selection = fragments[field.name.value].selection_set.selections
        else:
            _results = [prefix + name]
            new_prefix = prefix + name + "."
            if field.selection_set:
                sub_selection = field.selection_set.selections
            else:
                sub_selection = []

        for sub_field in sub_selection:
            _results += iterate_field_names(new_prefix, sub_field)

        return _results

    results = iterate_field_names('', info.field_asts[0])

    return results

The above function was taken from https://github.com/graphql-python/graphene/issues/348#issuecomment-267717809. That issue contains other versions of this function but I felt this was the most complete.

并使用识别的字段来限制 SQLAlchemy 查询中的检索字段,如下所示:

fields = get_field_names(info=info)
query = TypeAuthor.get_query(info=info).options(load_only(*relation_fields))

应用于上述示例查询时:

query GetAuthor{
  author(authorId: 1) {
    nameFirst
  }
}

get_field_names 函数会 return ['author', 'author.nameFirst']。然而,由于 'original' SQLAlchemy ORM 字段是蛇形的 get_field_names 查询需要更新以删除 author 前缀并通过 graphene.utils.str_converters.to_snake_case函数。

长话短说,上述方法会产生一个原始的 SQL 查询,如下所示:

INFO:sqlalchemy.engine.base.Engine:SELECT authors.author_id AS authors_author_id, authors.name_first AS authors_name_first
FROM authors
WHERE authors.author_id = ?
 LIMIT ? OFFSET ?
2018-06-09 13:22:16,396 INFO sqlalchemy.engine.base.Engine (1, 1, 0)

更新

如果有人来到这里想知道实现我是如何实现我自己的 get_query_fields 函数版本的:

from typing import List, Dict, Union, Type

import graphql
from graphql.language.ast import FragmentSpread
from graphql.language.ast import Field
from graphene.utils.str_converters import to_snake_case
import sqlalchemy.orm

from demo.orm_base import OrmBaseMixin

def extract_requested_fields(
    info: graphql.execution.base.ResolveInfo,
    fields: List[Union[Field, FragmentSpread]],
    do_convert_to_snake_case: bool = True,
) -> Dict:
    """Extracts the fields requested in a GraphQL query by processing the AST
    and returns a nested dictionary representing the requested fields.

    Note:
        This function should support arbitrarily nested field structures
        including fragments.

    Example:
        Consider the following query passed to a resolver and running this
        function with the `ResolveInfo` object passed to the resolver.

        >>> query = "query getAuthor{author(authorId: 1){nameFirst, nameLast}}"
        >>> extract_requested_fields(info, info.field_asts, True)
        {'author': {'name_first': None, 'name_last': None}}

    Args:
        info (graphql.execution.base.ResolveInfo): The GraphQL query info passed
            to the resolver function.
        fields (List[Union[Field, FragmentSpread]]): The list of `Field` or
            `FragmentSpread` objects parsed out of the GraphQL query and stored
            in the AST.
        do_convert_to_snake_case (bool): Whether to convert the fields as they
            appear in the GraphQL query (typically in camel-case) back to
            snake-case (which is how they typically appear in ORM classes).

    Returns:
        Dict: The nested dictionary containing all the requested fields.
    """

    result = {}
    for field in fields:

        # Set the `key` as the field name.
        key = field.name.value

        # Convert the key from camel-case to snake-case (if required).
        if do_convert_to_snake_case:
            key = to_snake_case(name=key)

        # Initialize `val` to `None`. Fields without nested-fields under them
        # will have a dictionary value of `None`.
        val = None

        # If the field is of type `Field` then extract the nested fields under
        # the `selection_set` (if defined). These nested fields will be
        # extracted recursively and placed in a dictionary under the field
        # name in the `result` dictionary.
        if isinstance(field, Field):
            if (
                hasattr(field, "selection_set") and
                field.selection_set is not None
            ):
                # Extract field names out of the field selections.
                val = extract_requested_fields(
                    info=info,
                    fields=field.selection_set.selections,
                )
            result[key] = val
        # If the field is of type `FragmentSpread` then retrieve the fragment
        # from `info.fragments` and recursively extract the nested fields but
        # as we don't want the name of the fragment appearing in the result
        # dictionary (since it does not match anything in the ORM classes) the
        # result will simply be result of the extraction.
        elif isinstance(field, FragmentSpread):
            # Retrieve referened fragment.
            fragment = info.fragments[field.name.value]
            # Extract field names out of the fragment selections.
            val = extract_requested_fields(
                info=info,
                fields=fragment.selection_set.selections,
            )
            result = val

    return result

它将 AST 解析为 dict 保留查询的结构并(希望)匹配 ORM 的结构。

运行 info 查询对象,例如:

query getAuthor{
  author(authorId: 1) {
    nameFirst,
    nameLast
  }
}

产生

{'author': {'name_first': None, 'name_last': None}}

而像这样的更复杂的查询:

query getAuthor{
  author(nameFirst: "Brandon") {
    ...authorFields
    books {
      ...bookFields
    }
  }
}

fragment authorFields on TypeAuthor {
  nameFirst,
  nameLast
}

fragment bookFields on TypeBook {
  title,
  year
}

产生:

{'author': {'books': {'title': None, 'year': None},
  'name_first': None,
  'name_last': None}}

现在这些字典可用于定义什么是主table(在本例中为Author)的字段,因为它们的值为None作为 name_first 或该主要 table 关系上的字段,例如 books 关系上的字段 title

自动应用这些字段的简单方法可以采用以下函数的形式:

def apply_requested_fields(
    info: graphql.execution.base.ResolveInfo,
    query: sqlalchemy.orm.Query,
    orm_class: Type[OrmBaseMixin]
) -> sqlalchemy.orm.Query:
    """Updates the SQLAlchemy Query object by limiting the loaded fields of the
    table and its relationship to the ones explicitly requested in the GraphQL
    query.

    Note:
        This function is fairly simplistic in that it assumes that (1) the
        SQLAlchemy query only selects a single ORM class/table and that (2)
        relationship fields are only one level deep, i.e., that requestd fields
        are either table fields or fields of the table relationship, e.g., it
        does not support fields of relationship relationships.

    Args:
        info (graphql.execution.base.ResolveInfo): The GraphQL query info passed
            to the resolver function.
        query (sqlalchemy.orm.Query): The SQLAlchemy Query object to be updated.
        orm_class (Type[OrmBaseMixin]): The ORM class of the selected table.

    Returns:
        sqlalchemy.orm.Query: The updated SQLAlchemy Query object.
    """

    # Extract the fields requested in the GraphQL query.
    fields = extract_requested_fields(
        info=info,
        fields=info.field_asts,
        do_convert_to_snake_case=True,
    )

    # We assume that the top level of the `fields` dictionary only contains a
    # single key referring to the GraphQL resource being resolved.
    tl_key = list(fields.keys())[0]
    # We assume that any keys that have a value of `None` (as opposed to
    # dictionaries) are fields of the primary table.
    table_fields = [
        key for key, val in fields[tl_key].items()
        if val is None
    ]

    # We assume that any keys that have a value being a dictionary are
    # relationship attributes on the primary table with the keys in the
    # dictionary being fields on that relationship. Thus we create a list of
    # `[relatioship_name, relationship_fields]` lists to be used in the
    # `joinedload` definitions.
    relationship_fieldsets = [
        [key, val.keys()]
        for key, val in fields[tl_key].items()
        if isinstance(val, dict)
    ]

    # Assemble a list of `joinedload` definitions on the defined relationship
    # attribute name and the requested fields on that relationship.
    options_joinedloads = []
    for relationship_fieldset in relationship_fieldsets:
        relationship = relationship_fieldset[0]
        rel_fields = relationship_fieldset[1]
        options_joinedloads.append(
            sqlalchemy.orm.joinedload(
                getattr(orm_class, relationship)
            ).load_only(*rel_fields)
        )

    # Update the SQLAlchemy query by limiting the loaded fields on the primary
    # table as well as by including the `joinedload` definitions.
    query = query.options(
        sqlalchemy.orm.load_only(*table_fields),
        *options_joinedloads
    )

    return query