Join us and get access to hundreds of tutorials and a community of expert Pythonistas.

Unlock This Lesson

This lesson is for members only. Join us and get access to hundreds of tutorials and a community of expert Pythonistas.

Unlock This Lesson

Hint: You can adjust the default video playback speed in your account settings.
Hint: You can set the default subtitles language in your account settings.
Sorry! Looks like there’s an issue with video playback 🙁 This might be due to a temporary outage or because of a configuration issue with your browser. Please see our video player troubleshooting guide to resolve the issue.

Your First Plot

Give Feedback

You’re ready to start coding your first plot! This will be a stack plot showing the combined debt growth over time for three different regions. You’ll be using random numbers rather than real data.

You’ll import the necessary modules, generate random data with numpy, and plot that data with matplotlib:

import matplotlib.pyplot as plt
import numpy as np

np.random.seed(444)

rng = np.arange(50)
rnd = np.random.randint(0, 10, size=(3, rng.size))
yrs = 1950 + rng

print(rng + rnd)

fig, ax = plt.subplots(figsize=(5, 3))

ax.stackplot(yrs, rng+rnd, labels=['Eastasia', 'Eurasia', 'Oceania'])
ax.set_title('Combined debt growth over time')
ax.legend(loc='upper left')
ax.set_ylabel('Total debt')
ax.set_xlim(xmin=yrs[0], xmax=yrs[-1])

fig.tight_layout()
plt.show()

00:00 Finally, we are ready to start coding our first plot. This will be a stack plot showing the combined debt growth over time for three different regions. Of course, this won’t be real data—just random numbers.

00:16 We’ll start by importing the necessary modules, and then we’ll use NumPy to generate our pseudorandom data in the form of ndarrays, and then finally, we’ll use Matplotlib in order to actually plot that data.

00:31 I’m working here in a conda environment called LearningMPL, which has matplotlib and all of its dependencies installed. I’ve also created a new Python file called plot1.py that we can work in.

00:46 We’ll start by importing pyplot. import matplotlib.pyplot as plt.

00:55 We’re not going to rely on pyplot exclusively, but we do need it to grab our initial Figure and Axes objects. We’ll also import numpy as np.

01:08 Now, let’s set the random seed so that we get the same pseudorandom output each time we run the program. We can do this with np.random.seed() and I’ll give it a value of 444.

01:24 Next, we have to generate the data to plot. Let’s create a new variable called rng, and that will store the result of np.arange(), and we’ll pass in 50.

01:38 This will be a one-dimensional ndarray counting from 0 to 49 inclusive. Now, let’s create another variable called rnd. For this, I want to create a 2D ndarray containing three one-dimensional ndarrays.

01:58 Each of those arrays should contain 50 numbers, from 0 to 9 inclusive. That function call looks like this: np.random.randint(0, 10, size=(3, rng.size)).

02:23 This is the data that will actually be plotted along the y-axis. We’ll create one more ndarray to represent the year labels along the x-axis.

02:34 This one should be identical to the rng array, but with 1950 added to each element. We can do this by simply adding 1950 to the existing ndarray, just like this. Before we start plotting, I want to print the result of adding rng + rnd to the console.

02:57 This is the data that will be plotted along the y-axis of the stack plot, and having those ndarrays onscreen is going to help me to explain how the stack plot works later on. To start plotting, we need two of Matplotlib’s objects, the Figure and the Axes, which is contained inside of the Figure.

03:21 We don’t use pyplot for much because we’re taking the stateless approach, but we can use it to generate those objects. To do that, let’s create two variables called fig and ax and grab those objects with plt.subplots(), passing in a tuple of (5, 3) for the figsize.

03:46 This figsize represents how large the figure should be onscreen. Now that we have the Figure and Axes objects, we can start adding to the Axes, which is the actual stack plot we’ll see onscreen. To create the stack plot, we can call the .stackplot() method on our ax object.

04:07 First, it needs a 1D array to populate the x-axis, and so I’ll pass in our yrs variable. It also needs a 2D array to plot along the y-axis.

04:20 This will be rng+rnd. Basically, this will increment the first element in each subarray in rnd by 0, then the second element in each by 1, then the third by 2, all the way until the last element is incremented by 49.

