Your First Plot
You’re ready to start coding your first plot! This will be a stack plot showing the combined debt growth over time for three different regions. You’ll be using random numbers rather than real data.
You’ll import the necessary modules, generate random data with numpy
, and plot that data with matplotlib
:
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(444)
rng = np.arange(50)
rnd = np.random.randint(0, 10, size=(3, rng.size))
yrs = 1950 + rng
print(rng + rnd)
fig, ax = plt.subplots(figsize=(5, 3))
ax.stackplot(yrs, rng+rnd, labels=['Eastasia', 'Eurasia', 'Oceania'])
ax.set_title('Combined debt growth over time')
ax.legend(loc='upper left')
ax.set_ylabel('Total debt')
ax.set_xlim(xmin=yrs[0], xmax=yrs[-1])
fig.tight_layout()
plt.show()
00:00 Finally, we are ready to start coding our first plot. This will be a stack plot showing the combined debt growth over time for three different regions. Of course, this won’t be real data—just random numbers.
00:16 We’ll start by importing the necessary modules, and then we’ll use NumPy to generate our pseudorandom data in the form of ndarrays, and then finally, we’ll use Matplotlib in order to actually plot that data.
00:31
I’m working here in a conda environment called LearningMPL
, which has matplotlib
and all of its dependencies installed. I’ve also created a new Python file called plot1.py
that we can work in.
00:46
We’ll start by importing pyplot
. import matplotlib.pyplot as plt
.
00:55
We’re not going to rely on pyplot
exclusively, but we do need it to grab our initial Figure
and Axes
objects. We’ll also import numpy as np
.
01:08
Now, let’s set the random seed so that we get the same pseudorandom output each time we run the program. We can do this with np.random.seed()
and I’ll give it a value of 444
.
01:24
Next, we have to generate the data to plot. Let’s create a new variable called rng
, and that will store the result of np.arange()
, and we’ll pass in 50
.
01:38
This will be a one-dimensional ndarray counting from 0
to 49
inclusive. Now, let’s create another variable called rnd
. For this, I want to create a 2D ndarray containing three one-dimensional ndarrays.
01:58
Each of those arrays should contain 50 numbers, from 0
to 9
inclusive. That function call looks like this: np.random.randint(0, 10, size=(3, rng.size))
.
02:23 This is the data that will actually be plotted along the y-axis. We’ll create one more ndarray to represent the year labels along the x-axis.
02:34
This one should be identical to the rng
array, but with 1950 added to each element. We can do this by simply adding 1950
to the existing ndarray, just like this. Before we start plotting, I want to print the result of adding rng + rnd
to the console.
02:57
This is the data that will be plotted along the y-axis of the stack plot, and having those ndarrays onscreen is going to help me to explain how the stack plot works later on. To start plotting, we need two of Matplotlib’s objects, the Figure
and the Axes
, which is contained inside of the Figure
.
03:21
We don’t use pyplot
for much because we’re taking the stateless approach, but we can use it to generate those objects. To do that, let’s create two variables called fig
and ax
and grab those objects with plt.subplots()
, passing in a tuple of (5, 3)
for the figsize
.
03:46
This figsize
represents how large the figure should be onscreen. Now that we have the Figure
and Axes
objects, we can start adding to the Axes
, which is the actual stack plot we’ll see onscreen. To create the stack plot, we can call the .stackplot()
method on our ax
object.
04:07
First, it needs a 1D array to populate the x-axis, and so I’ll pass in our yrs
variable. It also needs a 2D array to plot along the y-axis.
04:20
This will be rng+rnd
. Basically, this will increment the first element in each subarray in rnd
by 0
, then the second element in each by 1
, then the third by 2
, all the way until the last element is incremented by 49
.
04:42
We’ll see this sum printed to the console, which will make things a little bit more clear later on. Finally, this method needs a labels
argument, so let’s write labels=
and then a list containing 'Eastasia'
, 'Eurasia'
, and 'Oceana'
.
05:02
And that’s the most difficult line of code in this entire program. Next, we can set some other properties of the Axes
object. We can set the title of the Axes
with .set_title()
and I’ll pass in the string 'Combined debt growth over time'
.
05:23
We can also give it a legend, with the .legend()
method, and we need to pass in a location for that. Let’s say 'upper left'
. I also want to give this plot a y label, which we can do with the .set_ylabel()
method.
05:39
I’ll pass in the string 'Total debt'
. This will be displayed along the vertical axis. Next, we need to set the upper and lower limits for the years displaying along the x-axis.
05:53
This will ensure that our actual graphic is limited to just the year data that we’re working with. To do this, we can use the .set_xlim()
method, and this will take both an xmin
and an xmax
.
06:10
The xmin
should be the first element in our yrs
array, which we can access as if it were a list in Python, with square bracket notation.
06:21
I’ll say yrs[0]
. For the xmax
, we want the last element, which is yrs[-1]
. Great! Our data is plotted. The last thing we have to do is clean up the Figure
, which will remove any extra whitespace.
06:43
We can do that with the .tight_layout()
method, which we call directly on our Figure
object instead of the Axes
.
06:53
Now, because we’re not running this in any sort of interactive mode, we have to say plt.show()
in order to actually tell pyplot
to show our plot onscreen.
07:06 I will run this program, and we get our plot! We see it contains everything we’d expect—our titles, the label, the years along the x-axis, and our stack plot.
07:19
But how exactly does this stack plot work? To answer that question, let’s look at the ndarray
that was printed to the console. This ndarray
contains three sub one-dimensional arrays, one for each region in the plot here.
07:36
The first array starts with a 3
, and so the first plot point—representing the blue part of the graph—should be 3 units from the bottom of the y-axis.
07:49
The second array starts with 8
, but that doesn’t mean that the next point—representing the orange graph—starts at 8. Rather it starts 8 from the previous point, which is actually 11.
08:04
That’s why it’s called a stack plot—because each data point stacks on top of the previous. Finally, the last array starts at 5
, and so the last point—representing the green region—starts at 16, or 8 + 3 + 5.
08:24 This pattern continues for every x value in the plot. Just to show that this is accurate, I’m going to move over to my plot onscreen and I’ll use the zoom tool, which is this little magnifying glass.
08:39 I’ll draw a small box around the leftmost region of the plot, which will resize the entire window to that plot region—effectively, zooming in.
08:50 If I hover my mouse over the first data point for each color, you can see in the bottom right corner the y value of my mouse coordinate. Hovering over the first blue point shows 3, the first orange point shows 11, and the first green point shows 16—just as we’d expect.
09:11 We can use the other tools down here to pan around the graph, change between previous views, configure the spacing in the window, and even save a PNG file of the plot.
Austin Cepalia RP Team on Oct. 27, 2019
Oh yeah, there is a minor typo in the description. Unfortunately I can’t edit that directly, but I’ll see if I can get that fixed
Dan Bader RP Team on Oct. 27, 2019
I just fixed the typo, thanks for the heads up @avinashhm :)
avinashhm on Oct. 29, 2019
no worries .. thanks for fixing @dan, @austin !!
alazejha on Jan. 2, 2021
Dear Austin,
I tried the code, but the error message pops up:
AttributeError Traceback (most recent call last)
<ipython-input-2-d96da0308728> in <module>
9 print(rng+rnd)
10
---> 11 fig, ax = plt.subplots(figsize =(5,3))
12
13 ax.stackplot(yrs, rnd+rng , labels =["Asia", "Europe", "Oceania"])
AttributeError: module 'matplotlib' has no attribute 'subplots'
What I’m doing wrong? A
Bartosz Zaczyński RP Team on Jan. 4, 2021
@alazejha It looks like you might have aliased the wrong module. Make sure you have the following import statement at the top of your code:
import matplotlib.pyplot as plt
Judging by the error message, I’m guessing you did this instead:
import matplotlib as plt
However, this is only my guess since you haven’t shared the full code snippet. Let me know if that helps.
rachelannkirkland on Feb. 7, 2021
I am using Jupyter notebooks and the plot does not appear with plt.show()
. How do I view the figure? I have previously used pyplot to create simple plots and it has worked with Jupyter notebooks.
I also have Visual Studio Code and PyCharm free versions on my computer which I could use instead. What do you recommend? I looked for LearningMPL but couldn’t find it on Google.
Thanks for your help.
jwsayyes on Aug. 1, 2022
@rachelannkirkland, you may want to try just type fig instead of plt.show().
Alireza637 on Sept. 21, 2023
That zoom thingy at the end is specific to your IDE / env? Cause I don’t have it in my Jupyter Notebook!
Bartosz Zaczyński RP Team on Sept. 21, 2023
@Alireza637 When you call plt.show()
in a Jupyter Notebook, it renders the plot using an alternative backend. Normally, the plot ends up shown in Matplotlib’s popup window, which has this toolbar with the zoom icon.
Alireza637 on Sept. 22, 2023
@Bartosz Zaczyński Thanks for your response. Still, it renders the plot in the notebook itself and not in a popup. Tho, it’s not that big of a deal, was wondering what may be different in my env than the one in the tutorial.
Bartosz Zaczyński RP Team on Sept. 22, 2023
@Alireza637 Matplotlib renders your plot in the notebook because that’s the default backend it uses in Jupyter notebooks. You can change the backend using the %matplotlib
magic in your cell:
%matplotlib qt
import matplotlib.pyplot as plt
plt.plot([5, 3, 6, 2, 8, 5])
plt.show()
This will display the popup from the video with the toolbar and the zoom icon. Before you try it, however, make sure that you have one of these libraries installed in your virtual environment: PyQt6, PySide6, PyQt5, PySide2. If not, then try this command:
(venv) $ python -m pip install pyqt6
Available Matplotlib backends are:
tk
gtk
gtk3
gtk4
wx
qt4
qt5
qt6
qt
osx
nbagg
webagg
notebook
agg
svg
pdf
ps
inline
ipympl
widget
Hope this helps!
Alireza637 on Sept. 25, 2023
@Bartosz Zaczyński It did, thanks a bunch
Bartosz Zaczyński RP Team on Sept. 26, 2023
You’re welcome 😊
Become a Member to join the conversation.
avinashhm on Oct. 25, 2019
Hi There , great tutorial !
minor correction .. i guess
plt.sobplots
was intended to beplt.subplots
?