# Caveat on Scatter Plot

The issue presented in this post is discussed in the book The Data Science Design Manual.

Suppose we have two integer arrays x and y and we want to know if there is any relationship between these two variables. The first step is to plot the data. For example, we could do

1
2
import matplotlib.pyplot as plt
plt.scatter(x, y)


Suppose we get the figure below:

The figure shows that there is a linear relationship between x and y. It also shows there are noises in the data set as we can see the points are kind of uniformly distributed all over the place in the box.

Now the question is: Are there other patterns? As discussed in the book, the size of the points matters, and if the size is large, it may hide finner structure in the data. Another problem is that scatter plot on integer data set does not show the frequency of the values. To solve this issue, we could plot the data with a small marker size and set the color transparency parameter. For example:

1
plt.scatter(x, y, s = 1, alpha = 0.1)


This time, the noise of the data disappears in the plot. This means there aren't many points that are uniformly distributed. Or at least, they are only a small portion of the data. The linear relationship seems to be the dominant pattern in the data set. Is that true?

The idea was to use color transparency to represent the frequency (or density) of values. This is the reason why the less frequent uniformly distributed points disappear in the plot. However, it's really hard to tell the frequency of values immediately.

As proposed in the book, the trick is to add jitters to the data set. For example:

1
2
3
4
5
6
std = 10
xx = [item + random.gauss(0, std) for item in x]
yy = [item + random.gauss(0, std) for item in y]

plt.scatter(x, y, s = 1, alpha = 0.02, color="blue")
plt.scatter(xx, yy, s = 1, alpha = 0.05, color='red')


Now we can clearly see two clusters in the data set.

The data in x and y are generated using the code below:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import matplotlib.pyplot as plt
import random

numOfZeros = 3000
xMax = 300
yMax = 300

numOfRandomPoints = 500
numOfCenterPoints = 5000
sigma = 5

x = []
y = []

for i in range(numOfZeros):
x.append(0)
y.append(0)

for i in range(numOfRandomPoints):
x.append(random.randint(0, xMax))
y.append(random.randint(0, yMax))

for i in range(numOfCenterPoints):
xx = random.randint(100, 200)
yy = random.gauss(xx, sigma)
x.append(xx)
y.append(yy)


----- END -----

Welcome to join reddit self-learning community.

Want some fun stuff?