04:42 We’ll see this sum printed to the console, which will make things a little bit more clear later on. Finally, this method needs a labels argument, so let’s write labels= and then a list containing 'Eastasia', 'Eurasia', and 'Oceana'.

05:02 And that’s the most difficult line of code in this entire program. Next, we can set some other properties of the Axes object. We can set the title of the Axes with .set_title() and I’ll pass in the string 'Combined debt growth over time'.

05:23 We can also give it a legend, with the .legend() method, and we need to pass in a location for that. Let’s say 'upper left'. I also want to give this plot a y label, which we can do with the .set_ylabel() method.

05:39 I’ll pass in the string 'Total debt'. This will be displayed along the vertical axis. Next, we need to set the upper and lower limits for the years displaying along the x-axis.

05:53 This will ensure that our actual graphic is limited to just the year data that we’re working with. To do this, we can use the .set_xlim() method, and this will take both an xmin and an xmax.

06:10 The xmin should be the first element in our yrs array, which we can access as if it were a list in Python, with square bracket notation.

06:21 I’ll say yrs[0]. For the xmax, we want the last element, which is yrs[-1]. Great! Our data is plotted. The last thing we have to do is clean up the Figure, which will remove any extra whitespace.

06:43 We can do that with the .tight_layout() method, which we call directly on our Figure object instead of the Axes.

06:53 Now, because we’re not running this in any sort of interactive mode, we have to say plt.show() in order to actually tell pyplot to show our plot onscreen.

07:06 I will run this program, and we get our plot! We see it contains everything we’d expect—our titles, the label, the years along the x-axis, and our stack plot.

07:19 But how exactly does this stack plot work? To answer that question, let’s look at the ndarray that was printed to the console. This ndarray contains three sub one-dimensional arrays, one for each region in the plot here.

07:36 The first array starts with a 3, and so the first plot point—representing the blue part of the graph—should be 3 units from the bottom of the y-axis.

07:49 The second array starts with 8, but that doesn’t mean that the next point—representing the orange graph—starts at 8. Rather it starts 8 from the previous point, which is actually 11.

08:04 That’s why it’s called a stack plot—because each data point stacks on top of the previous. Finally, the last array starts at 5, and so the last point—representing the green region—starts at 16, or 8 + 3 + 5.

08:24 This pattern continues for every x value in the plot. Just to show that this is accurate, I’m going to move over to my plot onscreen and I’ll use the zoom tool, which is this little magnifying glass.

08:39 I’ll draw a small box around the leftmost region of the plot, which will resize the entire window to that plot region—effectively, zooming in.

08:50 If I hover my mouse over the first data point for each color, you can see in the bottom right corner the y value of my mouse coordinate. Hovering over the first blue point shows 3, the first orange point shows 11, and the first green point shows 16—just as we’d expect.

09:11 We can use the other tools down here to pan around the graph, change between previous views, configure the spacing in the window, and even save a PNG file of the plot.

avinashhm on Oct. 25, 2019

Hi There , great tutorial !

minor correction .. i guess plt.sobplots was intended to be plt.subplots ?

Austin Cepalia RP Team on Oct. 27, 2019

Oh yeah, there is a minor typo in the description. Unfortunately I can’t edit that directly, but I’ll see if I can get that fixed

Dan Bader RP Team on Oct. 27, 2019

I just fixed the typo, thanks for the heads up @avinashhm :)

avinashhm on Oct. 29, 2019

no worries .. thanks for fixing @dan, @austin !!

alazejha on Jan. 2, 2021

Dear Austin,

I tried the code, but the error message pops up:

AttributeError                            Traceback (most recent call last)
<ipython-input-2-d96da0308728> in <module>
      9 print(rng+rnd)
     10 
---> 11 fig, ax = plt.subplots(figsize =(5,3))
     12 
     13 ax.stackplot(yrs, rnd+rng , labels =["Asia", "Europe", "Oceania"])

AttributeError: module 'matplotlib' has no attribute 'subplots'

What I’m doing wrong? A

Bartosz Zaczyński RP Team on Jan. 4, 2021

@alazejha It looks like you might have aliased the wrong module. Make sure you have the following import statement at the top of your code:

import matplotlib.pyplot as plt

Judging by the error message, I’m guessing you did this instead:

import matplotlib as plt

However, this is only my guess since you haven’t shared the full code snippet. Let me know if that helps.

Become a Member to join the conversation.