Decision Treeのvisualisation
Decision Treeを可視化
Decision Treeは図にして俯瞰すると大変理解しやすくなる。ありふれた(諸先輩方には既知)コードではあるが自分用のメモとして残しておく。
準備
Decision Treeの作り方をおさらい。sklearnのcancer dataを使う
- Impurity: Gini
- Max Depth 5
from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split #Extract Data cn = load_breast_cancer() X_cn_df = pd.DataFrame(data=cn.data, columns=cn.feature_names) y_cn = pd.Series(data=cn.target) #model fit X_train, X_test, y_train, y_test = train_test_split(X_cn_df, y_cn, random_state=0) dt=DecisionTreeClassifier(criterion='gini', max_depth=5, random_state=0) dt.fit(X_train,y_train) #score dt.score(X_train,y_train) print('Accuracy(Train):{:.3f}'.format(dt.score(X_train,y_train))) print('Accuracy(Test):{:.3f}'.format(dt.score(X_test,y_test)))
Visualisation Code
from sklearn import tree import pydotplus from sklearn.externals.six import StringIO from IPython.display import Image data = StringIO() tree.export_graphviz(dt, out_file=data) visual = pydotplus.graph_from_dot_data(data.getvalue()) Image(visual.create_png())