@frankzhu

A Geometric Interpretation of Neural Networks

February 20, 2024 / 8 min read

Last Updated: February 20, 2024

The Canonical Neural Network

Let's start by reviewing the basic multi-layer perceptron (MLP) neural network. Given feature matrix ( is number of examples, is the dimensionality of each input/no. of features), and class labels , for any input the output is:

Where is the hidden layers, and is the output.

is the sigmoid function to convert the output to probability, and is another non-linear function.

Geometric Interpretation

But what do those matrix multiplications actually do?

A neural network basically does 2 things: Rotating and Twisting in high dimensional space. Rotating in mathematical jargon for Linear Transformations (), and twisting is the Non-linear transformations (). Note that with bias , the transformation is not just a rotation but also a translation, aka Affine Transformation.

Linear Transformations

So what is a linear transformation and why does it matter for neural networks? When we multiply a matrix with our input vector , we are essentially transforming the vector to a new vector (hidden layer), i.e., . This matrix multiplication, in essense, is linear transformation.

This transformation can be thought of as a combination of rotation and/or scaling and/or reflection of the input vector . Without the loss of generality, let's say I have this transformation matrix , and depending on its characteristics, it can do the following:

  • ArrowAn icon representing an arrow

    Rotation (when the is orthonormal, i.e. ).

    E.g., if is rotating by 90 degrees in the counter-clockwise direction.)

  • ArrowAn icon representing an arrow

    Scaling (when the is diagonal).

    E.g., if is scaling by 2 in the x-axis and 3 in the y-axis.

  • ArrowAn icon representing an arrow

    Reflection (when the is negative).

    E.g. is reflecting in the x-axis.

  • ArrowAn icon representing an arrow

    Shearing (when is non-zero off-diagonal).

    E.g. is shearing in the x-axis.

Let's visualize the linear transformations on an arbitrary 2D Cloud of points . You may ask "why points"? Because points are vectors in a high-dimensional space (more commonly used in Physics), and texts, images, etc., can be represented as a very long vector in those high dimensional spaces.

Creating Random 2D Cloud

1
import torch
2
from res.plot_lib import set_default, show_scatterplot, plot_bases
3
from matplotlib.pyplot import plot, title, axis, figure, gca, gcf
4
5
# generate some points in 2-D space
6
n_points = 1_000
7
X = torch.randn(n_points, 2).to('gpu')
8
show_scatterplot(X, title='X')
9
10
OI = torch.cat((torch.zeros(2, 2), torch.eye(2))).to('gpu')
11
plot_bases(OI) # basis vector
Set up: Random 2D Cloud, the green and red arrows are the basis vectors of the input space.

Now let's apply some linear transformations to this cloud of points. We first generate a random matrix , and using Singular Value Decomposition (SVD), we can decompose into 3 matrices such that . Don't worry about the strict definition for now, just noticed at how our original cloud changed:

Applying Linear Transformations

1
# create a random matrix
2
W = torch.randn(2, 2).to(device)
3
# singular value decomposition
4
U, S, V = torch.svd(W)
5
6
torch.manual_seed(2024)
7
# Define original basis vectors
8
OI = torch.cat((torch.zeros(2, 2), torch.eye(2))).to(device)
9
10
# Apply transformations sequentially
11
Y1 = X @ V # Note: V is already transposed in the output of torch.svd
12
Y2 = Y1 @ torch.diag(S) # S is a diagonal matrix
13
Y3 = Y2 @ U
14
15
# Transform the basis vectors for each step
16
new_OI_Y1 = OI @ V
17
new_OI_Y2 = new_OI_Y1 @ torch.diag(S)
18
new_OI_Y3 = new_OI_Y2 @ U
19
20
# Titles and data for plots
21
titles = [
22
'X (Original)',
23
'Y1 = XV\nV = [{:.3f}, {:.3f}], [{:.3f}, {:.3f}]'.format(V[0, 0].item(), V[0, 1].item(), V[1, 0].item(), V[1, 1].item()),
24
'Y2 = SY1\nS = [{:.3f}, {:.3f}]'.format(S[0].item(), S[1].item()),
25
'Y3 = UY2\nU = [{:.3f}, {:.3f}], [{:.3f}, {:.3f}]'.format(U[0, 0].item(), U[0, 1].item(), U[1, 0].item(), U[1, 1].item())
26
]
27
Ys = [X, Y1, Y2, Y3]
28
new_OIs = [OI, new_OI_Y1, new_OI_Y2, new_OI_Y3]
29
30
# Plot the sequential transformations
31
plt.figure(figsize=(15, 5))
32
33
for i in range(4):
34
plt.subplot(1, 4, i+1)
35
show_scatterplot(Ys[i], colors, title=titles[i], axis=True)
36
# plot_bases(OI)
37
plot_bases(new_OIs[i])
38
39
plt.show()
From left to right: Original Cloud, Rotated Cloud, Scaled Cloud, Rotated Cloud

