Python

【plotlyチュートリアル】<>ペアプロット図(散布図行列)作成

本記事ではPythonのグラフライブラリとして使えるplotlyによる「ペアプロット図(散布図行列)」の作成について紹介します。

2021年現在も流行している機械学習コンテストなどで良く用いられるのがペアプロットです。

特徴として変数間の相関関係を一望できる点が挙げられます。

RやR^2といった具体的な統計指標をイメージしながら、全ての変数同士の組み合わせと比較できるので、どの変数に目をつけるべきか知れるツールとして使えます。

今回もplotlyチュートリアルのペアプロットから例題を抜粋しながら解説していきます。

①基本的なペアプロット図

②z方向の変数も色で表すペアプロット図

③対角部分を削除したペアプロット図

go.Splomを使用したペアプロット図

⑤おまけ:レイアウトの変更色々

今回紹介する例題ではグラフの元データを容易する必要がなく、下記プログラムを完コピすればグラフを作成することができます。

まずは実際に手を動かしてコードを書いて見てください。

AnacondaのインストールとJupyter notebookの準備

Anacondaのインストールがまだの方はこちらの記事を参考に準備してください。

Jupyter notebookの起動と使い方まで紹介しています。

【2020年版、Python3をはじめよう!(Mac&Win対応)】AnacondaとJupyter notebookの始め方 はじめに Jupyter notebookとは Jupyter notebook(ジュピター・ノートブック)はブラウザ上で...

plotlyとplotly_expressのインストール

導入から基本的な使い方まで紹介しています。

Pythonをグラフ作成で学ぶ【plotly_expressで1行プログラミング】 【本記事の目標】 たった1行で動くグラフを作成し、体感する! これからPythonを学びたい人向けに朗報です。 今から紹介...

それではペアプロットについて解説していきます。

①基本的なペアプロット図

plotly_express、plotly_graph_objectsを使用するので、まずはインポートします。

import plotly.express as px
import plotly.graph_objects as go

次にデータを用意します。

今回はplotly_expressで用意されているpx.data.iris()を使用します。

px.data.iris()にはアヤメ(iris)のデータが格納されています。

がくの長さ(sepal_length)、 がくの幅(sepal_width)、 花弁の長さ(petal_length)、 花弁の幅(petal_width)、 アヤメの品種(species)、 品種No.(species_id)の情報が入っています。

df = px.data.iris()

中身をみてみましょう。

px.scatter_matrixでペアプロット図が書けます。

fig=px.scatter_matrix(df)のたった1行でOKです。

figはただの変数なので何でも大丈夫です。figure=図を意味しているので良く使われる表現です。

fig.show()で図が表示されます。

fig = px.scatter_matrix(df)
fig.show()

②z方向の変数も色で表すペアプロット図

①と違う点はdimensions=[]で変数を指定しているところ、さらにcolor=””でz方向の変数を指定して色分けすることができます。

fig = px.scatter_matrix(df,
    dimensions=["sepal_width", "sepal_length", "petal_width", "petal_length"],
    color="species")
fig.show()

③レイアウトを工夫したペアプロット図

labels={}を使用すると、内包表記でreplaceを連続的に行うことで対角の同じ変数同士の部分を空白にできます。

fig = px.scatter_matrix(df,
    dimensions=["sepal_width", "sepal_length", "petal_width", "petal_length"],
    color="species", symbol="species",
    title="Scatter matrix of iris data set",
    labels={col:col.replace('_', ' ') for col in df.columns}) # remove underscore
fig.update_traces(diagonal_visible=False)
fig.show()

go.Splomを使用したペアプロット図

ここからはplotly_graph_objectsを使った例を紹介します。

その場合、go.Splomを使って作成します。

いくつか例を連続して表示していきます。

レイアウトについては私の別の記事を参考にしてください。

https://cafe-mickey.com/category/python/plotly/

df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/iris-data.csv')

# The Iris dataset contains four data variables, sepal length, sepal width, petal length,
# petal width, for 150 iris flowers. The flowers are labeled as `Iris-setosa`,
# `Iris-versicolor`, `Iris-virginica`.

# Define indices corresponding to flower categories, using pandas label encoding
index_vals = df['class'].astype('category').cat.codes

