使用 Peewee 和 Marshmallow 序列化多对多关系

Serializing a Many to Many Relationship w/ Peewee and Marshmallow

我有一个 PostgreSQL 数据库,其中包含多对多用户与下表的标记关系:

我正在尝试构建一个简单的 API 来使用 Flask、Peewee 和 Marshmallow 访问此数据库中的数据。我们现在可以忽略 Flask,但我正在尝试为 social_user 创建一个模式,这将允许我在 returns 一个或多个用户的查询中转储带有各自的标签。我正在寻找类似于以下内容的回复:

{
    "id": "[ID]",
    "handle": "[HANDLE]",
    "local_id": "[LOCAL_ID]",
    "platform_slug": "[PLATFORM_SLUG]",
    "tags": [
        {
            "id": "[ID]",
            "title": "[TITLE]",
            "tag_type": "[TAG_TYPE]"
        },
        {
            "id": "[ID]",
            "title": "[TITLE]",
            "tag_type": "[TAG_TYPE]"
        }
    ]
}

我设法通过包含第二个查询来做到这一点,该查询将标签拉入 @post_dump 包装函数中的 social_user 模式中,然而,这感觉像是一个 hack,而且对于大量用户来说似乎会很慢(更新:这非常慢,我在 369 个用户上测试过)。我想我可以用 Marshmallow 的 fields.Nested field type 做点什么。有没有更好的方法只用一个 Peewee 查询来序列化这种关系?我的代码如下:

# just so you are aware of my namespaces
import marshmallow as marsh
import peewee as pw

Peewee 模型

db = postgres_ext.PostgresqlExtDatabase(
    register_hstore = False,
    **json.load(open('postgres.json'))
)

class Base_Model(pw.Model):
    class Meta:
        database = db

class Tag(Base_Model):
    title = pw.CharField()
    tag_type = pw.CharField(db_column = 'type')

    class Meta:
        db_table = 'tag'

class Social_User(Base_Model):
    handle = pw.CharField(null = True)
    local_id = pw.CharField()
    platform_slug = pw.CharField()

    class Meta:
        db_table = 'social_user'

class User_Tag(Base_Model):
    social_user_id = pw.ForeignKeyField(Social_User)
    tag_id = pw.ForeignKeyField(Tag)

    class Meta:
        primary_key = pw.CompositeKey('social_user_id', 'tag_id')
        db_table = 'user_tag'

棉花糖模式

class Tag_Schema(marsh.Schema):
    id = marsh.fields.Int(dump_only = True)
    title = marsh.fields.Str(required = True)
    tag_type = marsh.fields.Str(required = True, default = 'descriptive')

class Social_User_Schema(marsh.Schema):
    id = marsh.fields.Int(dump_only = True)
    local_id = marsh.fields.Str(required = True)
    handle = marsh.fields.Str()
    platform_slug = marsh.fields.Str(required = True)
    tags = marsh.fields.Nested(Tag_Schema, many = True, dump_only = True)

    def _get_tags(self, user_id):
        query = Tag.select().join(User_Tag).where(User_Tag.social_user_id == user_id)
        tags, errors = tags_schema.dump(query)
        return tags

    @marsh.post_dump(pass_many = True)
    def post_dump(self, data, many):
        if many:
            for datum in data:
                datum['tags'] = self._get_tags(datum['id']) if datum['id'] else []
        else:
            data['tags'] = self._get_tags(data['id'])
        return data

user_schema = Social_User_Schema()
users_schema = Social_User_Schema(many = True)
tags_schema = Tag_Schema(many = True)

以下是一些演示功能的测试:

db.connect()
query = Social_User.get(Social_User.id == 825)
result, errors = user_schema.dump(query)
db.close()
pprint(result)
{'handle': 'test',
 'id': 825,
 'local_id': 'test',
 'platform_slug': 'tw',
 'tags': [{'id': 20, 'tag_type': 'descriptive', 'title': 'this'},
          {'id': 21, 'tag_type': 'descriptive', 'title': 'that'}]}
db.connect()
query = Social_User.select().where(Social_User.platform_slug == 'tw')
result, errors = users_schema.dump(query)
db.close()
pprint(result)
[{'handle': 'test',
  'id': 825,
  'local_id': 'test',
  'platform_slug': 'tw',
  'tags': [{'id': 20, 'tag_type': 'descriptive', 'title': 'this'},
           {'id': 21, 'tag_type': 'descriptive', 'title': 'that'}]},
 {'handle': 'test2',
  'id': 826,
  'local_id': 'test2',
  'platform_slug': 'tw',
  'tags': []}]

看起来这可以使用 Peewee 模型中的 ManyToMany field 并手动设置 through_model 来完成。 ManyToMany 字段允许您向模型添加一个字段,将两个 table 相互关联,通常它会自动创建关系 table (through_model) 本身,但是你可以手动设置它。

