# Working With Multiple Subplots

A **figure** can have more than one **subplot**. Earlier, you learned that you can obtain your figure and axes objects by calling `plt.subplots()`

and passing in a figure size. This function can take two additional arguments:

- The number of
**rows** - The numbers of
**columns**

These arguments determine how many axes objects will belong to the figure, and by extension, how many axes objects will be returned to you. In the following code, `nrows`

is set to `1`

, and `ncols`

is set to `2`

:

```
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2)
```

This function returns two axes objects, which you store in a tuple.

**00:00**
A figure can have more than one subplot. Earlier, we learned how we can obtain our `Figure`

and `Axes`

objects with the `plt.subplots()`

function, passing in a figure size.

**00:15**
This function can also take two additional arguments, the number of rows and the number of columns.

**00:23**
These arguments dictate how many `Axes`

objects will belong to the `Figure`

, and by extension, how many `Axes`

objects will be returned to us.

**00:34**
In this example, I’ve set `nrows=1`

and `ncols=2`

, and so this function returns two `Axes`

objects, which I store in this tuple.

**00:48**
If I set `nrows=2`

and `ncols=2`

, the function would return four `Axes`

.

**00:58**
Let’s see how we can modify these two `Axes`

independently, creating two new visualizations in the process.

**01:06**
I’m here in a new file called `plot2.py`

and to save on time, I have already imported `pyplot`

and `numpy`

at the top. Just like before, we’re going to get our randomized data using `numpy`

. We’ll create a new variable called `x`

, and that will store the one-dimensional `ndarray`

obtained by calling `randint()`

with a lower limit of `1`

, an upper limit of `11`

, and a size of `50`

, just like before. Now, I’ll create a new variable called `y`

and that will store `x`

plus a one-dimensional array of 50 random numbers, from `1`

to `4`

inclusive.

**01:50**
This means that a random number from `1`

to `4`

will be added to each element in our `x`

`ndarray`

.

**01:59**
I’m also going to create one more variable called `data`

, and that will store the two-dimensional `ndarray`

obtained by calling `column_stack()`

with our `x`

and `y`

arrays.

**02:13**
Now we’re done obtaining our data points, so we can use `pyplot`

to obtain our `Figure`

and `Axes`

objects. This time around, we’re going to have one `Figure`

and two `Axes`

and so I’ll `write`

`fig,`

and then a tuple of `(ax1, ax2)`

.

**02:35**
Now we’ll call the `subplots()`

function with an `nrows`

value of `1`

, an `ncols`

value of `2`

, and a figure size of `(8, 4)`

. The first `Axes`

is going to be a scatter plot. I’ll write `ax1.scatter()`

, passing in `x`

for the *x* data and `y`

for the *y* data.

**03:03**
Now the method needs to know how to style the scatter plot. For `marker`

, I’ll give it a value of `'o'`

. The `marker`

parameter sets the style of the dot, like circles versus x’s versus crosses.

**03:19**
You can learn about all the different options by viewing the documentation for the `.scatter()`

function. Now I’ll give these circles a color of red and an edge color of blue.

**03:33**
Just like before, I’m going to set the axes’s title, *x* label, and *y* label. Matplotlib uses LaTeX to render the text, and so placing text in between dollar signs (`'$'`

) will italicize it.

**03:50**
It respects basic LaTeX formatting options. Let’s give this `Axes`

some grid lines so it’s easier to match each point to an *x* and *y* value. To do this, I’m first going to call the `.set_accessbelow()`

method with a value of `True`

, which will tell this `Axes`

to display ticks and grids behind the points.

**04:16**
Then, we can actually enable the grid by calling the `.grid()`

method with a `linestyle`

of two dashes (`'--'`

). Once again, you can check out the documentation to see all of the different line styles.

**04:31**
Great. Our first `Axes`

is done. The second `Axes`

will be a histogram, and so I’ll call the `.hist()`

method on our `ax2`

object.

**04:43**
I’ll pass in our data, which is the two-dimensional `ndarray`

that resulted from stacking our `x`

and `y`

arrays on top of each other.

**04:53**
The method also needs a `bins`

argument, which will set the axes’s bins along the *x*-axis. I’ll use `np.arange()`

with a lower bound of `data.min()`

and an upper bound of `data.max()`

, which will generate an `ndarray`

counting from `1`

to `13`

inclusive.

**05:19**
Finally, the `.hist()`

method needs a `label`

parameter, so I’ll give it `('x', 'y')`

.

