Wednesday, March 24, 2010

Arrows in matplotlib



The last example would have looked better with arrows. Here is a first attempt. The getArrow function is complicated a bit by my wanting to adjust the start and stop positions of the arrows so that they don't overlap the points. The commented print statements were useful in debugging that function. (It took me a lot longer than I care to admit).

I have a small problem with the arrows plotting after the points. The z-order for an arrow is not respected by plt. So the arrows plot over the points, and I haven't figured out how to fix that yet.


import numpy as np
import matplotlib.pyplot as plt

def getArrow(p1,p2,i):
# we need to subtract some from each end
# slope = m
w = p2.x - p1.x
h = p2.y - p1.y
#print p1.x,p1.y
#print p2.x,p2.y
#print 'w',w,'h',h

dr = 0.03
if w == 0:
dy = dr
dx = 0
else:
theta = np.arctan(np.abs(h/w))
dx = dr*np.cos(theta)
dy = dr*np.sin(theta)
#print 'dx',dx,'dy',dy

if w < 0: dx *= -1
if h < 0: dy *= -1
#print 'dx',dx,'dy',dy
w -= 2*dx
h -= 2*dy
#print 'w',w,'h',h
x = p1.x + dx
y = p1.y + dy
#print 'x',x,'y',y

a = plt.Arrow(x,y,w,h,
width=0.05,zorder=i+1)
a.set_facecolor('0.7')
a.set_edgecolor('w')
return a

class Point: pass

N = 10
L = np.random.uniform(size=N*2)
pL = list()
for i in range(0,N*2,2):
p = Point()
p.x,p.y = L[i],L[i+1]
pL.append(p)

ax = plt.axes()
for i,p in enumerate(pL):
if i:
a = getArrow(pL[i-1],p,i)
ax.add_patch(a)
plt.scatter(p.x,p.y,s=250,zorder=1)

ax.set_xlim(-0.01,1.01)
ax.set_ylim(-0.01,1.01)
plt.savefig('example.png')