我正在使用 3.0 alpha of Peewee,但我相信很多人都在使用当前的 stable 版本,所以我将包括这两个版本。我们将使用 DeferredThroughModel 对象和 ManyToMany 字段,在 Peewee 2.x 中,它们在 3.x 中的 "playhouse" 中,它们是主要 Peewee 的一部分发布。我们还将删除 @post_dump 包装函数:

Peewee 模型

# Peewee 2.x
# from playhouse import fields
# User_Tag_Proxy = fields.DeferredThroughModel()

# Peewee 3.x
User_Tag_Proxy = pw.DeferredThroughModel()

class Tag(Base_Model):
    title = pw.CharField()
    tag_type = pw.CharField(db_column = 'type')

    class Meta:
        db_table = 'tag'

class Social_User(Base_Model):
    handle = pw.CharField(null = True)
    local_id = pw.CharField()
    platform_slug = pw.CharField()
    # Peewee 2.x
    # tags = fields.ManyToManyField(Tag, related_name = 'users', through_model = User_Tag_Proxy)

    # Peewee 3.x
    tags = pw.ManyToManyField(Tag, backref = 'users', through_model = User_Tag_Proxy)

    class Meta:
        db_table = 'social_user'

class User_Tag(Base_Model):
    social_user = pw.ForeignKeyField(Social_User, db_column = 'social_user_id')
    tag = pw.ForeignKeyField(Tag, db_column = 'tag_id')

    class Meta:
        primary_key = pw.CompositeKey('social_user', 'tag')
        db_table = 'user_tag'

User_Tag_Proxy.set_model(User_Tag)

棉花糖模式

class Social_User_Schema(marsh.Schema):
    id = marsh.fields.Int(dump_only = True)
    local_id = marsh.fields.Str(required = True)
    handle = marsh.fields.Str()
    platform_slug = marsh.fields.Str(required = True)
    tags = marsh.fields.Nested(Tag_Schema, many = True, dump_only = True)

user_schema = Social_User_Schema()
users_schema = Social_User_Schema(many = True)

在实践中,它的工作原理与使用 @post_dump 包装函数完全相同。不幸的是,虽然这似乎是解决此问题的 "right" 方法,但实际上速度稍慢。

--更新--

我已经设法用 1/100 的时间完成了同样的事情。这有点麻烦,可以进行一些清理,但它确实有效!我没有对模型进行更改,而是在将数据传递给模式进行序列化之前调整了收集和处理数据的方式。

Peewee 模型

class Tag(Base_Model):
    title = pw.CharField()
    tag_type = pw.CharField(db_column = 'type')

    class Meta:
        db_table = 'tag'

class Social_User(Base_Model):
    handle = pw.CharField(null = True)
    local_id = pw.CharField()
    platform_slug = pw.CharField()

    class Meta:
        db_table = 'social_user'

class User_Tag(Base_Model):
    social_user = pw.ForeignKeyField(Social_User, db_column = 'social_user_id')
    tag = pw.ForeignKeyField(Tag, db_column = 'tag_id')

    class Meta:
        primary_key = pw.CompositeKey('social_user', 'tag')
        db_table = 'user_tag'

棉花糖模式

class Social_User_Schema(marsh.Schema):
    id = marsh.fields.Int(dump_only = True)
    local_id = marsh.fields.Str(required = True)
    handle = marsh.fields.Str()
    platform_slug = marsh.fields.Str(required = True)
    tags = marsh.fields.Nested(Tag_Schema, many = True, dump_only = True)

user_schema = Social_User_Schema()
users_schema = Social_User_Schema(many = True)

查询

对于新查询,我们将加入 (LEFT_OUTER) 三个 tables (Social_User, Tag,以及 User_Tag),其中 Social_User 作为我们的真实来源。我们想确保我们得到每个用户,无论他们是否有标签。这将 return 用户多次,具体取决于他们拥有的标签数量,因此我们需要通过遍历每个标签并使用字典来存储对象来减少这种情况。在每个新的 Social_User 对象中,我们将添加一个 tags 列表,我们将在其中附加 Tag 个对象。

db.connect()
query = (Social_User.select(User_Tag, Social_User, Tag)
    .join(User_Tag, pw.JOIN.LEFT_OUTER)
    .join(Tag, pw.JOIN.LEFT_OUTER)
    .order_by(Social_User.id))

users = {}
last = None
for result in query:
    user_id = result.id
    if (user_id not in users):
        # creates a new Social_User object matching the user data
        users[user_id] = Social_User(**result.__data__)
        users[user_id].tags = []
    try:
        # extracts the associated tag
        users[user_id].tags.append(result.user_tag.tag)
    except AttributeError:
        pass

result, errors = users_schema.dump(users.values())
db.close()
pprint(result)