Affine Registration in 12 lines of code
Yes, you’ve read that correctly! Affine registration can be done in 12 lines of Python using PyTorch. That’s surprising as PyTorch is originally build for deep learning not image registration. But it’s good news for (neuro-)imaging people like me and also a fun toy problem to understand the power of PyTorch. So, if you are interested in neuroimaging and/or deep learning this post will tickle your whistle!
If you are a in a hurry and know what Affine Registration/PyTorch is, simply skip to “The 12 lines” and for the really impatient ones
TL;DR:
- PyTorch is surprisingly effective for image registration due to its
- automatic gradient engine (saving code lines and headaches)
- utility functions F.affine_grid, F.grid_sample and F.interpolate
- GPU support which enables faster compute (on NVIDIA GPU or >2020 Apple Silicon)
- The core of image registration can be coded in 12 lines
- My ~100 lines image registration library installable via
pip install torchreg
supports- 2D and 3D images
- GPU computation (speedup!)
- freezing translation, rotation, zoom and/or shear (to do e.g. Rigid Registration)
- multiresolution approaches
- using custom similarity functions/losses and optimizers
- parallel, multimodal and coregistration
What is Affine Registration?
Affine registration is a method which aligns images (e.g. 3D brain-images).
Affine registration applies a combination of translation, rotation, zooming and shearing transformations.
Let’s apply it to two (pointcloudy) fishes to get a visual understanding.
In this example the blue fish - moving image - was aligned (registered) to the red fish - static image - such that each blue point matches the position of its corresponding red point.
Unfortunately, there does not exist a formular that takes two images and returns the desired transformation. But since we can quantify how good the two images are aligned by measuring the mean distances of the respective points we can solve it iteratively - step-by-step. Applied to this problem, an iterative approach does something like this:
- Apply some transformation to moving image
- Calculate distance between points
- If distance > x, apply another transformation to moving image
- Calculate distance between points
- If distance > x, apply another transformation to moving image
- Calculate distance between points
- …
The most naive version of the iterative approach would always apply a random transformation and take a long time until by chance it found a transformation which meets the requirement “distance < x”. As you can see in the GIF, there is a smarter way which smoothly reduces the distance until it meets the requirement. We will get there at the end of “Why PyTorch”!
The transformation can be fully described by an affine matrix \(\mathbf{A}\) (2D: 3x3 matrix, 3D: 4x4 matrix). The \(\mathbf{A}\) matrix encodes:
- translation in its last column
- scale (zoom) on its diagonal
- rotation/shear on all non-diagnoal values of the first two rows/columns
You might ask “Why encode the transformations in this weird matrix?”. Because we can now transform each coordinate point \(\vec{p}\) by simply multiplying it with this matrix \(\mathbf{A}\).
\[\vec{p}_{\text{moved}} = \mathbf{A} \cdot \vec{p} = \begin{bmatrix} a & b & c\\ d & e & f\\ 0 & 0 & 1\\ \end{bmatrix} \begin{bmatrix} x \\ y \\ 1 \\ \end{bmatrix} = \begin{bmatrix} a \cdot x + b \cdot y + c \\ d \cdot x + e \cdot y + f \\ 1 \\ \end{bmatrix}\]As you can see, we needed to add an extra dimension after \(y\) with a \(\mathrm{1}\) to the 2D point to make it work. In 3D you would have to do the same after the \(z\). Similary, the last row of the affine matrix will always be \(0\) \(0\) … and one \(1\) at the end in 2D and 3D.
Nice, we now introduced all needed concepts to do affine registration in a (naive) iterative fashion! We could write a (slow) program that would randomly stumble towards better aligning affine transformations. Using this starting point we will now march, mumbling the programmers “Make it work, make it pretty, make it fast!”-mantra. Thankfully, PyTorch works fast and is sufficiently pretty such that we will reach the end of the rainbow quite quickly!
Why PyTorch?
PyTorch is primarily used for deep learning and it can be thought of as a NumPy with GPU support (speedup!) since:
- It uses NumPy semantics
- Besides CPU (“normal” processor) it also supports GPU (graphics card) computation, which can be 10-100x faster
For our problem (affine registration) it has two more powerful features.
- There exist utility functions which allow you to apply affine transformations to 2D and 3D images!
- PyTorch offers automatic differentiation which is very useful since iterative optimization is much faster when a derivative can be calculated!
To understand the second point we have to answer the following question: What is a gradient?
If you can, try to remember what a derivative is (you probably had to know it during middle school math class).
- The derivative \(\frac{df}{dx}\) of a function \(f(x)\) is its rate of change w.r.t. (with respect to) \(x\).
Let’s make it more concrete and say that \(x\) is an affine matrix and the function \(f\) is the distance of points between two fish images. Wow, what a useful concept for our problem: Now the derivative expresses how much the distance changes w.r.t. the affine matrix. Since \(x\) holds 16 elements (all values of the 4x4 affine matrix) the derivative also contains 16 elements - each simply being the derivative of the distance w.r.t. the respective matrix element. Bravo, this multivariable derivative is the gradient we wanted to understand!
- The gradient \(\nabla\) of a function \(f(x_1, x_2,...)\) is its rate of change w.r.t \(x_1\), \(x_2\),…
The beauty about the gradient is that it always points in the direction (here, affine matrix change) of the maximum increase of the function (here, maximum increase of distance). So, if we want to minimize the distance we just have to change the affine matrix in the opposite direction. Following this direction during our iterative approach, will result in smooth improvement as shown in the GIF.
- Apply some affine to moving image
- Calculate distance between points
- Calculate the gradient of the distance w.r.t affine
- Change affine in descending gradient direction and apply to moving image
- Calculate distance between points
- Calculate the gradient of the distance w.r.t affine
- Change affine in descending gradient direction and apply to moving image
- …
The 12 lines
Now that we know what Affine Registration is and what PyTorch offers us, lets look at the code.
For people who have used PyTorch for deep learning, the code should look very familiar. It looks like the core component of code which trains a neural net!
Let’s run through it, line by line:
- The function takes a
static
(image), amoving
(image),n_iterations
(number of iterations) and alearning_rate
affine
is initialized usingtorch.eye
(4x4 matrix filled with zeros + ones on diagonal = identity affine)affine
is made atorch.nn.Parameter
which will be optimized if passed to an optimizeroptimizer
is initialized using the SGD (Stochastic Gradient Descent) optimizer with the givenlearning_rate
- Starting a for-loop which will repeat/iterate lines 6-11 for
n_iterations
times optimizer.zero_grad()
initializes all derivatives (stored in the background) to zeroaffine
is transformed into anaffine_grid
…- …which is used to apply the affine transformation to the
moving
image loss
value is the mean squared error (MSE) betweenstatic
andmoved
imageloss.backward()
calculates the derivative/gradient of the loss w.r.t the parameters using the chain rule (loss
->moved
->affine_grid
->affine
)optimizer.step()
changes theaffine
in the opposite of the gradient direction (gradient descent) to minimize the loss- The optimized
affine
parameter is converted back to a tensor.detach()
and returned
The code deals with (3D) images instead of points now, which is why lines 7-9 need some extra explanation:
First, the MSE in line 7 is doing what the distance between the corresponding red and blue fish points was doing earlier in the post: It measures image alignment. Higher MSEs indicate worse alignment - also called dissimilarity - between moving and static image.
Second, an image can be thought of as a grid of pixels/points.
Applying an affine transformation to each of these pixels - i.e. multiplying its coordinates with the affine matrix, happening in F.affine_grid
- works just fine BUT:
You end up with new pixel coordinates which are not placed perfectly on a rectangular grid anymore.
So in 2D each “old” pixel typically ends up somewhere in a 2x2 pixel area of the new image.
The standard approach to deal with this is interpolation which is what F.grid_sample
is doing for us in the background.
Application to brain images
After this theoretical fugazi you might think “Talk is cheap, just show me a demo!” so here we go. Using this Colab notebook you can run the whole demo by simply clicking “Runtime” -> “Run all”!
Let’s download two brain images
and plot them (using .orthoview()
) to see how misaligned the brains are
The moving_nii
-brain is not as accurately aligned with the crosshairs as the static_nii
-brain.
We will fix that by registering moving_nii
to static_nii
!
Therefore, we have to firstly convert these .nii-files (called “Nifti”) into PyTorch tensors.
And then we can apply our beautiful affine_registration
function.
Oops, I snuck two more operations in there:
- Reducing the masks resolution because
affine_registration
will be much faster this way - Using
_ / _.max()
to normalize image intensities to the value range 0.0-1.0
Finally, we use the optimal_affine
to align the moving
tensor
Did we do everything right? Let’s visualize to check!
The moved_nii is accurately aligned to the crosshairs so it worked fine! 🎉
Optionally, we can convert the tensor back to a Nifti which can be saved:
torchreg: Tiny PyTorch image registration library
Welcome to the advertising bit of this post!
You did not really think I would go through the trouble of explaining you all this stuff, just to share my excitement about short code, did ya?!
I have added a few tweaks and tricks which are missing to make it a “mature” registration tool and ended up with ~100 lines I named torchreg.
torchreg supports 2D and 3D images and can be installed via pip
and provides you with the AffineRegistration
class
which supports
- using a multiresolution approach to save compute (per default it runs with 1/4 and then 1/2 of the original resolution for 500 + 100 iterations)
- choosing which operations (translation, rotation, zoom, shear) to optimize
- optimization with custom initial parameters
- and using custom dissimilarity functions and optimizers
With
you can leverage your GPU (if you have a NVIDIA GPU) and speed up the registration.
You can easily access the affine
and the four parameters
Thats it with the advertising bit!
Conclusion
I think your conclusions out of this blog post highly depend on your background. Hopefully they look something like this:
- You are a neuroimaging person
- I finally really understand this “Affine Registration” my toolboxes uses
- PyTorch is some neural network stuff. I didn’t fully understand that part but it enables really fast image registration!
- The next time I’ll align Niftis I’ll use torchreg!
- I’ll give the repo a star!
- You are a deep learning person
- I finally kinda understand what PyTorch does in the background of the training loops I always use…
- …but I will play with the Colab notebook and plot/print affine, moved and gradients in the loop to get a feeling for what is happening in the background!
- The word gradient does not scare me anymore!
- If I ever want to apply Affine/Rigid Registration I’ll use torchreg!
- I’ll give the repo a star!
- You are a nerd
- Interesting read, nice comprehensive + short code!
- I’ll give the repo a star because this post entertained me!
- You are a normie
- What is this weird guy talking about?!
One last closing remark: Affine registration in neuroimaging can be made really robust using brainmasks. If you are interested in how to get really accurate masks, take a look into my next blog post (online soon)!