是否可以创建 mediapipe body 姿势地标和连接的动画 3D 散点图?

Is it possible to create a plotly animated 3D scatter plot of mediapipe's body pose landmarks and connections?

我目前正在做一个项目,我正在使用 mediapipe 的 body 姿势估计库,我知道使用 plotly 我们可以创建动画 3D 散点图。 我想知道是否可以使用 plotly 创建 body 姿势地标和 mediapipe 提供的连接的动画图?

它与 mediapipe 的 in-built 功能相同:

mp_draw.plot_landmarks()

但具有 plotly 的交互性和动画时间轴。

  • 查看 mp_draw.plot_landmarks() 的代码,它使用 matplotlib。因此 plotly 版本可以编码
  • 做到了这一点,这个版本returns一个plotly图。我确定可以进一步清理代码
  • 可以通过在多个结果中构建帧来构建动画

设置

import cv2
import math
import numpy as np
import plotly.express as px
from pathlib import Path
import mediapipe as mp

uploaded = {
    Path.home()
    .joinpath("Downloads")
    .joinpath("thao-le-hoang-v4zceVZ5HK8-unsplash.jpg"): ""
}
uploaded = {str(k): i for k, i in uploaded.items()}

# Read images with OpenCV.
images = {name: cv2.imread(name) for name in uploaded.keys()}

mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles

# Run MediaPipe Pose and plot 3d pose world landmarks.
with mp_pose.Pose(
    static_image_mode=True, min_detection_confidence=0.5, model_complexity=2
) as pose:
    for name, image in images.items():
        results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

mp_drawing.plot_landmarks(
        results.pose_world_landmarks,  mp_pose.POSE_CONNECTIONS)

阴谋plot_landmarks()

import pandas as pd
import plotly.graph_objects as go

_PRESENCE_THRESHOLD = 0.5
_VISIBILITY_THRESHOLD = 0.5


def plot_landmarks(
    landmark_list,
    connections=None,
):
    if not landmark_list:
        return
    plotted_landmarks = {}
    for idx, landmark in enumerate(landmark_list.landmark):
        if (
            landmark.HasField("visibility")
            and landmark.visibility < _VISIBILITY_THRESHOLD
        ) or (
            landmark.HasField("presence") and landmark.presence < _PRESENCE_THRESHOLD
        ):
            continue
        plotted_landmarks[idx] = (-landmark.z, landmark.x, -landmark.y)
    if connections:
        out_cn = []
        num_landmarks = len(landmark_list.landmark)
        # Draws the connections if the start and end landmarks are both visible.
        for connection in connections:
            start_idx = connection[0]
            end_idx = connection[1]
            if not (0 <= start_idx < num_landmarks and 0 <= end_idx < num_landmarks):
                raise ValueError(
                    f"Landmark index is out of range. Invalid connection "
                    f"from landmark #{start_idx} to landmark #{end_idx}."
                )
            if start_idx in plotted_landmarks and end_idx in plotted_landmarks:
                landmark_pair = [
                    plotted_landmarks[start_idx],
                    plotted_landmarks[end_idx],
                ]
                out_cn.append(
                    dict(
                        xs=[landmark_pair[0][0], landmark_pair[1][0]],
                        ys=[landmark_pair[0][1], landmark_pair[1][1]],
                        zs=[landmark_pair[0][2], landmark_pair[1][2]],
                    )
                )
        cn2 = {"xs": [], "ys": [], "zs": []}
        for pair in out_cn:
            for k in pair.keys():
                cn2[k].append(pair[k][0])
                cn2[k].append(pair[k][1])
                cn2[k].append(None)

    df = pd.DataFrame(plotted_landmarks).T.rename(columns={0: "z", 1: "x", 2: "y"})
    df["lm"] = df.index.map(lambda s: mp_pose.PoseLandmark(s).name).values
    fig = (
        px.scatter_3d(df, x="z", y="x", z="y", hover_name="lm")
        .update_traces(marker={"color": "red"})
        .update_layout(
            margin={"l": 0, "r": 0, "t": 0, "b": 0},
            scene={"camera": {"eye": {"x": 2.1, "y": 0, "z": 0}}},
        )
    )
    fig.add_traces(
        [
            go.Scatter3d(
                x=cn2["xs"],
                y=cn2["ys"],
                z=cn2["zs"],
                mode="lines",
                line={"color": "black", "width": 5},
                name="connections",
            )
        ]
    )

    return fig

试一试

plot_landmarks(
        results.pose_world_landmarks,  mp_pose.POSE_CONNECTIONS)