Watch Now This tutorial has a related video course created by the Real Python team. Watch it together with the written tutorial to deepen your understanding: Using plt.scatter() to Visualize Data in Python
An important part of working with data is being able to visualize it. Python has several third-party modules you can use for data visualization. One of the most popular modules is Matplotlib and its submodule pyplot, often referred to using the alias plt
. Matplotlib provides a very versatile tool called plt.scatter()
that allows you to create both basic and more complex scatter plots.
Below, you’ll walk through several examples that will show you how to use the function effectively.
In this tutorial you’ll learn how to:
- Create a scatter plot using
plt.scatter()
- Use the required and optional input parameters
- Customize scatter plots for basic and more advanced plots
- Represent more than two dimensions on a scatter plot
To get the most out of this tutorial, you should be familiar with the fundamentals of Python programming and the basics of NumPy and its ndarray
object. You don’t need to be familiar with Matplotlib to follow this tutorial, but if you’d like to learn more about the module, then check out Python Plotting With Matplotlib (Guide).
Free Bonus: Click here to get access to a free NumPy Resources Guide that points you to the best tutorials, videos, and books for improving your NumPy skills.
Creating Scatter Plots
A scatter plot is a visual representation of how two variables relate to each other. You can use scatter plots to explore the relationship between two variables, for example by looking for any correlation between them.
In this section of the tutorial, you’ll become familiar with creating basic scatter plots using Matplotlib. In later sections, you’ll learn how to further customize your plots to represent more complex data using more than two dimensions.
Getting Started With plt.scatter()
Before you can start working with plt.scatter()
, you’ll need to install Matplotlib. You can do so using Python’s standard package manger, pip
, by running the following command in the console :
$ python -m pip install matplotlib
Now that you have Matplotlib installed, consider the following use case. A café sells six different types of bottled orange drinks. The owner wants to understand the relationship between the price of the drinks and how many of each one he sells, so he keeps track of how many of each drink he sells every day. You can visualize this relationship as follows:
import matplotlib.pyplot as plt
price = [2.50, 1.23, 4.02, 3.25, 5.00, 4.40]
sales_per_day = [34, 62, 49, 22, 13, 19]
plt.scatter(price, sales_per_day)
plt.show()
In this Python script, you import the pyplot
submodule from Matplotlib using the alias plt
. This alias is generally used by convention to shorten the module and submodule names. You then create lists with the price and average sales per day for each of the six orange drinks sold.
Finally, you create the scatter plot by using plt.scatter()
with the two variables you wish to compare as input arguments. As you’re using a Python script, you also need to explicitly display the figure by using plt.show()
.
When you’re using an interactive environment, such as a console or a Jupyter Notebook, you don’t need to call plt.show()
. In this tutorial, all the examples will be in the form of scripts and will include the call to plt.show()
.
Here’s the output from this code:
This plot shows that, in general, the more expensive a drink is, the fewer items are sold. However, the drink that costs $4.02 is an outlier, which may show that it’s a particularly popular product. When using scatter plots in this way, close inspection can help you explore the relationship between variables. You can then carry out further analysis, whether it’s using linear regression or other techniques.
Comparing plt.scatter()
and plt.plot()
You can also produce the scatter plot shown above using another function within matplotlib.pyplot
. Matplotlib’s plt.plot()
is a general-purpose plotting function that will allow you to create various different line or marker plots.
You can achieve the same scatter plot as the one you obtained in the section above with the following call to plt.plot()
, using the same data:
plt.plot(price, sales_per_day, "o")
plt.show()
In this case, you had to include the marker "o"
as a third argument, as otherwise plt.plot()
would plot a line graph. The plot you created with this code is identical to the plot you created earlier with plt.scatter()
.
In some instances, for the basic scatter plot you’re plotting in this example, using plt.plot()
may be preferable. You can compare the efficiency of the two functions using the timeit
module:
import timeit
import matplotlib.pyplot as plt
price = [2.50, 1.23, 4.02, 3.25, 5.00, 4.40]
sales_per_day = [34, 62, 49, 22, 13, 19]
print(
"plt.scatter()",
timeit.timeit(
"plt.scatter(price, sales_per_day)",
number=1000,
globals=globals(),
),
)
print(
"plt.plot()",
timeit.timeit(
"plt.plot(price, sales_per_day, 'o')",
number=1000,
globals=globals(),
),
)
The performance will vary on different computers, but when you run this code, you’ll find that plt.plot()
is significantly more efficient than plt.scatter()
. When running the example above on my system, plt.plot()
was over seven times faster.
If you can create scatter plots using plt.plot()
, and it’s also much faster, why should you ever use plt.scatter()
? You’ll find the answer in the rest of this tutorial. Most of the customizations and advanced uses you’ll learn about in this tutorial are only possible when using plt.scatter()
. Here’s a rule of thumb you can use:
- If you need a basic scatter plot, use
plt.plot()
, especially if you want to prioritize performance. - If you want to customize your scatter plot by using more advanced plotting features, use
plt.scatter()
.
In the next section, you’ll start exploring more advanced uses of plt.scatter()
.
Customizing Markers in Scatter Plots
You can visualize more than two variables on a two-dimensional scatter plot by customizing the markers. There are four main features of the markers used in a scatter plot that you can customize with plt.scatter()
:
- Size
- Color
- Shape
- Transparency
In this section of the tutorial, you’ll learn how to modify all these properties.
Changing the Size
Let’s return to the café owner you met earlier in this tutorial. The different orange drinks he sells come from different suppliers and have different profit margins. You can show this additional information in the scatter plot by adjusting the size of the marker. The profit margin is given as a percentage in this example:
import matplotlib.pyplot as plt
import numpy as np
price = np.asarray([2.50, 1.23, 4.02, 3.25, 5.00, 4.40])
sales_per_day = np.asarray([34, 62, 49, 22, 13, 19])
profit_margin = np.asarray([20, 35, 40, 20, 27.5, 15])
plt.scatter(x=price, y=sales_per_day, s=profit_margin * 10)
plt.show()
You can notice a few changes from the first example. Instead of lists, you’re now using NumPy arrays. You can use any array-like data structure for the data, and NumPy arrays are commonly used in these types of applications since they enable element-wise operations that are performed efficiently. The NumPy module is a dependency of Matplotlib, which is why you don’t need to install it manually.
You’ve also used named parameters as input arguments in the function call. The parameters x
and y
are required, but all other parameters are optional.
The parameter s
denotes the size of the marker. In this example, you use the profit margin as a variable to determine the size of the marker and multiply it by 10
to display the size difference more clearly.
You can see the scatter plot created by this code below:
The size of the marker indicates the profit margin for each product. The two orange drinks that sell most are also the ones that have the highest profit margin. This is good news for the café owner!
Changing the Color
Many of the customers of the café like to read the labels carefully, especially to find out the sugar content of the drinks they’re buying. The café owner wants to emphasize his selection of healthy foods in his next marketing campaign, so he categorizes the drinks based on their sugar content and uses a traffic light system to indicate low, medium, or high sugar content for the drinks.
You can add color to the markers in the scatter plot to show the sugar content of each drink:
# ...
low = (0, 1, 0)
medium = (1, 1, 0)
high = (1, 0, 0)
sugar_content = [low, high, medium, medium, high, low]
plt.scatter(
x=price,
y=sales_per_day,
s=profit_margin * 10,
c=sugar_content,
)
plt.show()
You define the variables low
, medium
, and high
to be tuples, each containing three values that represent the red, green, and blue color components, in that order. These are RGB color values. The tuples for low
, medium
, and high
represent green, yellow, and red, respectively.
You then defined the variable sugar_content
to classify each drink. You use the optional parameter c
in the function call to define the color of each marker. Here’s the scatter plot produced by this code:
The café owner has already decided to remove the most expensive drink from the menu as this doesn’t sell well and has a high sugar content. Should he also stop stocking the cheapest of the drinks to boost the health credentials of the business, even though it sells well and has a good profit margin?
Changing the Shape
The café owner has found this exercise very useful, and he wants to investigate another product. In addition to the orange drinks, you’ll now also plot similar data for the range of cereal bars available in the café:
import matplotlib.pyplot as plt
import numpy as np
low = (0, 1, 0)
medium = (1, 1, 0)
high = (1, 0, 0)
price_orange = np.asarray([2.50, 1.23, 4.02, 3.25, 5.00, 4.40])
sales_per_day_orange = np.asarray([34, 62, 49, 22, 13, 19])
profit_margin_orange = np.asarray([20, 35, 40, 20, 27.5, 15])
sugar_content_orange = [low, high, medium, medium, high, low]
price_cereal = np.asarray([1.50, 2.50, 1.15, 1.95])
sales_per_day_cereal = np.asarray([67, 34, 36, 12])
profit_margin_cereal = np.asarray([20, 42.5, 33.3, 18])
sugar_content_cereal = [low, high, medium, low]
plt.scatter(
x=price_orange,
y=sales_per_day_orange,
s=profit_margin_orange * 10,
c=sugar_content_orange,
)
plt.scatter(
x=price_cereal,
y=sales_per_day_cereal,
s=profit_margin_cereal * 10,
c=sugar_content_cereal,
)
plt.show()
In this code, you refactor the variable names to take into account that you now have data for two different products. You then plot both scatter plots in a single figure. This gives the following output:
Unfortunately, you can no longer figure out which data points belong to the orange drinks and which to the cereal bars. You can change the shape of the marker for one of the scatter plots:
import matplotlib.pyplot as plt
import numpy as np
low = (0, 1, 0)
medium = (1, 1, 0)
high = (1, 0, 0)
price_orange = np.asarray([2.50, 1.23, 4.02, 3.25, 5.00, 4.40])
sales_per_day_orange = np.asarray([34, 62, 49, 22, 13, 19])
profit_margin_orange = np.asarray([20, 35, 40, 20, 27.5, 15])
sugar_content_orange = [low, high, medium, medium, high, low]
price_cereal = np.asarray([1.50, 2.50, 1.15, 1.95])
sales_per_day_cereal = np.asarray([67, 34, 36, 12])
profit_margin_cereal = np.asarray([20, 42.5, 33.3, 18])
sugar_content_cereal = [low, high, medium, low]
plt.scatter(
x=price_orange,
y=sales_per_day_orange,
s=profit_margin_orange * 10,
c=sugar_content_orange,
)
plt.scatter(
x=price_cereal,
y=sales_per_day_cereal,
s=profit_margin_cereal * 10,
c=sugar_content_cereal,
marker="d",
)
plt.show()
You keep the default marker shape for the orange drink data. The default marker is "o"
, which represents a dot. For the cereal bar data, you set the marker shape to "d"
, which represents a diamond marker. You can find the list of all markers you can use in the documentation page on markers. Here are the two scatter plots superimposed on the same figure:
You can now distinguish the data points for the orange drinks from those for the cereal bars. But there is one problem with the last plot you created that you’ll explore in the next section.
Changing the Transparency
One of the data points for the orange drinks has disappeared. There should be six orange drinks, but only five round markers can be seen in the figure. One of the cereal bar data points is hiding an orange drink data point.
You can fix this visualization problem by making the data points partially transparent using the alpha value:
# ...
plt.scatter(
x=price_orange,
y=sales_per_day_orange,
s=profit_margin_orange * 10,
c=sugar_content_orange,
alpha=0.5,
)
plt.scatter(
x=price_cereal,
y=sales_per_day_cereal,
s=profit_margin_cereal * 10,
c=sugar_content_cereal,
marker="d",
alpha=0.5,
)
plt.title("Sales vs Prices for Orange Drinks and Cereal Bars")
plt.legend(["Orange Drinks", "Cereal Bars"])
plt.xlabel("Price (Currency Unit)")
plt.ylabel("Average weekly sales")
plt.text(
3.2,
55,
"Size of marker = profit margin\n" "Color of marker = sugar content",
)
plt.show()
You’ve set the alpha
value of both sets of markers to 0.5
, which means they’re semitransparent. You can now see all the data points in this plot, including those that coincide:
You’ve also added a title and other labels to the plot to complete the figure with more information about what’s being displayed.
Customizing the Colormap and Style
In the scatter plots you’ve created so far, you’ve used three colors to represent low, medium, or high sugar content for the drinks and cereal bars. You’ll now change this so that the color directly represents the actual sugar content of the items.
You first need to refactor the variables sugar_content_orange
and sugar_content_cereal
so that they represent the sugar content value rather than just the RGB color values:
sugar_content_orange = [15, 35, 22, 27, 38, 14]
sugar_content_cereal = [21, 49, 29, 24]
These are now lists containing the percentage of the daily recommended amount of sugar in each item. The rest of the code remains the same, but you can now choose the colormap to use. This maps values to colors:
# ...
plt.scatter(
x=price_orange,
y=sales_per_day_orange,
s=profit_margin_orange * 10,
c=sugar_content_orange,
cmap="jet",
alpha=0.5,
)
plt.scatter(
x=price_cereal,
y=sales_per_day_cereal,
s=profit_margin_cereal * 10,
c=sugar_content_cereal,
cmap="jet",
marker="d",
alpha=0.5,
)
plt.title("Sales vs Prices for Orange Drinks and Cereal Bars")
plt.legend(["Orange Drinks", "Cereal Bars"])
plt.xlabel("Price (Currency Unit)")
plt.ylabel("Average weekly sales")
plt.text(
2.7,
55,
"Size of marker = profit margin\n" "Color of marker = sugar content",
)
plt.colorbar()
plt.show()
The color of the markers is now based on a continuous scale, and you’ve also displayed the colorbar that acts as a legend for the color of the markers. Here’s the resulting scatter plot:
All the plots you’ve plotted so far have been displayed in the native Matplotlib style. You can change this style by using one of several options. You can display the available styles using the following command:
>>> plt.style.available
[
"Solarize_Light2",
"_classic_test_patch",
"bmh",
"classic",
"dark_background",
"fast",
"fivethirtyeight",
"ggplot",
"grayscale",
"seaborn",
"seaborn-bright",
"seaborn-colorblind",
"seaborn-dark",
"seaborn-dark-palette",
"seaborn-darkgrid",
"seaborn-deep",
"seaborn-muted",
"seaborn-notebook",
"seaborn-paper",
"seaborn-pastel",
"seaborn-poster",
"seaborn-talk",
"seaborn-ticks",
"seaborn-white",
"seaborn-whitegrid",
"tableau-colorblind10",
]
You can now change the plot style when using Matplotlib by using the following function call before calling plt.scatter()
:
import matplotlib.pyplot as plt
import numpy as np
plt.style.use("seaborn")
# ...
This changes the style to that of Seaborn, another third-party visualization package. You can see the different style by plotting the final scatter plot you displayed above using the Seaborn style:
You can read more about customizing plots in Matplotlib, and there are also further tutorials on the Matplotlib documentation pages.
Using plt.scatter()
to create scatter plots enables you to display more than two variables. Here are the variables being represented in this example:
Variable | Represented by |
---|---|
Price | X-axis |
Average number sold | Y-axis |
Profit margin | Marker size |
Product type | Marker shape |
Sugar content | Marker color |
The ability to represent more than two variables makes plt.scatter()
a very powerful and versatile tool.
Exploring plt.scatter
() Further
plt.scatter()
offers even more flexibility in customizing scatter plots. In this section, you’ll explore how to mask data using NumPy arrays and scatter plots through an example. In this example, you’ll generate random data points and then separate them into two distinct regions within the same scatter plot.
A commuter who’s keen on collecting data has collated the arrival times for buses at her local bus stop over a six-month period. The timetabled arrival times are at 15 minutes and 45 minutes past the hour, but she noticed that the true arrival times follow a normal distribution around these times:
This plot shows the relative likelihood of a bus arriving at each minute within an hour. This probability distribution can be represented using NumPy and np.linspace()
:
import matplotlib.pyplot as plt
import numpy as np
mean = 15, 45
sd = 5, 7
x = np.linspace(0, 59, 60) # Represents each minute within the hour
first_distribution = np.exp(-0.5 * ((x - mean[0]) / sd[0]) ** 2)
second_distribution = 0.9 * np.exp(-0.5 * ((x - mean[1]) / sd[1]) ** 2)
y = first_distribution + second_distribution
y = y / max(y)
plt.plot(x, y)
plt.ylabel("Relative probability of bus arrivals")
plt.xlabel("Minutes past the hour")
plt.show()
You’ve created two normal distributions centered on 15
and 45
minutes past the hour and summed them. You set the most likely arrival time to a value of 1
by dividing by the maximum value.
You can now simulate bus arrival times using this distribution. To do this, you can create random times and random relative probabilities using the built-in random
module. In the code below, you will also use list comprehensions:
import random
import matplotlib.pyplot as plt
import numpy as np
n_buses = 40
bus_times = np.asarray([random.randint(0, 59) for _ in range(n_buses)])
bus_likelihood = np.asarray([random.random() for _ in range(n_buses)])
plt.scatter(x=bus_times, y=bus_likelihood)
plt.title("Randomly chosen bus arrival times and relative probabilities")
plt.ylabel("Relative probability of bus arrivals")
plt.xlabel("Minutes past the hour")
plt.show()
You’ve simulated 40
bus arrivals, which you can visualize with the following scatter plot:
Your plot will look different since the data you’re generating is random. However, not all of these points are likely to be close to the reality that the commuter observed from the data she gathered and analyzed. You can plot the distribution she obtained from the data with the simulated bus arrivals:
import random
import matplotlib.pyplot as plt
import numpy as np
mean = 15, 45
sd = 5, 7
x = np.linspace(0, 59, 60)
first_distribution = np.exp(-0.5 * ((x - mean[0]) / sd[0]) ** 2)
second_distribution = 0.9 * np.exp(-0.5 * ((x - mean[1]) / sd[1]) ** 2)
y = first_distribution + second_distribution
y = y / max(y)
n_buses = 40
bus_times = np.asarray([random.randint(0, 59) for _ in range(n_buses)])
bus_likelihood = np.asarray([random.random() for _ in range(n_buses)])
plt.scatter(x=bus_times, y=bus_likelihood)
plt.plot(x, y)
plt.title("Randomly chosen bus arrival times and relative probabilities")
plt.ylabel("Relative probability of bus arrivals")
plt.xlabel("Minutes past the hour")
plt.show()
This gives the following output:
To keep the simulation realistic, you need to make sure that the random bus arrivals match the data and the distribution obtained from those data. You can filter the randomly generated points by keeping only the ones that fall within the probability distribution. You can achieve this by creating a mask for the scatter plot:
# ...
in_region = bus_likelihood < y[bus_times]
out_region = bus_likelihood >= y[bus_times]
plt.scatter(
x=bus_times[in_region],
y=bus_likelihood[in_region],
color="green",
)
plt.scatter(
x=bus_times[out_region],
y=bus_likelihood[out_region],
color="red",
marker="x",
)
plt.plot(x, y)
plt.title("Randomly chosen bus arrival times and relative probabilities")
plt.ylabel("Relative probability of bus arrivals")
plt.xlabel("Minutes past the hour")
plt.show()
The variables in_region
and out_region
are NumPy arrays containing Boolean values based on whether the randomly generated likelihoods fall above or below the distribution y
. You then plot two separate scatter plots, one with the points that fall within the distribution and another for the points that fall outside the distribution. The data points that fall above the distribution are not representative of the real data:
You’ve segmented the data points from the original scatter plot based on whether they fall within the distribution and used a different color and marker to identify the two sets of data.
Reviewing the Key Input Parameters
You’ve learned about the main input parameters to create scatter plots in the sections above. Here’s a brief summary of key points to remember about the main input parameters:
Parameter | Description |
---|---|
x and y |
These parameters represent the two main variables and can be any array-like data types, such as lists or NumPy arrays. These are required parameters. |
s |
This parameter defines the size of the marker. It can be a float if all the markers have the same size or an array-like data structure if the markers have different sizes. |
c |
This parameter represents the color of the markers. It will typically be either an array of colors, such as RGB values, or a sequence of values that will be mapped onto a colormap using the parameter cmap . |
marker |
This parameter is used to customize the shape of the marker. |
cmap |
If a sequence of values is used for the parameter c , then this parameter can be used to select the mapping between values and colors, typically by using one of the standard colormaps or a custom colormap. |
alpha |
This parameter is a float that can take any value between 0 and 1 and represents the transparency of the markers, where 1 represents an opaque marker. |
These are not the only input parameters available with plt.scatter()
. You can access the full list of input parameters from the documentation.
Conclusion
Now that you know how to create and customize scatter plots using plt.scatter()
, you’re ready to start practicing with your own datasets and examples. This versatile function gives you the ability to explore your data and present your findings in a clear way.
In this tutorial you’ve learned how to:
- Create a scatter plot using
plt.scatter
() - Use the required and optional input parameters
- Customize scatter plots for basic and more advanced plots
- Represent more than two dimensions with
plt.scatter()
You can get the most out of visualization using plt.scatter()
by learning more about all the features in Matplotlib and dealing with data using NumPy.
Watch Now This tutorial has a related video course created by the Real Python team. Watch it together with the written tutorial to deepen your understanding: Using plt.scatter() to Visualize Data in Python