本記事ではPythonのグラフライブラリとして使えるplotlyによる「ペアプロット図(散布図行列)」の作成について紹介します。
2021年現在も流行している機械学習コンテストなどで良く用いられるのがペアプロットです。
特徴として変数間の相関関係を一望できる点が挙げられます。
RやR^2といった具体的な統計指標をイメージしながら、全ての変数同士の組み合わせと比較できるので、どの変数に目をつけるべきか知れるツールとして使えます。
今回もplotlyチュートリアルのペアプロットから例題を抜粋しながら解説していきます。
①基本的なペアプロット図
②z方向の変数も色で表すペアプロット図
③対角部分を削除したペアプロット図
④go.Splomを使用したペアプロット図
⑤おまけ:レイアウトの変更色々
今回紹介する例題ではグラフの元データを容易する必要がなく、下記プログラムを完コピすればグラフを作成することができます。
まずは実際に手を動かしてコードを書いて見てください。
Anacondaのインストールがまだの方はこちらの記事を参考に準備してください。
Jupyter notebookの起動と使い方まで紹介しています。
導入から基本的な使い方まで紹介しています。
それではペアプロットについて解説していきます。
plotly_express、plotly_graph_objectsを使用するので、まずはインポートします。
import plotly.express as px
import plotly.graph_objects as go
次にデータを用意します。
今回はplotly_expressで用意されているpx.data.iris()を使用します。
px.data.iris()にはアヤメ(iris)のデータが格納されています。
がくの長さ(sepal_length)、 がくの幅(sepal_width)、 花弁の長さ(petal_length)、 花弁の幅(petal_width)、 アヤメの品種(species)、 品種No.(species_id)の情報が入っています。
df = px.data.iris()
中身をみてみましょう。
px.scatter_matrixでペアプロット図が書けます。
fig=px.scatter_matrix(df)のたった1行でOKです。
figはただの変数なので何でも大丈夫です。figure=図を意味しているので良く使われる表現です。
fig.show()で図が表示されます。
fig = px.scatter_matrix(df)
fig.show()
①と違う点はdimensions=[]で変数を指定しているところ、さらにcolor=””でz方向の変数を指定して色分けすることができます。
fig = px.scatter_matrix(df,
dimensions=["sepal_width", "sepal_length", "petal_width", "petal_length"],
color="species")
fig.show()
labels={}を使用すると、内包表記でreplaceを連続的に行うことで対角の同じ変数同士の部分を空白にできます。
fig = px.scatter_matrix(df,
dimensions=["sepal_width", "sepal_length", "petal_width", "petal_length"],
color="species", symbol="species",
title="Scatter matrix of iris data set",
labels={col:col.replace('_', ' ') for col in df.columns}) # remove underscore
fig.update_traces(diagonal_visible=False)
fig.show()
ここからはplotly_graph_objectsを使った例を紹介します。
その場合、go.Splomを使って作成します。
いくつか例を連続して表示していきます。
レイアウトについては私の別の記事を参考にしてください。
https://cafe-mickey.com/category/python/plotly/
df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/iris-data.csv')
# The Iris dataset contains four data variables, sepal length, sepal width, petal length,
# petal width, for 150 iris flowers. The flowers are labeled as `Iris-setosa`,
# `Iris-versicolor`, `Iris-virginica`.
# Define indices corresponding to flower categories, using pandas label encoding
index_vals = df['class'].astype('category').cat.codes
fig = go.Figure(data=go.Splom(
dimensions=[dict(label='sepal length',
values=df['sepal length']),
dict(label='sepal width',
values=df['sepal width']),
dict(label='petal length',
values=df['petal length']),
dict(label='petal width',
values=df['petal width'])],
text=df['class'],
marker=dict(color=index_vals,
showscale=False, # colors encode categorical variables
line_color='white', line_width=0.5)
))
fig.update_layout(
title='Iris Data set',
dragmode='select',
width=600,
height=600,
hovermode='closest',
)
fig.show()
df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/iris-data.csv')
index_vals = df['class'].astype('category').cat.codes
fig = go.Figure(data=go.Splom(
dimensions=[dict(label='sepal length',
values=df['sepal length']),
dict(label='sepal width',
values=df['sepal width']),
dict(label='petal length',
values=df['petal length']),
dict(label='petal width',
values=df['petal width'])],
diagonal_visible=False, # remove plots on diagonal
text=df['class'],
marker=dict(color=index_vals,
showscale=False, # colors encode categorical variables
line_color='white', line_width=0.5)
))
fig.update_layout(
title='Iris Data set',
width=600,
height=600,
)
fig.show()
df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/iris-data.csv')
index_vals = df['class'].astype('category').cat.codes
fig = go.Figure(data=go.Splom(
dimensions=[dict(label='sepal length',
values=df['sepal length']),
dict(label='sepal width',
values=df['sepal width']),
dict(label='petal length',
values=df['petal length']),
dict(label='petal width',
values=df['petal width'])],
showupperhalf=False, # remove plots on diagonal
text=df['class'],
marker=dict(color=index_vals,
showscale=False, # colors encode categorical variables
line_color='white', line_width=0.5)
))
fig.update_layout(
title='Iris Data set',
width=600,
height=600,
)
fig.show()
df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/iris-data.csv')
index_vals = df['class'].astype('category').cat.codes
fig = go.Figure(data=go.Splom(
dimensions=[dict(label='sepal length',
values=df['sepal length']),
dict(label='sepal width',
values=df['sepal width'],
visible=False),
dict(label='petal length',
values=df['petal length']),
dict(label='petal width',
values=df['petal width'])],
text=df['class'],
marker=dict(color=index_vals,
showscale=False, # colors encode categorical variables
line_color='white', line_width=0.5)
))
fig.update_layout(
title='Iris Data set',
width=600,
height=600,
)
fig.show()
dfd = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/diabetes.csv')
textd = ['non-diabetic' if cl==0 else 'diabetic' for cl in dfd['Outcome']]
fig = go.Figure(data=go.Splom(
dimensions=[dict(label='Pregnancies', values=dfd['Pregnancies']),
dict(label='Glucose', values=dfd['Glucose']),
dict(label='BloodPressure', values=dfd['BloodPressure']),
dict(label='SkinThickness', values=dfd['SkinThickness']),
dict(label='Insulin', values=dfd['Insulin']),
dict(label='BMI', values=dfd['BMI']),
dict(label='DiabPedigreeFun', values=dfd['DiabetesPedigreeFunction']),
dict(label='Age', values=dfd['Age'])],
marker=dict(color=dfd['Outcome'],
size=5,
colorscale='Bluered',
line=dict(width=0.5,
color='rgb(230,230,230)')),
text=textd,
diagonal=dict(visible=False)))
title = "Scatterplot Matrix (SPLOM) for Diabetes Dataset<br>Data source:"+\
" <a href='https://www.kaggle.com/uciml/pima-indians-diabetes-database/data'>[1]</a>"
fig.update_layout(title=title,
dragmode='select',
width=1000,
height=1000,
hovermode='closest')
fig.show()
以上で解説は終わりです。
↓Plotlyについて学べる数少ない参考書です。