3D Plot:: 如何设置图例和颜色条的方向和位置?

3D Plot:: How to set legend and colorbar orientation and position?

我想创建一个 3d 图,可视化 hidden_layer_sizesmax_iterScore 之间的相关性。我不得不 google 稍微绕一圈才能得到预期的情节,但现在我遇到了一些关于图例的问题:

  1. 我有两个传说
  2. 第二个图例很小

我的目标是将右边的图例移到底部。但它不起作用,我什至无法删除正确的图例。如果我设置 showlegend=False 只有突出显示的小图例消失,正确的图例仍然存在。

我确定这只是我缺乏剧情经验。如果有任何可能的帮助,我将不胜感激。


MWE

数据

import pandas as pd

df = pd.DataFrame({'hidden_layer_sizes': {0: 25,
  1: 25,  2: 25,  3: 25,  4: 25,  5: 50,  6: 50,  7: 50,  8: 50,  9: 50,  10: 75,
  11: 75,  12: 75,  13: 75,  14: 75,  15: 100,  16: 100,  17: 100,  18: 100,  19: 100,  20: 125,
  21: 125,  22: 125,  23: 125,  24: 125,  25: 150,  26: 150,  27: 150,  28: 150,  29: 150}, 
'max_iter': {0: 100,  1: 200,  2: 300,  3: 400,  4: 500,  5: 100,  6: 200,  7: 300,  8: 400,  9: 500,  
10: 100,  11: 200,  12: 300,  13: 400,  14: 500,  15: 100, 16: 200,  17: 300,  18: 400,  19: 500,  
20: 100,  21: 200,  22: 300,  23: 400,  24: 500,  25: 100,  26: 200,  27: 300,  28: 400,  29: 500}, 
'Score': {0: 0.9270832984321359,  1: 0.9172223807360554,  2: 0.9202868292420568,  3: 0.9187318693456508,
  4: 0.9263589700182026,  5: 0.9325454241272417,  6: 0.9351742112383672,  7: 0.934706441722599,
  8: 0.9350294733755595,  9: 0.9334167352798914,  10: 0.9355533396303661,  11: 0.9327821227628682,
  12: 0.9333376163633981,  13: 0.9322875868305249,  14: 0.9345524934883098,  15: 0.9341786678949748,
  16: 0.9306931295155753,  17: 0.9332227354795629,  18: 0.9312008571438402,  19: 0.9335295484755572,
  20: 0.9333167395841182,  21: 0.9315595511169302,  22: 0.9301811416101524,  23: 0.9314818362895073,
  24: 0.9308551601915486,  25: 0.9296559215457606,  26: 0.9284091216867709,  27: 0.9318823563281231,
  28: 0.9295666150206443,  29: 0.9291284919738931},
 'Time': {0: 119.91294360160828,  1: 256.4710912704468,  2: 266.6792154312134,  3: 326.7445312023163,
  4: 256.8881601810455,  5: 183.77022705078124,  6: 359.7090343952179,  7: 383.6012378692627,
  8: 416.3133870601654,  9: 425.7837643623352,  10: 225.39801173210145,  11: 516.9914848804474,
  12: 562.7134436607361,  13: 585.6752841472626,  14: 560.5802517414093,  15: 267.22873797416685,
  16: 646.1253435134888,  17: 811.1979314804078,  18: 780.6058969974517,  19: 789.9369702339172,
  20: 394.0711458206177,  21: 890.7988158226013,  22: 1065.5482338428496,  23: 996.5119229316712,
  24: 1096.0208141803741,  25: 524.0947244644165,  26: 1182.684538602829,  27: 1348.3343998908997,
  28: 1356.0255290508271,  29: 1053.8607951164245}})

创建情节的代码

import numpy as np
import plotly.graph_objects as go 
from scipy.interpolate import griddata
import plotly.io as pio

xi = np.linspace(min(df["hidden_layer_sizes"]), max(df["hidden_layer_sizes"]), num=100)
yi = np.linspace(min(df["max_iter"]), max(df["max_iter"]), num=100)

x_grid, y_grid = np.meshgrid(xi,yi)
z_grid = griddata((df["hidden_layer_sizes"],df["max_iter"]),df["Score"],(x_grid,y_grid),method="cubic")

fig = go.Figure(go.Surface(x=x_grid, y=y_grid, z=z_grid, showlegend=True))
fig.update_layout(title="Test",
                  width=600, height=600, template="none",
                  legend=dict(orientation="h"))

fig.show()

如果我没记错的话,您可以使用 update_traces 来设置颜色栏属性。类似于:

