Build an ML Framework from Scratch

An introduction to the basics of ML frameworks

April 2025

I have always wondered how TensorFlow, PyTorch, or JAX work. However, their codebases are super complex and unreadable. To find out, I built a machine learning framework from scratch by myself to learn the basics of machine learning frameworks.

It only took about 800 lines of Python and C++ code altogether to build the framework. The following is a basic example of what it is capable of, but you are also welcome to check out the end-to-end classification example.

import framework

x = framework.Tensor.from_numpy(np.array([[2.0, 3.0]], dtype=np.float32))
y = framework.Tensor.from_numpy(np.array([[4.0], [5.0]], dtype=np.float32))
z = framework.ops.matmul(x, y)  # Expected: [[2*4 + 3*5]] = [[23.0]]
loss = framework.ops.sum(z)

loss.backward()  # Compute the gradients of loss w.r.t. tensor x, y, and z.

print("x.grad:", x.grad.numpy())  # Expected: [[4.0, 5.0]]
print("y.grad:", y.grad.numpy())  # Expected: [[2.0], [3.0]]

The codebase is thoroughly commented, which brings it up to around 2000 lines of code. Everyone can easily read the source code to understand how an ML framework works and feel more confident about building a new one. This project is named as the "Readable ML Framework". ⭐ Star it on GitHub now so that you don't lose it.

This article dives into all the core concepts in ML frameworks citing the implementations from the repo. We will also discuss other important elements of ML frameworks beyond the project.

Why care about ML frameworks?

ML framework lies in the center of the ML infra. Understanding ML framework gives you the bigger picture by connecting all the dots, including CUDA/Triton kernels, GPU/TPU, data pipelining, ML compilers, distributed training and so on.

How did I come up with this idea?

It all started with a question from my PhD advisor. After my first internship on the TensorFlow/Keras team, he asked me if I knew how to build a project like TensorFlow from scratch. The answer was absolutely no at the time. It planted a seed in my mind that grew my curiosity for ML frameworks through the years. Five years later, I built this project to learn more about ML frameworks myself.

I am a top contributor to Keras, helped with integrating Keras to TensorFlow, PyTorch, and JAX. Many people would think that I knew ML frameworks very well, but I did not. Then, I started to realize that if I did not know much about it, others would not either unless they worked on building an ML framework directly.

The next question I asked myself was how hard could it be to build a basic ML framework. To find out, I read the source code of PyTorch. Here are my notes. I found it was not that hard if we got rid of all the complex features.

So, I decided to build one with almost no features except tensor operations and automatic gradient computation. This "bare-bones" style for educational projects is largely inspired by the super tiny compiler project, which helped me learn what a compiler is.

Python C++ interfacing

Let's start with the number one issue that turns the code readers away from the ML frameworks, Python C++ interfacing. Whenever I wanted to track down a tensor operation (We call it op/ops in the rest of the article for short) from its Python interface to its actual implementation in C++, I got lost. It could take a super complex system to bind all the C++ functions to Python functions. We call this one-to-one mapping between Python functions and C++ functions bindings.

In PyTorch, it indeed developed a system called torchgen to generate the Python C++ bindings. TensorFlow and JAX uses Pybind11, which is a third-party library to simplify the code for Python C++ bindings. For simplicity and the readability of the code, we use Pybind11 for our framework.

It works as follows.

// "core" is the Python submodule name.
// Its Python import path is framework.core
// "m" is the module object.
PYBIND11_MODULE(core, m) {
    ...

    // Create a Python submodule named ops under framework.core.
    auto ops_module = m.def_submodule("ops");

    // Bind C++ function ops::matmul to
    // Python function framework.core.ops.matmul().
    ops_module.def(
        "matmul",                  // Function name in Python.
        &ops::matmul,              // Pointer to the C++ function.
        "(m, k), (k, n) -> (m, n)" // Docstring of the function in Python.
    );
    ...
}
# Define the extension module
ext_modules = [
    Pybind11Extension(
        # The Python import path to the C++ module.
        "framework.core",
        # Pass in all the C++ files needed to compile the module. We put all
        # the .cpp and .h files under framework/core dir.
        glob.glob("framework/core/**/*.cpp", recursive=True),
        # Tell setuptools where to find the pybind11 header files while
        # compiling the C++ extension module.
        include_dirs=[pybind11.get_include()],
    ),
]

