关联图-研究多变量间的关系

  |  

摘要: 研究多变量之间关系的可视化图表

【对数据分析、人工智能、金融科技、风控服务感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:潮汐朝夕
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


本文我们来看一下研究多变量之间关系时,有哪些可视化图表可以选择。

散点图 plt.scatter()

数据 df 如下:

数据1

绘制代码如下,其中 plt.gcf()plt.gca() 分别获得当前的图表和子图,进而设置 x, y 轴的显示范围和标签。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Prepare Data 
# Create as many colors as there are unique df['category']
categories = np.unique(df['category'])
colors = [plt.cm.tab10(i / float(len(categories) - 1)) for i in range(len(categories))]

# Draw Plot for Each Category
plt.figure(figsize=(16, 10), dpi= 80, facecolor='w', edgecolor='k')

for i, category in enumerate(categories):
plt.scatter('area', 'poptotal',
data=df.loc[df.category==category, :],
s=20, c=colors[i], label=str(category))

# Decorations
plt.gca().set(xlim=(0.0, 0.1), ylim=(0, 90000),
xlabel='Area', ylabel='Population')

plt.xticks(fontsize=12); plt.yticks(fontsize=12)
plt.title("Scatterplot of df Area vs Population", fontsize=22)
plt.legend(fontsize=12)
plt.show()

带边界的气泡图 ax.add_patch(poly)

数据 df 还是前面的【数据1】,绘制代码还是以 plt.scatter() 为主体,画边界的函数为 encircle

  • np.r_ 是按列连接两个矩阵,就是把两矩阵上下相加,要求列数相等,类似于 pd.concat()
  • np.c_ 是按行连接两个矩阵,就是把两矩阵左右相加,要求行数相等,类似于 pd.merge()
  • ConvexHull 是给定二维平面点集求凸包。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from matplotlib import patches
from scipy.spatial import ConvexHull
import warnings
warnings.simplefilter('ignore')
sns.set_style("white")

# As many colors as there are unique df['category']
categories = np.unique(df['category'])
colors = [plt.cm.tab10(i / float(len(categories) - 1)) for i in range(len(categories))]

# Draw Scatterplot with unique color for each category
fig = plt.figure(figsize=(16, 10), dpi=80, facecolor='w', edgecolor='k')

for i, category in enumerate(categories):
plt.scatter('area', 'poptotal', data=df.loc[df.category==category, :]
,s='dot_size', c=colors[i], label=str(category)
,edgecolors='black', linewidths=.5)

# Encircling
def encircle(x, y, ax=None, **kw):
if not ax:
ax = plt.gca()
p = np.c_[x,y]
hull = ConvexHull(p)
poly = plt.Polygon(p[hull.vertices,:], **kw)
ax.add_patch(poly)

# Select data to be encircled
df_encircle_data = df.loc[df.state=='IN', :]

# Draw polygon surrounding vertices
encircle(df_encircle_data.area, df_encircle_data.poptotal, ec="k", fc="gold", alpha=0.1)
encircle(df_encircle_data.area, df_encircle_data.poptotal, ec="firebrick", fc="none", linewidth=1.5)

# Step 4: Decorations
plt.gca().set(xlim=(0.0, 0.1), ylim=(0, 90000),
xlabel='Area', ylabel='Population')

plt.xticks(fontsize=12); plt.yticks(fontsize=12)
plt.title("Bubble Plot with Encircling", fontsize=22)
plt.legend(fontsize=12)
plt.show()

带线性回归最佳拟合的散点图 sns.lmplot()

数据 df 如下:

数据2

下面的例子展示了数据中各组之间最佳拟合线的差异。要禁用分组并仅为整个数据集绘制一条最佳拟合线,可以在 sns.lmplot() 调用中删除 hue ='cyl' 参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
# Import Data
df_select = df.loc[df.cyl.isin([4,8]), :]

# Plot
sns.set_style("white")
gridobj = sns.lmplot(x="displ", y="hwy", hue="cyl", data=df_select,
height=7, aspect=1.6, robust=True, palette='tab10',
scatter_kws=dict(s=60, linewidths=.7, edgecolors='black'))

# Decorations
gridobj.set(xlim=(0.5, 7.5), ylim=(0, 50))
plt.title("Scatterplot with line of best fit grouped by number of cylinders", fontsize=20)
plt.show()

抖动图 sns.stripplot()

有时多个数据点有相同的 X 和 Y 值,多个点绘制会重叠并隐藏。为避免这种情况,请将数据点稍微抖动,以便您可以直观地看到它们。

数据 df 用前面的【数据2】。

1
2
3
4
5
6
7
# Draw Stripplot
fig, ax = plt.subplots(figsize=(16,10), dpi=80)
sns.stripplot(x=df.cty, y=df.hwy, jitter=0.25, size=8, ax=ax, linewidth=.5, palette='Set1')

# Decorations
plt.title('Use jittered plots to avoid overlapping of points', fontsize=22)
plt.show()

计数图 plt.scatter(s=df.counts)

避免点重叠问题的另一个选择是增加点的大小,这取决于该点中有多少点。因此,点的大小越大,其周围的点的集中度越高。

数据 df 用前面的【数据2】。

reset_index 重置 DataFrame 的索引,并使用默认值。如果 DataFrame 具有 MultiIndex,则此方法可以删除一个或多个级别。

绘制代码还是以 plt.scatter() 为主体,计数与散点大小通过参数 s 这个参数设置:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
df_counts = df.groupby(['hwy', 'cty']).size().reset_index(name='counts')