fig.update_traces(
    colorbar_orientation='h',
    colorbar_y=0
)

您可能需要调整位置本身,以免出现任何重叠。

你在这里说的是两个不同的东西:legendcolorbar,前者是图形布局的属性,后者是图形数据或轨迹的属性.要获得您在这里的目标,只需包含以下内容:

fig.update_layout(legend = dict(orientation="h", x = -0.25, y = -0.10))
fig.update_traces(colorbar = dict(orientation='h', y = -0.25, x = 0.5))

情节 1

也就是说,如果您想完全保留“小”图例。如果没有,只需使用:

fig.update_layout(showlegend = False)

情节 2

完整代码:

import numpy as np
import plotly.graph_objects as go 
from scipy.interpolate import griddata
import plotly.io as pio

import pandas as pd

df = pd.DataFrame({'hidden_layer_sizes': {0: 25,
  1: 25,  2: 25,  3: 25,  4: 25,  5: 50,  6: 50,  7: 50,  8: 50,  9: 50,  10: 75,
  11: 75,  12: 75,  13: 75,  14: 75,  15: 100,  16: 100,  17: 100,  18: 100,  19: 100,  20: 125,
  21: 125,  22: 125,  23: 125,  24: 125,  25: 150,  26: 150,  27: 150,  28: 150,  29: 150}, 
'max_iter': {0: 100,  1: 200,  2: 300,  3: 400,  4: 500,  5: 100,  6: 200,  7: 300,  8: 400,  9: 500,  
10: 100,  11: 200,  12: 300,  13: 400,  14: 500,  15: 100, 16: 200,  17: 300,  18: 400,  19: 500,  
20: 100,  21: 200,  22: 300,  23: 400,  24: 500,  25: 100,  26: 200,  27: 300,  28: 400,  29: 500}, 
'Score': {0: 0.9270832984321359,  1: 0.9172223807360554,  2: 0.9202868292420568,  3: 0.9187318693456508,
  4: 0.9263589700182026,  5: 0.9325454241272417,  6: 0.9351742112383672,  7: 0.934706441722599,
  8: 0.9350294733755595,  9: 0.9334167352798914,  10: 0.9355533396303661,  11: 0.9327821227628682,
  12: 0.9333376163633981,  13: 0.9322875868305249,  14: 0.9345524934883098,  15: 0.9341786678949748,
  16: 0.9306931295155753,  17: 0.9332227354795629,  18: 0.9312008571438402,  19: 0.9335295484755572,
  20: 0.9333167395841182,  21: 0.9315595511169302,  22: 0.9301811416101524,  23: 0.9314818362895073,
  24: 0.9308551601915486,  25: 0.9296559215457606,  26: 0.9284091216867709,  27: 0.9318823563281231,
  28: 0.9295666150206443,  29: 0.9291284919738931},
 'Time': {0: 119.91294360160828,  1: 256.4710912704468,  2: 266.6792154312134,  3: 326.7445312023163,
  4: 256.8881601810455,  5: 183.77022705078124,  6: 359.7090343952179,  7: 383.6012378692627,
  8: 416.3133870601654,  9: 425.7837643623352,  10: 225.39801173210145,  11: 516.9914848804474,
  12: 562.7134436607361,  13: 585.6752841472626,  14: 560.5802517414093,  15: 267.22873797416685,
  16: 646.1253435134888,  17: 811.1979314804078,  18: 780.6058969974517,  19: 789.9369702339172,
  20: 394.0711458206177,  21: 890.7988158226013,  22: 1065.5482338428496,  23: 996.5119229316712,
  24: 1096.0208141803741,  25: 524.0947244644165,  26: 1182.684538602829,  27: 1348.3343998908997,
  28: 1356.0255290508271,  29: 1053.8607951164245}})


xi = np.linspace(min(df["hidden_layer_sizes"]), max(df["hidden_layer_sizes"]), num=100)
yi = np.linspace(min(df["max_iter"]), max(df["max_iter"]), num=100)

x_grid, y_grid = np.meshgrid(xi,yi)
z_grid = griddata((df["hidden_layer_sizes"],df["max_iter"]),df["Score"],(x_grid,y_grid),method="cubic")

fig = go.Figure(go.Surface(x=x_grid, y=y_grid, z=z_grid, showlegend=True))
fig.update_layout(title="Test",
                  width=600, height=600, template="none",
                  # legend=dict(orientation="h")
                 )

fig.update_layout(legend = dict(orientation="h", x = -0.25, y = -0.10))
fig.update_traces(colorbar = dict(orientation='h', y = -0.25, x = 0.5))
fig.update_layout(showlegend = False)

fig.show()