Seaborn swarmplot 断线

Seaborn swarmplot break into lines

我正在尝试用 seaborn

制作这个 swarmplot

我的问题是蜂群太宽了。我希望能够将它们分成每行最多 3 个点的行

这是我的代码:

# Import modules
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
###

# Import and clead dataset
url = "https://raw.githubusercontent.com/amirnakar/scratchboard/master/Goodreads/goodreads_library_export.csv"
Books = pd.read_csv(url)
Books = Books[Books['Date Read'].notna()]   # Remove NA

Books['Year'] = pd.to_datetime(             # Convert to dates
    Books['Date Read'],
    format='%YYYY%mm%dd', 
    errors='coerce')

Books['Year'] = pd.DatetimeIndex(Books['Date Read']).year # Take only years
Books[['Year', 'Date Read']]                 # merge the years in
###

# Calculate mean rate by year
RateMeans = (Books["My Rating"].groupby(Books["Year"]).mean())
Years = list(RateMeans.index.values)
Rates = list(RateMeans)
RateMeans = pd.DataFrame(
    {'Years': Years,
     'Rates': Rates
    })
###

# Plot
fig,ax = plt.subplots(figsize=(20,10))

## Violin Plot:
plot = sns.violinplot(
    data=Books, 
    x = "Year", 
    y = 'My Rating', 
    ax=ax,
    color = "white", 
    inner=None,
    #palette=colors_from_values(ArrayRates[:,1], "Blues")
    )

## Swarm Plot
plot = sns.swarmplot(
    data=Books, 
    x = "Year", 
    y = 'My Rating', 
    ax=ax,
    hue = "My Rating",
    size = 10
    )
    
## Style
    
### Title
ax.text(x=0.5, y=1.1, s='Book Ratings: Distribution per Year', fontsize=32, weight='bold', ha='center', va='bottom', transform=ax.transAxes)
ax.text(x=0.5, y=1.05, s='Source: Goodreads.com (amirnakar)', fontsize=24, alpha=0.75, ha='center', va='bottom', transform=ax.transAxes)



### Axis
ax.set(xlim=(4.5, None), ylim=(0,6))
#ax.set_title('Book Ratings: Distribution per Year \n', fontsize = 32)
ax.set_ylabel('Rating (out of 5 stars)', fontsize = 24)
ax.set_xlabel('Year', fontsize = 24)
ax.set_yticklabels(ax.get_yticks().astype(int), size=20)
ax.set_xticklabels(ax.get_xticks(), size=20)

### Legend
plot.legend(loc="lower center", ncol = 5 )

### Colour pallete
colorset = ["#FAFF04", "#FFD500", "#9BFF00", "#0099FF", "#000BFF"]
colorset.reverse()
sns.set_palette(sns.color_palette(colorset))


# Save the plot
#plt.show(plot)
plt.savefig("Rate-Python.svg", format="svg")

这是输出:

我希望发生的事情:

我希望能够定义 每行点的 maximum 为 3,如果多于 break它成一个新行。我在这里对两组进行了演示(在 PowerPoint 中手动完成),但我希望整个情节都使用它

之前:

之后:

这里尝试稍微重新定位这些点 upward/downward。 delta 的值来自试验。

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# Import and clea dataset
url = "https://raw.githubusercontent.com/amirnakar/scratchboard/master/Goodreads/goodreads_library_export.csv"
Books = pd.read_csv(url)
Books = Books[Books['Date Read'].notna()]  # Remove NA
Books['Year'] = pd.DatetimeIndex(Books['Date Read']).year  # Take only years
# Calculate mean rate by year
RatePerYear = Books[["My Rating", "Year"]].groupby("Year")["My Rating"].value_counts()

modified_ratings = []
delta = 0.2  # distance to move overlapping ratings
for (year, rating), count in RatePerYear.iteritems():
    higher = max(0, ((count - 3) + 1) // 2)
    lower = max(0, (count - 3) // 2)
    modified_ratings.append([year, rating, count - higher - lower])
    for k in range((higher + 2) // 3):
        modified_ratings.append([year, rating + (k + 1) * delta, 3 if (k + 1) * 3 <= higher else higher % 3])
    for k in range((lower + 2) // 3):
        modified_ratings.append([year, rating - (k + 1) * delta, 3 if (k + 1) * 3 <= lower else lower % 3])
modified_ratings = np.array(modified_ratings)
modified_ratings_df = pd.DataFrame(
    {'Year': np.repeat(modified_ratings[:, 0].astype(int), modified_ratings[:, 2].astype(int)),
     'My Rating': np.repeat(modified_ratings[:, 1], modified_ratings[:, 2].astype(int))})
modified_ratings_df['Rating'] = modified_ratings_df['My Rating'].round().astype(int)

fig, ax = plt.subplots(figsize=(20, 10))
sns.violinplot(data=Books, x="Year", y='My Rating', ax=ax, color="white", inner=None)
palette = ["#FAFF04", "#FFD500", "#9BFF00", "#0099FF", "#000BFF"].reverse()
sns.swarmplot(data=modified_ratings_df, x="Year", y='My Rating', ax=ax, hue="Rating", size=10, palette=palette)

ax.set(xlim=(4.5, None), ylim=(0, 6))
ax.legend(loc="lower center", ncol=5)
plt.tight_layout()
plt.show()