Did you notice anything? The cloud of points was rotated, scaled, and rotated again. Let's analyze the transformations:

  • ArrowAn icon representing an arrow

    The first transformation rotated the cloud by the matrix .

    To find how much the cloud was rotated, we first note that generally, a rotation matrix can be written as:

    Which means we can extract the angle of rotation from the matrix using the formula randians. This means that rotated the cloud by degrees. In other words, the cloud was rotated by 75 degrees in the clockwise direction.

    However, a 75 degrees rotation means that the blue cloud would be in the 4th quadrant, but the cloud is in the 3rd quadrant. This is because the matrix also contains a reflection! We can check that , which means that the matrix is reflecting the cloud, resulting in the blue cloud landing in the 3rd quadrant (though outside of the scope of this article, we can find axis of reflection via eigen-decomposition).

  • ArrowAn icon representing an arrow

    The second transformation scaled the cloud by the diagonal matrix . This means that the cloud was scaled by a factor of 2.3 in the x-axis and 0.5933 in the y-axis. We check that, indeed, the cloud had become both longer and thinner.

  • ArrowAn icon representing an arrow

    Finally, the third transformation rotated the cloud again by the matrix . This means that the cloud was rotated by -137.59 degrees, a clock-wise rotation.

Singular Value Decomposition (SVD)

What we just did called Singular Value Decomposition of the matrix :

  • ArrowAn icon representing an arrow
    are the rotation-reflection matrices we just saw.
  • ArrowAn icon representing an arrow
    is the scaling factor for the dimension, and is for the dimension. Larger the scaling factor, the more stretched the space is in that direction, and vice versa.

Non-linear Transformations

Linear transforms can rotate, reflect, stretch and compress, but cannot squash/curve. We need non-linearities for this. To visualize this, we can use the tanh function, which squashes the input to the range . For a 2D space, this squashes the points to a square, and we will scale the input by to make the effect more visible:

There are a couple of famous non-linear functions that are used in neural networks:

  • ArrowAn icon representing an arrow
    Hyperbolic Tangent (tanh)
    • ArrowAn icon representing an arrow
      Squash the input to be between and
    • ArrowAn icon representing an arrow
      Squash the output to be between and .
    • ArrowAn icon representing an arrow
      Input and output (x and y) that are already in the range doesn't change much
    • ArrowAn icon representing an arrow
      2 kinks: one at and the other at
  • ArrowAn icon representing an arrow
    Sigmoid
    • ArrowAn icon representing an arrow
      Squashes the input to be between and
    • ArrowAn icon representing an arrow
      Squashes the output to be between and
    • ArrowAn icon representing an arrow
      Commonly used in the output layer of binary classifiers

Applying Non-linear Transformations

1
import torch
2
import torch.nn as nn
3
plt.figure(figsize=(15, 5))
4
plt.subplot(1, 4, 1)
5
show_scatterplot(Y3, colors, title='h=Y3')
6
plot_bases(OI)
7
8
# Loop through scaling factors and plot
9
for s in range(1, 4):
10
plt.subplot(1, 4, s + 1)
11
Y = torch.tanh(s * Y3).data # Scale & apply non-linearity
12
show_scatterplot(Y, colors, title=f'Y=tanh({s}*h)')
13
plot_bases(OI, width=0.01)
14
15
plt.show()
From left to right: Result from previous Linear Transformation, scaled Non-linear Transformation with s=1, s=2, s=3

And that's it! We have seen how a neural network can be thought of as a series of linear and non-linear transformations in high-dimensional space. This kind of flexibility allow neural networks to model complex relationships in the data, and is the reason why they are so powerful.

Have a wonderful day.

– Frank

It's not a black box, it's a high-dimensional space transformer!