# Setup function
setuptools.setup(
    ...
    # Add the C++ extension module to the package.
    ext_modules=ext_modules,
    ...
)

That is it! With the above, we can call the C++ function as a Python function from Python code.

When we build the project, for example, with pip install, the setuptools will compile the C++ code into a .so file in Linux, which is binary file. It is also known as a shared object or shared library. When we call the corresponding Python function in Python, it will automatically convert the arguments of Python types to the parameters of C++ types, and call the C++ function.

In our framework, it is easy to track down an op. Just look at python_bind.py and see which C++ function it is bound to.

Why mix Python with C++?

The simple answer is we want the speed of C++ and the simplicity of Python programming. There are parts of an ML framework that have to be Python and parts that have to be C++/CUDA.

The user-facing APIs have to be in Python. Python is the programming language of machine learning and AI. All the research scientists and machine learning engineers love Python. It has an amazing ecosystem of machine learning libraries. For the optimal user experience, the frameworks have to have Python APIs.

The underlying data structure of tensors and the tensor ops have to be implemented in C++/CUDA for speed performance. Python is slow, especially with its Global Interpreter Lock (GIL) to prevent parallel programming. So we need C++/CUDA to write efficient data structures for tensors and parallel tensor ops.

Note there are also ongoing efforts to include the merits of both languages in one programming language, like the 🔥 Mojo programming language.

Tensor

Tensor is the most essential data structure used in an ML framework. To be more efficient and easier to explicitly manage the memory allocation, they are usually implemented in C++/CUDA.

In modern ML frameworks, tensors are split into the data and the view. The data contains the actual values in the tensor occupying a large amount of memory. The view is how these data are presented to the ops or users.

The data can be a contiguous chunk of memory of floats or integers. The view can contain the shape of the tensor, the order of the axis, the step size to iterate the data, and so on.

This data-view split allows many ops to be super efficient. For example, if you want to transpose a tensor or matrix, you only need to create a new view indexing the values differently. No copy or modification to the data is needed.

For simplicity, we chose a simple split strategy to implement the Tensor class in our framework. The data part is always a flattened C++ vector. The view only contains the shape of the tensor. Data and shape are the two members of the Tensor class.

For example, the following 2x2 matrix can be represented as shape=[2,2] and data = [1,2,3,4]. [1234]

Here is our Tensor class.

class Tensor {
public:
    /**
     * @brief The shape of the tensor, represented as a vector of size_t.
     *
     * Each element in the vector corresponds to the size of a dimension.
     * For example, a tensor with shape {2, 3} would be a 2x3 matrix.
     */
    std::vector<size_t> shape;

    /**
     * @brief The underlying data of the tensor, stored as a contiguous vector
     * of floats.
     *
     * The elements are typically ordered according to a row-major (C-style)
     * layout, though this class itself doesn't enforce a specific layout.
     */
    std::vector<float> data;

    ...
}

Read tensor.h and tensor.cpp for more details.

Tensor operations

Implementing the tensor operations is the most challenging part of building our ML framework. We had to make sure they were parallelized correctly. Let's take a look at the two most important ops we implemented, matrix multiplication, and the backward function of softmax.

Matrix multiplication

Matrix multiplication is important because it is so widely used in neural networks. Optimizing it has a high ROI. The technique we used here is called tiling.

Let's say we have two tensors (matrices) x and y. Also, we have allocated memory for the output tensor. Then, output[i,j] is the inner product of the ith row of x and the j th column of y.

To do this, we implemented a simple enough strategy to show the basic idea. This implementation employs a row-wise parallelization strategy. The rows of the first matrix (x) is divided into chunks (implicitly by the parallel_for construct), and each chunk is processed concurrently by invoking the matmul_task function.

Specifically, for a given sub-range of rows [start_row, end_row) of x, matmul_task computes the corresponding sub-block of the output matrix by multiplying these rows with the entirety of the second matrix (y). If a task is assigned the row slice x[start_row:end_row,:], it calculates output[start_row:end_row,:].

