Design an Easy-to-Use Deep Learning Framework
The three software design principles I learned as an open-source contributor
April 2024
Deep learning frameworks are extremely transitory. If you compare the deep learning frameworks people use today with what it was eight years ago, you will find the landscape is completely different. There were Theano, Caffe2, and MXNet, which all went obsolete. Today's most popular frameworks, like TensorFlow and PyTorch, were just released to the public.
Through all these years, Keras has survived as a high-level user-facing library supporting different backends, including TensorFlow, PyTorch, and JAX. As a contributor to Keras, I learned how much the team cares about user experience for the software and how they ensured a good user experience by following a few simple yet powerful principles in their design process.
In this article, I will share the 3 most important software design principles I learned by contributing to the Keras through the past few years, which may be generalizable to all types of software and help you make an impact in the open-source community with yours.
Why user experience is important for open-source software
Before we dive into the main content, let’s quickly discuss why user experience is so important. We can learn this through the PyTorch vs. TensorFlow case.
They were developed by two tech giants, Meta and Google, and have quite different cultural strengths. Meta is good at product, while Google is good at engineering. As a result, Google’s frameworks like TensorFlow and JAX are the fastest to run and technically superior to PyTorch, as they support sparse tensors and distributed training well. However, PyTorch still took away half of the market share from TensorFlow because it prioritizes user experience over other aspects of the software.
Better user experience wins for the research scientists who build the models and propagate them to the engineers, who take models from them since they don’t always want to convert the models they receive from the research scientists to another framework. They will build new software around PyTorch to smooth their workflow, which will establish a software ecosystem around PyTorch.
TensorFlow also made a few blunders that caused its users to lose. TensorFlow’s general user experience is good. However, its installation guide for GPU support was broken for years before it was fixed in 2022. TensorFlow 2 broke the backward compatibility, which cost its users millions of dollars to migrate.
So, the lesson we learned here is that despite technical superiority, user experience decides which software the open-source users would choose.
All deep learning frameworks invest heavily in user experience
All the deep learning frameworks—TensorFlow, PyTorch, and JAX—invest heavily in user experience. Good evidence is that they all have a relatively high Python percentage in their codebases.
All the core logic of deep learning frameworks, including tensor operations, automatic differentiation, compilation, and distribution are implemented in C++. Why would they want to expose a set of Python APIs to the users? It is just because the users love Python and they want to polish their user experience.
Investing in user experience is of high ROI
Imagine how much engineering effort it requires to make your deep learning framework a little bit faster than others. A lot.
However, for a better user experience, as long as you follow a certain design process and some principles, you can achieve it. For attracting more users, your user experience is as important as the computing efficiency of your framework. So, investing in user experience is of high return on investment (ROI).
The three principles
I will share the three important software design principles I learned by contributing to Keras, each with good and bad code examples from different frameworks.
Principle 1: Design end-to-end workflows
When we think of designing the APIs of a piece of software, you may look like this.
class Model:
def __call__(self, input):
"""The forward call of the model.
Args:
input: A tensor. The input to the model.
"""
pass
Define the class and add the documentation. Now, we know all the class names, method names, and arguments. However, this would not help us understand much about the user experience.
What we should do is something like this.
input = keras.Input(shape=(10,))
x = layers.Dense(32, activation='relu')(input)
output = layers.Dense(10, activation='softmax')(x)
model = keras.models.Model(inputs=input, outputs=output)
model.compile(
optimizer='adam', loss='categorical_crossentropy'
)
We want to write out the entire user workflow of using the software. Ideally, it should be a tutorial on how to use the software. It provides much more information about the user experience. It may help us spot many more UX problems during the design phase compared with just writing out the class and methods.
Let’s look at another example. This is how I discovered a user experience problem by following this principle when implementing KerasTuner.
When using KerasTuner, users can use this RandomSearch class to select the best model. We have the metrics, and objectives in the arguments. By default, objective equals validation loss. So, it helps us find the model with the smallest validation loss.
class RandomSearch:
def __init__(self, ..., metrics, objective="val_loss", ...):
"""The initializer.
Args:
metrics: A list of Keras metrics.
objective: String or a custom metric function. The
name of the metirc we want to minimize.
"""
pass
Again, it doesn’t provide much information about the user experience. So, everything looks OK for now.
However, if we write an end-to-end workflow like the following. It exposes many more problems. The user is trying to define a custom metric function named custom_metric
. The objective is not so straightforward to use anymore. What should we pass to the objective argument now?
tuner = RandomSearch(
...,
metrics=[custom_metric],
objective="val_???",
)
It should be just "val_custom_metric”
. Just use the prefix of "val_"
and the name of the metric function. It is not intuitive enough. We want to make it better instead of forcing the user to learn this. We easily spotted a user experience problem by writing this workflow.
If you wrote the design more comprehensively by including the implementation of the custom_metric
function, you will find you even need to learn how to write a Keras custom metric. You have to follow the function signature to make it work, as shown in the following code snippet.
def custom_metric(y_true, y_pred):
squared_diff = ops.square(y_true - y_pred)
return ops.mean(squared_diff, axis=-1)
After discovering this problem. We specially designed a better workflow for custom metrics. You only need to override HyperModel.fit()
to compute your custom metric and return it. No strings to name the objective. No function signature to follow. Just a return value. The user experience is much better right now.
class MyHyperModel(HyperModel):
def fit(self, trial, model, validation_data):
x_val, y_true = validation_data
y_pred = model(x_val)
return custom_metric(y_true, y_pred)
tuner = RandomSearch(MyHyperModel(), max_trials=20)
One more thing to remember is we should always start from the user experience. The designed workflows backpropagate to the implementation.
Principle 2: Minimize cognitive load
Do not force the user to learn anything unless it is really necessary. Let’s see some good examples.
The Keras modeling API is a good example shown in the following code snippet. The model builders already have these concepts in mind, for example, a model is a stack of layers. It needs a loss function. We can fit it with data or make it predict on data.
model = keras.Sequential([
layers.Dense(10, activation="relu"),
layers.Dense(num_classes, activation="softmax"),
])
model.compile(loss='categorical_crossentropy')
model.fit(...)
model.predict(...)
So basically, no new concepts were learned to use Keras.
Another good example is the PyTorch modeling. The code is executed just like Python code. All tensors are just real tensors with real values. You can depend on the value of a tensor to decide your path with plain Python code.
class MyModel(nn.Module):
def forward(self, x):
if x.sum() > 0:
return self.path_a(x)
return self.path_b(x)
You can also do this with Keras with TensorFlow or JAX backend but needs to be written differently. All the if conditions need to be written with this ops.cond function as shown in the following code snippet.
class MyModel(keras.Model):
def call(self, inputs):
return ops.cond(
ops.sum(inputs) > 0,
lambda : self.path_a(inputs),
lambda : self.path_b(inputs),
)
This is teaching the user to learn a new op instead of using the if-else clause they are familiar with, which is bad. In compensation, it brings significant improvement in training speed.
Here is the catch of the flexibility of PyTorch. If you ever needed to optimize the memory and speed of your model, you would have to do it by yourself using the following APIs and new concepts to do so, including the inplace arguments for the ops, the parallel op APIs, and explicit device placement. It introduces a rather high learning curve for the users.
torch.relu(x, inplace=True)
x = torch._foreach_add(x, y)
torch._foreach_add_(x, y)
x = x.cuda()
Some other good examples are keras.ops, tensorflow.numpy, jax.numpy. They are just a reimplementation of the numpy API. When introducing some cognitive load, just reuse what people already know. Every framework has to provide some low-level ops in these frameworks. Instead of letting people learn a new set of APIs, which may have a hundred functions, they just use the most popular existing API for it. The numpy APIs are well-documented and have tons of Stack Overflow questions and answers related to it.
The worst thing you can do with user experience is to trick the users. Trick the user to believe your API is something they are familiar with but it is not. I will give two examples. One is on PyTorch. The other one is on TensorFlow.
What should we pass as the pad argument in F.pad()
function if you want to pad the input tensor of the shape (100, 3, 32, 32)
to (100, 3, 1+32+1, 2+32+2)
or (100, 3, 34, 36)
?
import torch.nn.functional as F
# pad the 32x32 images to (1+32+1)x(2+32+2)
# (100, 3, 32, 32) to (100, 3, 34, 36)
out = F.pad(
torch.empty(100, 3, 32, 32),
pad=???,
)
My first intuition is that it should be ((0, 0), (0, 0), (1, 1), (2, 2))
, where each sub-tuple corresponds to one of the 4 dimensions, and the two numbers are the padding size before and after the existing values. My guess is originated from the numpy API.
However, the correct answer is (2, 2, 1, 1)
. There is no sub-tuple, but one plain tuple. Moreover, the dimensions are reversed. The last dimension goes the first.
The following is a bad example from TensorFlow. Can you guess what is the output of the following code snippet?
value = True
@tf.function
def get_value():
return value
value = False
print(get_value())
Without the tf.function
decorator, the output should be False
, which is pretty simple. However, with the decorator, the output is True
. This is because TensorFlow compiles the function and any Python variable is compiled into a new constant. Changing the old variable’s value would not affect the created constant.
It tricks the user into believing it is the Python code they are familiar with, but actually, it is not.
Principle 3: Interaction over documentation
No one likes to read long documentation if they can figure it out just by running some example code and tweaking it by themselves. So, we try to make the user workflow of the software follow the same logic.
Here is a good example shown in the following code snippet. In PyTorch, all methods with the underscore are inplace ops, while the ones without are not. From an interactive perspective, these are good, because they are easy to follow, and the users do not need to check the docs whenever they want the inplace version of a method. However, of course, they introduced some cognitive load. The users need to know what does inplace means and when to use them.
x = x.add(y)
x.add_(y)
x = x.mul(y)
x.mul_(y)
Another good example is the Keras layers. They strictly follow the same naming convention as shown in the following code snippet. With a clear naming convention, the users can easily remember the layer names without checking the documentation.
from keras import layers
layers.MaxPooling2D()
layers.GlobalMaxPooling1D()
layers.GlobalAveragePooling3D()
Another important part of the interaction between the user and the software is the error message. You cannot expect the user to write everything correctly the very first time. We should always do the necessary checks in the code and try to print helpful error messages.
Let’s see the following two examples shown in the code snippet. The first one has not much information. It just says tensor shape mismatch. The second one contains much more useful information for the user to find the bug. It not only tells you the error is because of tensor shape mismatch, but it also shows what is the expected shape and what is the wrong shape it received. If you did not mean to pass that shape, you have a better idea of the bug now.
# Bad example:
raise ValueError("Tensor shape mismatch.")
# Good example:
raise ValueError(
"Tensor shape mismatch. "
"Expected: (batch, num_features). "
f"Received: {x.shape}"
)
The best error message would be directly pointing the user to the fix. The following code snippet shows a general Python error message. It guessed what was wrong with the code and directly pointed the user to the fix.
import math
math.sqr(4)
AttributeError: module 'math' has no attribute 'sqr'.
Did you mean: 'sqrt'?
Final words
So far we have introduced the three most valuable software design principles I have learned when contributing to the deep learning frameworks. First, write end-to-end workflows to discover more user experience problems. Second, reduce cognitive load and do not teach the user anything unless necessary. Third, follow the same logic in your API design and throw meaningful error messages so that the users can learn your software by interacting with it instead of constantly checking the documentation.
However, there are many more principles to follow if you want to make your software even better. You can refer to the Keras API design guidelines as a complete API design guide.