**05:26**
This histogram will have two colored bars, so it’s important we distinguish between which one represents *x* and which one represents *y*. To do this, we can add a legend to our plot.

**05:41**
That’s as easy as `ax2.legend()`

,

**05:46**
and I’ll give it a location of `0`

, which will tell `matplotlib`

to render the legend in the best place to avoid overlapping with the bars drawn on the screen.

**05:58**
It should be noted that this can cause delays in rendering your figure when you have very large data sets to work with, so in that case, you can define the position manually by passing in two floats.

**06:12**
I almost always use `0`

, though. Just like before, I will give this `Axes`

a title with the variable names italicized. Because we’ll be plotting two `Axes`

side by side, it would look nicer if the *y*-axis ticks for this second `Axes`

was on the right side of the plot, instead of on the left.

**06:36**
We can change this with `ax2.yaxis.tick_right()`

. To finish up, I’ll add the dashed lines to our grid and I’ll call `plt.show()`

so we actually get some output on our screen. And when I run this, you see both axes show up within our figure.

**07:00**
The scatter plot simply checks the elements in each position of both the `x`

and `y`

arrays, and then plots a point there. As for the histogram, that shows the frequency of `x`

and `y`

, meaning how often a specific number along the *x*-axis occurs in either the `x`

array or the `y`

array. Notice how the tick marks and the labels for the *y*-axis are on the right side, just as we specified. This code shows that it’s important we use a stateless approach with `matplotlib`

.

**07:38**
If we relied on `pyplot`

alone, it would be very difficult to customize each axes independently, because we wouldn’t have a direct object reference to each `Axes`

.

**07:50**
We’d have to dig down deep into `pyplot`

and find each `Axes`

ourself, which is even more difficult considering that `pyplot`

keeps track of the current `Axes`

, and we have two of them to deal with.

**08:05**
Before I close off this video, I want to show you one other way we could use `pyplot`

to get one `Figure`

and multiple `Axes`

objects.

**08:16**
I’m here in a blank Python shell, and I’m going to start by importing `matplotlib.pyplot as plt`

, as we always do. I want to create a figure with four axes in a 2 by 2 grid, so I’ll write `fig, ax = plt.subplots(nrows=2, ncols=2)`

and I’ll give it a figure size of, let’s say, `(7, 7)`

.

**08:52**
The figure size doesn’t really matter here since we won’t be displaying the figure on the screen. Notice here that I wrote `ax`

instead of a tuple containing `(ax1, ax2, ax3, ax4)`

.

**09:09**
We could do that, but it might get difficult to manage if we have a figure with a lot of axes in it. But then if we’re trying to create four axes, what is `ax`

? To answer that, I’ll use the Python `type()`

function, passing in `ax`

.

**09:28**
As you can see, `ax`

is not an `Axes`

, but a `numpy.ndarray`

. If I just write `ax`

to see its value, we can see that it’s a two-dimensional ndarray containing all four of our `Axes`

objects. To get the first object, we can use square bracket notation, `ax[0][0]`

.

**09:53**
I put two zeros because it’s a two-dimensional array. And there’s our object. I mentioned in the NumPy video that each `ndarray`

has a shape, which can help us better visualize the multi-dimensional array. To see the shape,

**10:10**
we can write `ax.shape`

, and we see it’s a 2 by 2, because we have `2`

columns and `2`

rows.

**10:20**
If I ever wanted to store each `Axes`

object in the array within its own separate variable, I can write `ax1, ax2, ax3, ax4 = ax.flatten()`

.

**10:38**
Now, we’ve got variables that reference each `Axes`

object individually. I’ll run `ax1`

, and you see there’s our first `Axes`

.

**10:51**
This is all just another way to create a `Figure`

with multiple `Axes`

. It’s especially helpful if you have a `Figure`

with many `Axes`

, as you can manage them within an `ndarray`

, instead of having variables for each one.

**Ranit Pradhan** on April 5, 2020

File “<ipython-input-14-4e84e83ed077>”, line 19 ax2.legend(loc=(0)) ^ SyntaxError: invalid syntax

–what’s the solution?

**ycc** on Nov. 22, 2020

Regarding invalid syntax, I faced similar problem as well. After correcting line 17 on `ax2.hist(data, bins...`

whereby I missed out one “)” which is not obvious at first, managed to run it without error.

Become a Member to join the conversation.

alistairmclaren1on March 21, 2020Why when setting lower and upper limits for x the argument is “low=” and “high=” however for variable y only the integers 1 and 5 are used?