Working With Multiple Subplots
A figure can have more than one subplot. Earlier, you learned that you can obtain your figure and axes objects by calling plt.subplots()
and passing in a figure size. This function can take two additional arguments:
- The number of rows
- The numbers of columns
These arguments determine how many axes objects will belong to the figure, and by extension, how many axes objects will be returned to you. In the following code, nrows
is set to 1
, and ncols
is set to 2
:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2)
This function returns two axes objects, which you store in a tuple.
00:00
A figure can have more than one subplot. Earlier, we learned how we can obtain our Figure
and Axes
objects with the plt.subplots()
function, passing in a figure size.
00:15 This function can also take two additional arguments, the number of rows and the number of columns.
00:23
These arguments dictate how many Axes
objects will belong to the Figure
, and by extension, how many Axes
objects will be returned to us.
00:34
In this example, I’ve set nrows=1
and ncols=2
, and so this function returns two Axes
objects, which I store in this tuple.
00:48
If I set nrows=2
and ncols=2
, the function would return four Axes
.
00:58
Let’s see how we can modify these two Axes
independently, creating two new visualizations in the process.
01:06
I’m here in a new file called plot2.py
and to save on time, I have already imported pyplot
and numpy
at the top. Just like before, we’re going to get our randomized data using numpy
. We’ll create a new variable called x
, and that will store the one-dimensional ndarray
obtained by calling randint()
with a lower limit of 1
, an upper limit of 11
, and a size of 50
, just like before. Now, I’ll create a new variable called y
and that will store x
plus a one-dimensional array of 50 random numbers, from 1
to 4
inclusive.
01:50
This means that a random number from 1
to 4
will be added to each element in our x
ndarray
.
01:59
I’m also going to create one more variable called data
, and that will store the two-dimensional ndarray
obtained by calling column_stack()
with our x
and y
arrays.
02:13
Now we’re done obtaining our data points, so we can use pyplot
to obtain our Figure
and Axes
objects. This time around, we’re going to have one Figure
and two Axes
and so I’ll write
fig,
and then a tuple of (ax1, ax2)
.
02:35
Now we’ll call the subplots()
function with an nrows
value of 1
, an ncols
value of 2
, and a figure size of (8, 4)
. The first Axes
is going to be a scatter plot. I’ll write ax1.scatter()
, passing in x
for the x data and y
for the y data.
03:03
Now the method needs to know how to style the scatter plot. For marker
, I’ll give it a value of 'o'
. The marker
parameter sets the style of the dot, like circles versus x’s versus crosses.
03:19
You can learn about all the different options by viewing the documentation for the .scatter()
function. Now I’ll give these circles a color of red and an edge color of blue.
03:33
Just like before, I’m going to set the axes’s title, x label, and y label. Matplotlib uses LaTeX to render the text, and so placing text in between dollar signs ('$'
) will italicize it.
03:50
It respects basic LaTeX formatting options. Let’s give this Axes
some grid lines so it’s easier to match each point to an x and y value. To do this, I’m first going to call the .set_accessbelow()
method with a value of True
, which will tell this Axes
to display ticks and grids behind the points.
04:16
Then, we can actually enable the grid by calling the .grid()
method with a linestyle
of two dashes ('--'
). Once again, you can check out the documentation to see all of the different line styles.
04:31
Great. Our first Axes
is done. The second Axes
will be a histogram, and so I’ll call the .hist()
method on our ax2
object.
04:43
I’ll pass in our data, which is the two-dimensional ndarray
that resulted from stacking our x
and y
arrays on top of each other.
04:53
The method also needs a bins
argument, which will set the axes’s bins along the x-axis. I’ll use np.arange()
with a lower bound of data.min()
and an upper bound of data.max()
, which will generate an ndarray
counting from 1
to 13
inclusive.
05:19
Finally, the .hist()
method needs a label
parameter, so I’ll give it ('x', 'y')
.
05:26 This histogram will have two colored bars, so it’s important we distinguish between which one represents x and which one represents y. To do this, we can add a legend to our plot.
05:41
That’s as easy as ax2.legend()
,
05:46
and I’ll give it a location of 0
, which will tell matplotlib
to render the legend in the best place to avoid overlapping with the bars drawn on the screen.
05:58 It should be noted that this can cause delays in rendering your figure when you have very large data sets to work with, so in that case, you can define the position manually by passing in two floats.
06:12
I almost always use 0
, though. Just like before, I will give this Axes
a title with the variable names italicized. Because we’ll be plotting two Axes
side by side, it would look nicer if the y-axis ticks for this second Axes
was on the right side of the plot, instead of on the left.
06:36
We can change this with ax2.yaxis.tick_right()
. To finish up, I’ll add the dashed lines to our grid and I’ll call plt.show()
so we actually get some output on our screen. And when I run this, you see both axes show up within our figure.
07:00
The scatter plot simply checks the elements in each position of both the x
and y
arrays, and then plots a point there. As for the histogram, that shows the frequency of x
and y
, meaning how often a specific number along the x-axis occurs in either the x
array or the y
array. Notice how the tick marks and the labels for the y-axis are on the right side, just as we specified. This code shows that it’s important we use a stateless approach with matplotlib
.
07:38
If we relied on pyplot
alone, it would be very difficult to customize each axes independently, because we wouldn’t have a direct object reference to each Axes
.
07:50
We’d have to dig down deep into pyplot
and find each Axes
ourself, which is even more difficult considering that pyplot
keeps track of the current Axes
, and we have two of them to deal with.
08:05
Before I close off this video, I want to show you one other way we could use pyplot
to get one Figure
and multiple Axes
objects.
08:16
I’m here in a blank Python shell, and I’m going to start by importing matplotlib.pyplot as plt
, as we always do. I want to create a figure with four axes in a 2 by 2 grid, so I’ll write fig, ax = plt.subplots(nrows=2, ncols=2)
and I’ll give it a figure size of, let’s say, (7, 7)
.
08:52
The figure size doesn’t really matter here since we won’t be displaying the figure on the screen. Notice here that I wrote ax
instead of a tuple containing (ax1, ax2, ax3, ax4)
.
09:09
We could do that, but it might get difficult to manage if we have a figure with a lot of axes in it. But then if we’re trying to create four axes, what is ax
? To answer that, I’ll use the Python type()
function, passing in ax
.
09:28
As you can see, ax
is not an Axes
, but a numpy.ndarray
. If I just write ax
to see its value, we can see that it’s a two-dimensional ndarray containing all four of our Axes
objects. To get the first object, we can use square bracket notation, ax[0][0]
.
09:53
I put two zeros because it’s a two-dimensional array. And there’s our object. I mentioned in the NumPy video that each ndarray
has a shape, which can help us better visualize the multi-dimensional array. To see the shape,
10:10
we can write ax.shape
, and we see it’s a 2 by 2, because we have 2
columns and 2
rows.
10:20
If I ever wanted to store each Axes
object in the array within its own separate variable, I can write ax1, ax2, ax3, ax4 = ax.flatten()
.
10:38
Now, we’ve got variables that reference each Axes
object individually. I’ll run ax1
, and you see there’s our first Axes
.
10:51
This is all just another way to create a Figure
with multiple Axes
. It’s especially helpful if you have a Figure
with many Axes
, as you can manage them within an ndarray
, instead of having variables for each one.
Ranit Pradhan on April 5, 2020
File “<ipython-input-14-4e84e83ed077>”, line 19 ax2.legend(loc=(0)) ^ SyntaxError: invalid syntax
–what’s the solution?
ycc on Nov. 22, 2020
Regarding invalid syntax, I faced similar problem as well. After correcting line 17 on ax2.hist(data, bins...
whereby I missed out one “)” which is not obvious at first, managed to run it without error.
Become a Member to join the conversation.
alistairmclaren1 on March 21, 2020
Why when setting lower and upper limits for x the argument is “low=” and “high=” however for variable y only the integers 1 and 5 are used?