draw binary tree jupyter notebook
I have attempted to display binary tree naturally with matplotlib. My tree has form of like this.
{'root': [{'left1': [{'left1_2': ['res1', 'res2']}, 'res1']}, {'right1': [{'left_2': ['res1', 'res2']}, 'res2']}]}
As you can see, the tree is something like nested dictionary that has list as a value for each key. And the first element of list has to be locateded at left and the second element of list has to be located at right from parent node. That is, I want the tree to be displayed like this.
My code as follows.
Define Node class.
class Node: def __init__(self,text='',x=None,y=None,isRoot=False,parentNode=None, leftNode=None,rightNode=None,textBox=None,boxWidth=None, ): self.x = x self.y = y self.text = text self.IsRoot = isRoot self.parentNode = parentNode self.leftNode = leftNode self.rightNode = rightNode self.textBox = textBox self.boxWidth = boxWidth self.isTerminal=False def getParentNode(self): return self.parentNode def setParentNode(self,parentNode): self.parentNode = parentNode def getIsTerminal(self): return self.isTerminal def setIsTerminal(self, isTerminal): self.isTerminal = isTerminal return def getLeftNode(self): return self.leftNode def setLeftNode(self, leftNode): self.leftNode = leftNode def getRightNode(self): return self.rightNode def setRightNode(self, rightNode): self.rightNode = rightNode def getText(self): return self.text def setText(self,text): self.text = text def getX(self): return self.x def setX(self,x): self.x = x def getY(self): return self.y def setY(self,y): self.y = y def __str__(self): return f'{self.x}, {self.y}, {self.text}' def __repr__(self): return f'TreeNode({self.x}, {self.y}, {self.text})'
Define Tree class.
class Tree: def __init__(self): import numpy as np self.depth = 0 self.width = 1 self.height = 1 self.verticalSpace = 0.5 self.xMax = -np.infty self.xMin = np.infty self.yMax = -np.infty self.yMin = np.infty def createNode(self,text='',x=None,y=None,isRoot=False, parentNode=None,leftNode=None,rightNode=None, isTerminal=False): return Node(text,x,y,isRoot,parentNode,leftNode,rightNode,isTerminal) def addNode(self,node=None,isLeft=True,text='',x=None,y=None, isRoot=False,parentNode=None,isTerminal=False): if x+self.width+2 > self.xMax: self.setXMax(x+self.width) if x-self.width < self.xMin: self.setXMin(x-self.width) if y > self.yMax: self.setYMax(y) if y-self.height-self.verticalSpace < self.yMin: self.setYMin(y-self.height) if node is None: return self.createNode(text,x,y,isRoot=True) assert isinstance(node,Node) if isLeft: node.leftNode = self.createNode(text,x,y,parentNode=node,isTerminal=isTerminal) return node.leftNode else: node.rightNode = self.createNode(text,x,y,parentNode=node,isTerminal=isTerminal) return node.rightNode def getXMax(self): return self.xMax def setXMax(self,xMax): self.xMax = xMax return def getXMin(self): return self.xMin def setXMin(self,xMin): self.xMin = xMin return def getYMax(self): return self.yMax def setYMax(self,yMax): self.yMax = yMax return def getYMin(self): return self.yMin def setYMin(self,yMin): self.yMin = yMin return
Define "building_tree_node" function which generate decsendant node from the predifined root node.
def building_tree_node(tree, node, counter=0): box_x = 2 counter += 1 for key, value in tree.items(): node.setText(key) if contain_dict(value): if isinstance(value[0], dict): left_node = TREE.addNode(node=node,x=-box_x*counter,y=-1*counter, isLeft=True,text=list(value[0].keys())[0]) building_tree_node(tree=value[0], node=left_node, counter=counter) if isinstance(value[1], dict): right_node = TREE.addNode(node=node,x=box_x*counter,y=-1*counter, isLeft=False,text=list(value[1].keys())[0]) building_tree_node(tree=value[1], node=right_node, counter=counter) else: temp = TREE.addNode(node=node,x= box_x*counter,y=-1*counter, isLeft=False,text=value[1]) temp.setIsTerminal(True) if isinstance(value[1], dict): right_node = TREE.addNode(node=node,x=box_x*counter,y=-1*counter, isLeft=False,text=list(value[1].keys())[0]) building_tree_node(tree=value[1], node=right_node, counter=counter) temp = TREE.addNode(node=node,x=-box_x*counter,y=-1*counter, isLeft=True,text=value[0]) temp.setIsTerminal(True) else: temp = TREE.addNode(node=node,x=-box_x*counter,y=-1*counter, isLeft=True,text=value[0]) temp.setIsTerminal(True) temp = TREE.addNode(node=node,x= box_x*counter,y=-1*counter, isLeft=False,text=value[1]) temp.setIsTerminal(True) final_tree = TREE final_root = node return final_tree, final_root
Helper function
def contain_dict(l): res = False for v in l: if isinstance(v, dict): res = True return res
Generate tree and TREE class and root Node.
tree = dict() left1 = dict() right1 = dict() left1['left1'] = [{'left1_2':['res1','res2']},'res1'] right1['right1'] = [{'left2_2':['res1','res2']},'res2'] tree['root'] = [left1, right1] TREE = Tree() root = TREE.addNode(x=0,y=0,text='root')
Define function which draws a tree.
import matplotlib.pyplot as plt ft, fr = building_tree_node(tree,root) ## final_tree, root node has all decsendants fig = plt.figure(figsize=(10,10)) renderer = fig.canvas.get_renderer() ax = fig.add_subplot() ax.set_xlim(ft.getXMin(),ft.getXMax()) ax.set_ylim(ft.getYMin(),ft.getYMax()+1) width = ft.width height = ft.height def drawNode(node): if node is not None: if node.getIsTerminal(): bbox=dict(boxstyle='square',fc='green') else: bbox=dict(boxstyle='square',fc='yellow') text_box = ax.text(node.getX(),node.getY(),node.getText(),bbox=bbox,fontsize=15,ha='center',va='center') node.textBox = text_box if node.parentNode is not None: parentTextBox = node.parentNode.textBox pbb = parentTextBox.get_window_extent(renderer=renderer).inverse_transformed(ax.transData) currentTextBox = node.textBox cbb = currentTextBox.get_window_extent(renderer=renderer).inverse_transformed(ax.transData) ax.plot((node.parentNode.x, node.x), (node.parentNode.y-pbb.height*0.7, node.y),color='k') drawNode(node.leftNode) drawNode(node.rightNode)
When I run my code as follows:
drawNode(fr)
I have the result as below:
As you can see, right child was displayed appropriatly(of coure, not perfect location).
But left child did not split properly. Obviously, I think the problem is in "building_tree_node" function.
However, I had no idea because I was not familar with recursive coding. I want to know that the first left child of the tree is splited properly. My gratitude goes to any comment and idea. Additionally, I want to know about an algorithm for displaying binary tree naturally. Thank you for watching this.
Source: https://stackoverflow.com/questions/65255664/drawing-binary-tree-with-matplotlib
0 Response to "draw binary tree jupyter notebook"
Post a Comment