ctys = np.unique(df['cty'])
colors = [plt.cm.tab10(i / float(len(ctys) - 1)) for i in range(len(ctys))]

# Draw Plot for Each Category
plt.figure(figsize=(16, 10), dpi= 80, facecolor='w', edgecolor='k')

for i, cty in enumerate(ctys):
plt.scatter('cty', 'hwy',
data=df_counts.loc[df.cty==cty, :],
s=df_counts.loc[df.cty==cty, :].counts * 5,
c=colors[i], label=str(cty))

plt.gca().set(xlabel='cty', ylabel='hwy')

# Decorations
plt.title('Counts Plot - Size of circle is bigger as more points overlap', fontsize=22)
plt.show()

边缘直方图 plt.scatter + plt.hist

边缘直方图具有沿 X 和 Y 轴变量的直方图。这用于可视化 X 和 Y 之间的关系以及单独的 X 和 Y 的单变量分布。

数据 df 用前面的【数据2】。

使用了 3 个 Axes,绘图的方法为 sns.boxplot()plt.hist()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Create Fig and gridspec
fig = plt.figure(figsize=(16, 10), dpi= 80)
grid = plt.GridSpec(4, 4, hspace=0.5, wspace=0.2)

# Define the axes
ax_main = fig.add_subplot(grid[:-1, :-1])
ax_right = fig.add_subplot(grid[:-1, -1], xticklabels=[], yticklabels=[])
ax_bottom = fig.add_subplot(grid[-1, 0:-1], xticklabels=[], yticklabels=[])

# Scatterplot on main ax
ax_main.scatter('displ', 'hwy', s=df.cty*4,
c=df.manufacturer.astype('category').cat.codes,
alpha=.9, data=df, cmap="tab10", edgecolors='gray', linewidths=.5)

# histogram on the right
ax_bottom.hist(df.displ, 40, histtype='stepfilled', orientation='vertical', color='deeppink')
ax_bottom.invert_yaxis()

# histogram in the bottom
ax_right.hist(df.hwy, 40, histtype='stepfilled', orientation='horizontal', color='deeppink')

# Decorations
ax_main.set(title='Scatterplot with Histograms \n displ vs hwy', xlabel='displ', ylabel='hwy')
ax_main.title.set_fontsize(20)
for item in ([ax_main.xaxis.label, ax_main.yaxis.label] + ax_main.get_xticklabels() + ax_main.get_yticklabels()):
item.set_fontsize(14)

xlabels = ax_main.get_xticks().tolist()
ax_main.set_xticklabels(xlabels)
plt.show()

边缘箱形图 plt.scatter + sns.boxplot

边缘箱图与边缘直方图相似。箱线图有助于精确定位 X 和 Y 的中位数、第 25 和第 75 百分位数。

使用了 3 个 Axes,绘图的方法为 sns.boxplot()plt.scatter()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Create Fig and gridspec
fig = plt.figure(figsize=(16, 10), dpi= 80)
grid = plt.GridSpec(4, 4, hspace=0.5, wspace=0.2)

# Define the axes
ax_main = fig.add_subplot(grid[:-1, :-1])
ax_right = fig.add_subplot(grid[:-1, -1], xticklabels=[], yticklabels=[])
ax_bottom = fig.add_subplot(grid[-1, 0:-1], xticklabels=[], yticklabels=[])

# Scatterplot on main ax
ax_main.scatter('displ', 'hwy', s=df.cty*5, c=df.manufacturer.astype('category').cat.codes, alpha=.9, data=df, cmap="Set1", edgecolors='black', linewidths=.5)

# Add a graph in each part
sns.boxplot(df.hwy, ax=ax_right, orient="v")
sns.boxplot(df.displ, ax=ax_bottom, orient="h")

# Decorations ------------------
# Remove x axis name for the boxplot
ax_bottom.set(xlabel='')
ax_right.set(ylabel='')

# Main Title, Xlabel and YLabel
ax_main.set(title='Scatterplot with Histograms \n displ vs hwy', xlabel='displ', ylabel='hwy')

# Set font size of different components
ax_main.title.set_fontsize(20)
for item in ([ax_main.xaxis.label, ax_main.yaxis.label] + ax_main.get_xticklabels() + ax_main.get_yticklabels()):
item.set_fontsize(14)

plt.show()

相关图 sns.heatmap

相关图用于直观地查看给定数据框中所有可能的数值变量对之间的相关性。

数据 df 如下:

数据3

1
2
3
4
5
6
7
8
9
10
# Plot
plt.figure(figsize=(12,10), dpi= 80)
sns.heatmap(df.corr(), xticklabels=df.corr().columns, yticklabels=df.corr().columns,
cmap='RdYlGn', center=0, annot=True)

# Decorations
plt.title('Correlogram of mtcars', fontsize=22)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.show()

矩阵图 sns.pairplot

矩阵图是探索性分析中常用的,用于理解所有可能的数值变量对之间的关系。是双变量分析的必备工具。

数据 df 如下:

1
2
3
4
# Plot
plt.figure(figsize=(10,8), dpi= 80)
sns.pairplot(df, kind="scatter", hue="Species", plot_kws=dict(s=80, edgecolor="white", linewidth=2.5))
plt.show()

1
2
3
4
# Plot
plt.figure(figsize=(10,8), dpi= 80)
sns.pairplot(df, kind="reg", hue="Species")
plt.show()


Share