{"id":2059,"date":"2014-02-26T14:37:08","date_gmt":"2014-02-26T14:37:08","guid":{"rendered":"http:\/\/www.garysieling.com\/blog\/?p=2059"},"modified":"2014-02-26T14:37:08","modified_gmt":"2014-02-26T14:37:08","slug":"rending-scikit-decision-trees-d3-js","status":"publish","type":"post","link":"https:\/\/www.garysieling.com\/blog\/rending-scikit-decision-trees-d3-js\/","title":{"rendered":"Rendering scikit Decision Trees in D3.js"},"content":{"rendered":"<p>Scikit-learn provides routines to <a href=\"http:\/\/scikit-learn.org\/stable\/modules\/generated\/sklearn.tree.export_graphviz.html\">export decision trees to a format called Graphviz<\/a>, although typically this is used to provide an image of a chart.<\/p>\n<p>For some applications this is valuable, but if the product of machine learning is a the ability to generate models (rather than predictions), it would be preferable to provide interactive models.<\/p>\n<p>Automatically construction decision trees, for instance, might allow you to discover patterns in underlying data, e.g. determining that many failures are caused by a particular device or vendor. In this scenario, being able to predict failure is relatively useless, since the goal is to take corrective action.<\/p>\n<p>There are many awesome interactive tree examples with D3.js and the example that follows will show how to link these two products together.<\/p>\n<p>Scikit-learn provides a function called export_graphviz, which I pulled changed to export JSON (the library would probably benefit from adding API calls that let you iterate over their trees, so this is not needed)<\/p>\n<p>The whole function is a bit long, so feel free to skip to the next section.<\/p>\n<pre lang=\"python\">\ndef viz(decision_tree, feature_names=None):\n  from warnings import warn\n\n  js = \"\"\n\n  def node_to_str(tree, node_id, criterion):\n    if not isinstance(criterion, sklearn.tree.tree.six.string_types):\n      criterion = \"impurity\"\n\n    value = tree.value[node_id]\n    if tree.n_outputs == 1:\n      value = value[0, :]\n\n    if tree.children_left[node_id] == sklearn.tree._tree.TREE_LEAF:\n      return '{\"id\": \"%s\", \"criterion\": \"%s\", \"impurity\": \"%s\", \"samples\": \"%s\", \"value\": \"%s\"}' \\\n             % (node_id, \n                criterion,\n                tree.impurity[node_id],\n                tree.n_node_samples[node_id],\n                value)\n    else:\n      if feature_names is not None:\n        feature = feature_names[tree.feature[node_id]]\n      else:\n        feature = tree.feature[node_id]\n\n      return '\"id\": \"%s\", \"rule\": \"%s <= %.4f\", \"%s\": \"%s\", \"samples\": \"%s\"' \\\n             % (node_id, \n                feature,\n                tree.threshold[node_id],\n                criterion,\n                tree.impurity[node_id],\n                tree.n_node_samples[node_id])\n\n  def recurse(tree, node_id, criterion, parent=None, depth=0):\n    tabs = \"  \" * depth\n    js = \"\"\n\n    left_child = tree.children_left[node_id]\n    right_child = tree.children_right[node_id]\n\n    js = js + \"\\n\" + \\\n         tabs + \"{\\n\" + \\\n         tabs + \"  \" + node_to_str(tree, node_id, criterion)\n\n    if left_child != sklearn.tree._tree.TREE_LEAF and depth < 6:\n      js = js + \",\\n\" + \\\n           tabs + '  \"left\": ' + \\\n           recurse(tree, \\\n                   left_child, \\\n                   criterion=criterion, \\\n                   parent=node_id, \\\n                   depth=depth + 1) + \",\\n\" + \\\n           tabs + '  \"right\": ' + \\\n           recurse(tree, \\\n                   right_child, \\\n                   criterion=criterion, \\\n                   parent=node_id,\n                   depth=depth + 1)\n\n    js = js + tabs + \"\\n\" + \\\n         tabs + \"}\"\n\n    return js\n\n  if isinstance(decision_tree, sklearn.tree.tree.Tree):\n    js = js + recurse(decision_tree, 0, criterion=\"impurity\")\n  else:\n    js = js + recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion)\n\n  return js\n\ncols = dict()\nfor i, c in enumerate(income_trn.columns):\n  cols[i] = c.name\n\nprint viz(clf, feature_names=cols)\n<\/pre>\n<p>I'm extending a previous example I wrote, which <a href=\"http:\/\/www.garysieling.com\/blog\/building-decision-tree-python-postgres-data\">builds a decision tree for income data from Postgres data<\/a>. The 'rules' here are what determines whether the left or right path down a tree applies to a data point. The internal structure of the tree numbers these as column 1, column 2, etc, but the function above lets you provide a mapping, so these can be turned back into column headers.<\/p>\n<pre lang=\"javascript\">\nx = {\n  id: '0', rule: 'marital_status <= 12.0000', 'gini': '0.360846845083', 'samples': '16281',\n  left: \n  {\n    id: '1', rule: 'capital_gain <= 86.5000', 'gini': '0.0851021980364', 'samples': '5434',\n    left: \n    {\n      id: '2', rule: 'capital_loss <= 2372.0000', 'gini': '0.0634649568944', 'samples': '5212',\n      left: \n      {\n        id: '3', rule: 'education <= 46.0000', 'gini': '0.0584505343', 'samples': '5177'      \n      },\n      right: \n      {\n        id: '570', rule: 'capital_loss <= 2558.5000', 'gini': '0.489795918367', 'samples': '35'      \n      }    \n    },\n    right: \n    {\n      id: '577', rule: 'education_num <= 47.0000', 'gini': '0.43507020534', 'samples': '222',\n      left: \n      {\n        id: '578', rule: 'capital_gain <= 221.5000', 'gini': '0.272324674466', 'samples': '123'      \n      },\n      right: \n      {\n        id: '607', rule: 'capital_gain <= 817.5000', 'gini': '0.499540863177', 'samples': '99'      \n      }    \n    }  \n  },\n  right: \n  {\n    id: '632', rule: 'marital_status <= 55.0000', 'gini': '0.44372508662', 'samples': '10847',\n    left: \n    {\n      id: '633', rule: 'education_num <= 47.0000', 'gini': '0.493880410242', 'samples': '7403',\n      left: \n      {\n        id: '634', rule: 'capital_gain <= 22.0000', 'gini': '0.454579586462', 'samples': '4363'      \n      },\n      right: \n      {\n        id: '2885', rule: 'education_num <= 113.5000', 'gini': '0.486689750693', 'samples': '3040'      \n      }    \n    },\n    right: \n    {\n      id: '4292', rule: 'capital_gain <= 86.5000', 'gini': '0.164770726851', 'samples': '3444',\n      left: \n      {\n        id: '4293', rule: 'education_num <= 47.0000', 'gini': '0.126711048456', 'samples': '3207'      \n      },\n      right: \n      {\n        id: '4902', rule: 'education <= 46.0000', 'gini': '0.478627000659', 'samples': '237'      \n      }    \n    }  \n  }\n}\n<\/pre>\n<p>To make this render with D3, I <a href=\"http:\/\/mbostock.github.io\/d3\/talk\/20111018\/treemap.html\">picked one example<\/a> structure, which looks like this:<\/p>\n<p><a href=\"http:\/\/172.104.26.128\/wp-content\/uploads\/2014\/02\/sklearn1.png\"><img loading=\"lazy\" decoding=\"async\" src=\"http:\/\/www.garysieling.com\/blog\/wp-content\/uploads\/2014\/02\/sklearn1-578x299.png\" alt=\"sklearn1\" width=\"578\" height=\"299\" class=\"aligncenter size-large wp-image-2069\" \/><\/a><\/p>\n<p>The nice thing about this visualization is that while it shows leaf nodes, it groups them into the tree hierarchy and lets you selectively drill down by clicking to zoom.<\/p>\n<p>The JSON structure for every example isn't necessarily guaranteed to be the same, so I've written a function to restructure the above tree into exactly what this needs (this is easier than rewriting the above python code for every test case, or fixing the D3 examples)<\/p>\n<pre lang=\"Javascript\">\nfunction toJson(x) \n{\n  var result = {};\n  result.name = x.rule;\n\n\tif ( (!!x.left && !x.left.value) ||\n\t\t\t (!!x.right && !x.right.value) )\n    result.children = [];\n\telse\n    result.size = parseInt(x.samples);\n\n  var index = 0;\n  if (!!x.left && !x.left.value)\n    result.children[index++] = toJson(x.left);\n\n  if (!!x.right && !x.right.value)\n    result.children[index++] = toJson(x.right);\n\n  return result;\n}\n<\/pre>\n<p>Then, the only change you need to make to the D3 example is to add this function and add a call:<\/p>\n<pre lang=\"Javascript\">\n  node = root = toJson(data);\n<\/pre>\n<p><a href=\"http:\/\/172.104.26.128\/wp-content\/uploads\/2014\/02\/sklearn.png\"><img loading=\"lazy\" decoding=\"async\" src=\"http:\/\/www.garysieling.com\/blog\/wp-content\/uploads\/2014\/02\/sklearn-578x305.png\" alt=\"sklearn\" width=\"578\" height=\"305\" class=\"aligncenter size-large wp-image-2068\" \/><\/a><\/p>\n<p>Now, this isn't the prettiest and is only one view of the tree (leaves), but it wires up enough parts to get you set up to find the right visualization for what you're doing.<\/p>\n","protected":false},"excerpt":{"rendered":"<p>Scikit-learn provides routines to export decision trees to a format called Graphviz, although typically this is used to provide an image of a chart. For some applications this is valuable, but if the product of machine learning is a the ability to generate models (rather than predictions), it would be preferable to provide interactive models. &hellip; <\/p>\n<p class=\"link-more\"><a href=\"https:\/\/www.garysieling.com\/blog\/rending-scikit-decision-trees-d3-js\/\" class=\"more-link\">Continue reading<span class=\"screen-reader-text\"> &#8220;Rendering scikit Decision Trees in D3.js&#8221;<\/span><\/a><\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"om_disable_all_campaigns":false,"_monsterinsights_skip_tracking":false,"_monsterinsights_sitenote_active":false,"_monsterinsights_sitenote_note":"","_monsterinsights_sitenote_category":0,"footnotes":""},"categories":[4],"tags":[138,140,152,153,277,303,352,447,493],"aioseo_notices":[],"amp_enabled":true,"_links":{"self":[{"href":"https:\/\/www.garysieling.com\/blog\/wp-json\/wp\/v2\/posts\/2059"}],"collection":[{"href":"https:\/\/www.garysieling.com\/blog\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.garysieling.com\/blog\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.garysieling.com\/blog\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.garysieling.com\/blog\/wp-json\/wp\/v2\/comments?post=2059"}],"version-history":[{"count":0,"href":"https:\/\/www.garysieling.com\/blog\/wp-json\/wp\/v2\/posts\/2059\/revisions"}],"wp:attachment":[{"href":"https:\/\/www.garysieling.com\/blog\/wp-json\/wp\/v2\/media?parent=2059"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.garysieling.com\/blog\/wp-json\/wp\/v2\/categories?post=2059"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.garysieling.com\/blog\/wp-json\/wp\/v2\/tags?post=2059"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}