fig = go.Figure(data=go.Splom(
                dimensions=[dict(label='sepal length',
                                 values=df['sepal length']),
                            dict(label='sepal width',
                                 values=df['sepal width']),
                            dict(label='petal length',
                                 values=df['petal length']),
                            dict(label='petal width',
                                 values=df['petal width'])],
                text=df['class'],
                marker=dict(color=index_vals,
                            showscale=False, # colors encode categorical variables
                            line_color='white', line_width=0.5)
                ))


fig.update_layout(
    title='Iris Data set',
    dragmode='select',
    width=600,
    height=600,
    hovermode='closest',
)

fig.show()
df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/iris-data.csv')
index_vals = df['class'].astype('category').cat.codes

fig = go.Figure(data=go.Splom(
                dimensions=[dict(label='sepal length',
                                 values=df['sepal length']),
                            dict(label='sepal width',
                                 values=df['sepal width']),
                            dict(label='petal length',
                                 values=df['petal length']),
                            dict(label='petal width',
                                 values=df['petal width'])],
                diagonal_visible=False, # remove plots on diagonal
                text=df['class'],
                marker=dict(color=index_vals,
                            showscale=False, # colors encode categorical variables
                            line_color='white', line_width=0.5)
                ))


fig.update_layout(
    title='Iris Data set',
    width=600,
    height=600,
)

fig.show()
df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/iris-data.csv')
index_vals = df['class'].astype('category').cat.codes

fig = go.Figure(data=go.Splom(
                dimensions=[dict(label='sepal length',
                                 values=df['sepal length']),
                            dict(label='sepal width',
                                 values=df['sepal width']),
                            dict(label='petal length',
                                 values=df['petal length']),
                            dict(label='petal width',
                                 values=df['petal width'])],
                showupperhalf=False, # remove plots on diagonal
                text=df['class'],
                marker=dict(color=index_vals,
                            showscale=False, # colors encode categorical variables
                            line_color='white', line_width=0.5)
                ))


fig.update_layout(
    title='Iris Data set',
    width=600,
    height=600,
)

fig.show()
df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/iris-data.csv')
index_vals = df['class'].astype('category').cat.codes

fig = go.Figure(data=go.Splom(
                dimensions=[dict(label='sepal length',
                                 values=df['sepal length']),
                            dict(label='sepal width',
                                 values=df['sepal width'],
                                 visible=False),
                            dict(label='petal length',
                                 values=df['petal length']),
                            dict(label='petal width',
                                 values=df['petal width'])],
                text=df['class'],
                marker=dict(color=index_vals,
                            showscale=False, # colors encode categorical variables
                            line_color='white', line_width=0.5)
                ))


fig.update_layout(
    title='Iris Data set',
    width=600,
    height=600,
)

fig.show()

⑤おまけ:レイアウトの変更色々

dfd = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/diabetes.csv')
textd = ['non-diabetic' if cl==0 else 'diabetic' for cl in dfd['Outcome']]

fig = go.Figure(data=go.Splom(
                  dimensions=[dict(label='Pregnancies', values=dfd['Pregnancies']),
                              dict(label='Glucose', values=dfd['Glucose']),
                              dict(label='BloodPressure', values=dfd['BloodPressure']),
                              dict(label='SkinThickness', values=dfd['SkinThickness']),
                              dict(label='Insulin', values=dfd['Insulin']),
                              dict(label='BMI', values=dfd['BMI']),
                              dict(label='DiabPedigreeFun', values=dfd['DiabetesPedigreeFunction']),
                              dict(label='Age', values=dfd['Age'])],
                  marker=dict(color=dfd['Outcome'],
                              size=5,
                              colorscale='Bluered',
                              line=dict(width=0.5,
                                        color='rgb(230,230,230)')),
                  text=textd,
                  diagonal=dict(visible=False)))

title = "Scatterplot Matrix (SPLOM) for Diabetes Dataset<br>Data source:"+\
        " <a href='https://www.kaggle.com/uciml/pima-indians-diabetes-database/data'>[1]</a>"
fig.update_layout(title=title,
                  dragmode='select',
                  width=1000,
                  height=1000,
                  hovermode='closest')

fig.show()

以上で解説は終わりです。

Plotlyに関する書籍紹介

↓Plotlyについて学べる数少ない参考書です。

Mickey@コーヒー好きエンジニア

【製造業×プログラミング×AI】Python/VBAを活用した業務改善、Streamlit/Plotlyを活用したWebアプリ開発について初心者向けに発信中|趣味は自家焙煎コーヒー作り|noteでは焙煎理論を発信|ココナラではプログラミングに関する相談,就職/転職やコーヒーに関する相談などのサービスをやっています