Following is the code of our implementation.

void matmul_task(size_t start_row, size_t end_row, size_t num_x_cols,
                 size_t num_y_cols, const std::vector<float> &x_data,
                 const std::vector<float> &y_data,
                 std::vector<float> &output_data) {
    // Iterate through a specified range of rows of the first matrix (x).
    for (size_t i = start_row; i < end_row; ++i) {
        // For each row of x in the assigned range, iterate through all columns
        // of the second matrix (y).
        for (size_t j = 0; j < num_y_cols; ++j) {
            // This inner loop calculates the dot product (inner product) of
            // the i-th row of x and the j-th column of y. This is the
            // fundamental operation of matrix multiplication.
            float inner_product = 0.0f;
            for (size_t l = 0; l < num_x_cols; ++l) {
                inner_product +=
                    x_data[i * num_x_cols + l] * y_data[l * num_y_cols + j];
            }

            // The computed inner product is the element at the i-th row and
            // j-th column of the output matrix.
            output_data[i * num_y_cols + j] = inner_product;
        }
    }
}

The choice of row-wise sharding simplifies the parallelization logic. However, it's worth noting that more advanced parallel matrix multiplication techniques often involve sharding both input matrices (x and y) to achieve finer-grained parallelism and potentially better data locality, especially on distributed memory systems.

For instance, a common approach involves assigning tasks to compute sub-blocks of the output matrix resulting from the multiplication of a row block of x and a column block of y (e.g., if a thread processes x[a:b,:] and y[:,c:d], it would produce output[a:b,c:d]). These sub-blocks are then combined (tiled) to form the final output matrix.

If the rows and columns are too long, you can further chop them into shorter ones, compute a fraction of the inner product, and add them up later. For example, if a thread processes x[a:b,c1:c2] and y[c1:c2,c:d], it would produce part of the inner product using part of the two vectors (between dimensions of c1 and c2). Then, we need to sum up all the produced parts to get output[a:b,c:d].

There is even fully optimized hardware just for matrix multiplication. We usually refer to them as application-specific integrated circuit (ASIC), for example, the Tensor Cores in NVIDIA GPUs, and the MXUs in Google's TPUs. To use them, you will need to chop the matrix, feed them to the ASICs, and tile back the output.

Backward functions

A quick recap of forward and backward functions. The forward function is the function of the op itself, for example, softmax, matmul, and relu. It computes the output of the op given the input to the op. The backward function is to compute the gradients w.r.t. the input given the gradients w.r.t. the output. The backward function may also access the input and output of the forward function.

The backward function is useful for the neural network optimizers during backward propagation to compute the gradients of loss w.r.t all the weights of the neural network.

The following figure shows how gradients are computed. In the forward function, we compute z using x and y. In the backward function, we compute the x.grad, y.grad using x, y, and z.grad, where x.grad, y.grad, z.grad are the gradients of loss w.r.t. x, y, z.

To put it in math, if the forward function is z=f(x,y). The backward function is to compute L/x and L/y where L is the loss. We can compute them using the chain rule as follows.

Lx=Lzzx Ly=Lzzy

In practice, x, y, z, and Lz are known at the time we call the backward function. We need to compute Lx using them. As a straightforward way to compute it, we only need to compute zx, and zy, which are computed using x and y.

For a simplified example, given z=f(x,y)=x×y, L/z=3, x=4, y=5, z=20 are scalars, we have zx=y, zy=x. So, we can compute the gradients as follows. Lx=Lzzx=Lzy=35=15 Ly=Lzzy=Lzx=34=12

Note that in the example above, the output of the forward function, z, is not directly involved. However, we may still need the output in some cases as a trick for performance (running speed) optimization. The backward function of softmax is an example.

The backward function of softmax

Another interesting op to dive into is the backward function of softmax. We will use it to showcase the recompute technique in op implementations.

Here is what the forward function of softmax does. Given the logits, a vector of value range (,), the softmax function computes the probabilities of the different classes, which is also a vector with all non-positive values sum up to one. To put it in math, it looks as follows.

