PyInv

プログラミングのメモ、海外投資のメモ

Decision Treeのvisualisation

Decision Treeを可視化

Decision Treeは図にして俯瞰すると大変理解しやすくなる。ありふれた(諸先輩方には既知)コードではあるが自分用のメモとして残しておく。


準備

Decision Treeの作り方をおさらい。sklearnのcancer dataを使う

  • Impurity: Gini
  • Max Depth 5
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

 #Extract Data
cn = load_breast_cancer()
X_cn_df = pd.DataFrame(data=cn.data, columns=cn.feature_names)
y_cn = pd.Series(data=cn.target)

#model fit
X_train, X_test, y_train, y_test = train_test_split(X_cn_df, y_cn, random_state=0)
dt=DecisionTreeClassifier(criterion='gini', max_depth=5, random_state=0)
dt.fit(X_train,y_train)

#score
dt.score(X_train,y_train)
print('Accuracy(Train):{:.3f}'.format(dt.score(X_train,y_train)))
print('Accuracy(Test):{:.3f}'.format(dt.score(X_test,y_test)))

Visualisation Code

from sklearn import tree
import pydotplus
from sklearn.externals.six import StringIO
from IPython.display import Image

data = StringIO()
tree.export_graphviz(dt, out_file=data)
visual = pydotplus.graph_from_dot_data(data.getvalue())
Image(visual.create_png())