y=softmax(x)

softmax(x)i=exij=1Cexj

where x is the logit vector, e is the base of the Natural Logarithm, C is the number of classes. For numerical stability, we usually normalize x by subtracting max(x) on every element. Note that we use y as the output of the softmax function since it is a unary operator with no second operand involved other than x.

Now, let's see what the backward function of softmax looks like. According to matrix calculus, the gradients of loss w.r.t. x, which is the input to the softmax function, are as follows.

Lx=(yx)TLy=(Jsoftmax(x))TLy,

Jsoftmax(x)ij=yixj={yi(1yi)if i=j,yiyjif ij,

where Jsoftmax is the Jacobian matrix of the softmax function, x is the input to the softmax function, y is the output of the softmax function.

A quick recap of what is a Jacobian matrix. It is a matrix of the derivatives of the output of a function w.r.t. the input of the function. Since both the output y and the input x are vectors, the derivative is not a single value but a matrix of shape len(y) by len(x), whose jth element on the ith row is the derivative of yi w.r.t. xj.

As you can see, we only need y and Ly, which we should already have at the time we call the backward function, to compute Lx. The value of x is not used.

Now, we are facing two choices of how to implement the backward function:

The second option is called recomputing. It is widely used in performance optimizations for models. It is typically useful when the model is memory-bounded, which means the memory accessing time is the bottleneck of the model performance, and retrieving y from memory is more expensive than recomputing y using x. We chose the second option for our implementation to illustrate how this technique is applied.

In the following implementation, we deal with a range of rows, where each row is a logit vector x. When dealing with a single row, we compute the common denominator of y, which is j=1Cexj, so that we can easily compute any element of y on the fly as we build the Jacobian matrix.

We compute the Jacobian matrix element by element using a single C++ variable jacobian_jk without storing the entire matrix or any vector. Each element is used right away as they are produced and discarded afterward. No large memory access is needed.

void softmax_backward_task(size_t start_row, size_t end_row, size_t n,
                           const std::vector<float> &output_grad_data,
                           const std::vector<float> &x_data,
                           std::vector<float> &x_grad_data) {
    for (size_t i = start_row; i < end_row; ++i) {
        size_t row_start = i * n;
        float max_val = -std::numeric_limits<float>::infinity();

        // We first compute the softmax forward pass on the fly.

        // Compute the max value in x.
        for (size_t j = 0; j < n; ++j) {
            if (x_data[row_start + j] > max_val) {
                max_val = x_data[row_start + j];
            }
        }

        float sum_exp = 0.0f;
        // Calculate the exponential of each element and the sum of exponentials
        for (size_t j = 0; j < n; ++j) {
            sum_exp += std::exp(x_data[row_start + j] -
                                max_val); // Subtract max_val for stability
        }

        // Now, we have max_val, and sum_exp, which will be used to compute any
        // s[i] on the fly.

        // Iterate the elements in one row of x_grad.
        for (size_t j = 0; j < n; ++j) {
            float inner_product = 0.0f;
            // Compute s[j]
            float softmax_j =
                std::exp(x_data[row_start + j] - max_val) / sum_exp;
            // Iterate the elements in one row of Jacobian and output_grad.
            for (size_t k = 0; k < n; ++k) {
                // Compute s[k]
                float softmax_k =
                    std::exp(x_data[row_start + k] - max_val) / sum_exp;
                // Compute J[j][k]
                float jacobian_jk =
                    softmax_j * ((j == k ? 1.0f : 0.0f) - softmax_k);
                // Add J[j][k] * output_grad_row[k] to the inner product.
                inner_product += jacobian_jk * output_grad_data[row_start + k];
            }
            // Assign the inner product to x_grad[i][j].
            x_grad_data[row_start + j] = inner_product;
        }
    }
}

Please refer to ops.h and ops.cpp for the ops implementations.

Automatic differentiation

Automatic differentiation, also known as autodiff or autograd, is a mechanism to compute the gradients of the loss w.r.t. the weights of the neural network automatically. The user only defines the forward pass. The framework records the ops performed on each tensor and computes the gradients. This is a core feature of an ML framework.

Before we dive into the implementation, There are a few things we need to clarify.

Compute graph

The following figure shows the compute graphs of the forward pass and the backward pass of a super simple neural network, whose inputs are a, b, d, and the output is e.

In the graph, each node is a tensor, and each edge or two edges pointing to the same output tensor is an op.

To put the compute shown in the figure in code, it should look like the following.

def forward(a, b, d):
    c = ops.add(a, b)
    e = ops.multiply(c, d)
    return e

# Run if we call e.backward().
def backward(a, b, c, d, e):
    c.grad, d.grad = ops.multiply_backward(output_grad=e.grad, inputs=(c, d))
    a.grad, b.grad = ops.add_backward(output_grad=c.grad, inputs=(a, b))

In the forward pass, we compute the output e from the inputs. In the backward pass, we compute the gradients of e w.r.t. each tensor (i.e. a.grad, b.grad, c.grad, d.grad).

The user defines the forward pass of the neural network. We need to automatically figure out the backward pass when the user calls e.backward(), or, in most cases, loss.backward().

Tracing

Tracing refers to the process of extracting the compute graph from the user-defined forward pass. It is called tracing because it is the process of tracing the paths of tensors flowing in the compute graph.

There are many different ways to do this, but we implemented the simplest one. We just record the op history on every tensor, namely OpRecord. We use this information to reconstruct the compute graph from the very last output tensor, which, in most cases, is the loss.

What information do we store in an OpRecord? In the example above, we just record c and d in e.op_record, and record a and b in c.op_record. So, we record the input tensors of the op in the OpRecord of the output tensor. This information should be good enough for us to reconstruct the entire compute graph by starting a breadth-first search (BFS) from the last output tensor.

Backward propagation

During backward propagation, we want to compute the gradients w.r.t. all the tensors. So, every time, we just call the backward function of the op with the gradients and the input tensors of the op. So, we need one more piece of information in the OpRecord, the backward function of the op to call.

Following is our implementation of the OpRecord.

class OpRecord:
    def __init__(self, func_backward, input_tensors, output_tensor):
        self.func_backward = func_backward
        self.input_tensors = input_tensors
        self.output_tensor = output_tensor

So, it only has 3 attributes:

Here are the steps to do it.

Please refer to op_record.py, forward.py, backward.py, autograd.py for more details, for example, why we need topological sorting.

End-to-end workflow

So far, we have reviewed all the major components of the framework. Let's connect the dots by describing an end-to-end workflow.

Let's look at the code example at the beginning of this article.

import framework

x = framework.Tensor.from_numpy(np.array([[2.0, 3.0]], dtype=np.float32))
y = framework.Tensor.from_numpy(np.array([[4.0], [5.0]], dtype=np.float32))
z = framework.ops.matmul(x, y)  # Expected: [[2*4 + 3*5]] = [[23.0]]
loss = framework.ops.sum(z)

loss.backward()  # Compute the gradients of loss w.r.t. tensor x, y, and z.

print("x.grad:", x.grad.numpy())  # Expected: [[4.0, 5.0]]
print("y.grad:", y.grad.numpy())  # Expected: [[2.0], [3.0]]

First, we constructed two tensors x and y from NumPy arrays. The framework.Tensor is a Python class, it uses the C++ implementation of the tensor, framework.core.Tensor, as its attribute. The NumPy array will be passed to the constructor of the C++ class. The argument type conversion from Python types to C++ types is handled by Pybind11.

Then, we call two ops matmul and sum on them to get the single value tensor, the loss. This defines the forward pass of the model. framework.ops.matmul and framework.ops.sum are just Python wrappers of the C++ implementation, framework.core.ops. Before calling the C++ code, the Python wrapper stores x, y, and matmul_backward() in z.op_record. It also does something similar to loss.op_record.

Finally, when loss.backward() is called, the autograd module trace and build the compute graph using the op records, and compute the gradients calling the backward functions of the ops.

So far, we have gone through all the features we have implemented in our ML framework. The main challenges we faced were the Python C++ interfacing and ops implementations.

If you follow everything we introduced, you can explore the heavily commented codebase and try to understand all the code by yourself. You will feel confident to implement a framework by yourself. I encourage you to do so to better understand all the details.

Other features

In the rest of the article, we will explore the important features that are not included in our framework. We will explain what they are, why they are important, and how they can be added to our framework.

GPU backend

A GPU backend of an ML framework should allow you to run your models on GPU for faster speed. The implementation requires adding a new C++ extension module to the Python library. This new module should implement a new Tensor class and all the ops using CUDA, a programming language similar to C++ but for GPUs specifically. The Python code of the framework should dynamically route to different backends based on the user configuration.

It is not hard to implement the GPU backend, but extremely challenging if you want it to run fast. There are a lot of tricks in CUDA kernel implementation. A CUDA kernel is a function that is executed on a GPU making use of its parallel processing capabilities.

More precisions

In our framework, we only support float32 as the data type for all the tensors. However, people may want to use a different precision like bfloat16, int8, and so on. This allows the user to control the compute precision on a finer-grained level to better address their speed and power requirements.

The implementation of this feature is mainly about abstraction. To reduce code duplication, we need to make code for the ops general enough for more data types and also leave space for custom optimization for certain data types.

More tensor types

In many applications, people may find sparse tensors and ragged tensors useful. For example, sparse tensors are widely used in the embedding lookup tables of recommendation systems. Ragged tensors, which means a dimension of the model may have different lengths, for example, the length of the input sentences to a natural language processing model.

To implement these, we need to make the Tensor class of the framework extensible, and make the ops more supportive for different types of tensors. We may even need to introduce new hardware feature support to better process sparse tensors, for example, making use of the sparse core of Google TPUs.

Data pipelining

If not careful, the data pipeline can become the performance bottleneck of the overall ML process. The data pipelining module needs to load the data in advance to feed the model.

The implementation is rather orthogonal to what we have introduced above. It should prefetch the data on the compute device, for example, the HBM of GPUs before the model actually requests that data for computation.

ML compiler

ML compiler is a big topic. We will dig more into the ML compilers in a future blog post. The ML compiler is different from a traditional programming language compiler, it does not compile any programming language, but optimizes the compute graph and lowers it down to execute on a GPU/TPU.

Here is an example of an optimization that an ML compiler can do. We use softmax to convert logits z into probability p and use categorical cross-entropy as the loss function to get the loss value L, shown as follows.

The softmax function: p=softmax(z) pi=softmax(z)i=ezij=1Kezj

The categorical cross-entropy loss: L=i=1Kyilog(pi)

where y=(y1,y2,,yK) is the one-hot encoded label.

It would be time-consuming to call the backward function of all the ops involved (exp, log, divide, multiply, sum) one by one to get the gradients of the loss L w.r.t. the logits z denoted as L/z.

However, if you derive it by hand, you can find that L/z can be simplified to py.

Lz=py

The ML compiler should have a built-in rule to capture this pattern of a long chain of backward function in the compute graph and optimize it to py.

Regarding the implementation of an ML compiler, it takes tremendous effort to build a good one. It is far beyond what a solo project can do. It should capture the compute graph through tracing, and serialize the compute graph to some intermediate representations (IR). The compiler takes the IR as input and optimizes it to output another IR. After multiple levels of such optimization, the IR should be lowered down to the hardware to execute the computation.

Distributed training

Finally, here is the killer feature of an ML framework. It is almost impossible to do distributed training without an ML framework. It distributes the compute and memory usage across multiple machines and multiple GPUs. It allows the users to have an experience similar to running the model on a single machine.

There are many different distribution techniques in machine learning. We will also have a future post dedicated to this topic.

Conclusions

So far, we have gone through the basics of building an ML framework from scratch including Python and C++ interfacing, tensor data structures, tensor operations, and automatic differentiation.

As you can see, it is not difficult to build a simple working ML framework. The challenge is to make it efficient enough for all the different scenarios. It may include running on different hardware, on different numbers of machines, with different data types and tensor types. To support these scenarios, it may require a good ML compiler, distribution features, optimized op implementations, and so on.

After reading this article, hope you are more confident about your knowledge on ML frameworks now.