Python is a dynamic language with a large and varied standard library. You can
even use Python to manipulate Python programs as data! That’s called
metaprogramming, and it’s
what we’re going to do using Python’s
ast module. By the end of this post,
we’ll have implemented constant
folding for Python. Constant
folding, stated simply, is the evaluation of constant expressions by a compiler
or interpreter before your program actually runs. (I think CPython already does
some constant folding, but that won’t stop us.)
By the way, I’ll be assuming you know some Python, but the concepts here should apply elsewhere. I’ll also assume you’re using CPython 3.7. If you don’t know what Python implementation you’re using, you’re probably using CPython.
AST? What’s that?
Wait! What's a tree?A tree is a recursive data structure that is used to express hierarchical or nested information. For example, you could use a tree to represent your family tree, a decision tree, regions in your country at various levels from city to country, or an organizational chart. A tree is composed of nodes, which are cells that can hold data. The nodes are connected by edges which indicate parent-child relationships. Here are some resources about trees:
Here’s an example of an AST:
This AST can be textually represented as
args=[Str(s='hello world')], keywords=))]). It represents the following
Let’s go through the structure of this AST. The
Module within the AST
represents the entire module (which is most often a file). It contains an
expression statement (
Expr) which contains a function call (
function being called is
Str node as an argument.
How the Python
ast module works
ast module allows us to parse Python code at runtime and get
back ASTs. It also provides functions to help us process these ASTs.
Let’s get an AST for the code
import ast syntax_tree = ast.parse("print('hello world')")
ast.parse parses code as if it’s in its own module. The function
can parse in other modes, but we won’t cover those here.
We can get a view of what’s in the AST by calling
ast.dump on it:
This will print:
Module(body=[Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Str(s='hello world')], keywords=))])
This is the same AST we saw earlier.
Now, let’s try executing the code represented by this AST. Python makes this
easy for us; we can first compile the AST into a bytecode object using
compile, and then execute that object using
exec(compile(syntax_tree, '<string>', 'exec'))
exec, “hello world” should be printed out.
Example: constant folding
We’re going to figure out how to do some basic constant expression evaluation. We’ll handle the following cases:
- Adding numeric and string literals (for example,
2 + 2and
'hello ' + 'dave')
- Multiplying numeric literals and list/tuple/strings containing only literals
2.0 * 3.0,
'abc' * 8, and
[1, 3, 5] * 0)
There are many cases we could handle, but we have to start somewhere. I suggest you create a new file for the program we’re about to write.
To achieve our goal, we need to recursively run a constant folding algorithm
through the nodes of the AST. In other words, we need to traverse the AST. We
can use the
class to change an AST in place while traversing it recursively from the top
down. To use
NodeTransformer, we’ll subclass it.
Next, we’ll override methods like
visit_Str so that we can perform different
actions depending on the type of node we encounter. For example, the
method is called when the node transformer sees a
Str node. The transformer
uses the return value of each
visit_* method to replace the nodes in the AST,
modifying the structure in place.
Let’s write our transformer code organized to the kinds of expressions we want to evaluate.
Adding numeric and string literals
The type of AST node that represents addition is
BinOp where the node’s
field is an
Add instance. Thus, we’ll override
visit_BinOp and check that
the operation is addition.
Since constant folding is recursive, we should first call
visit on the
right fields of the current node. (This implies a depth-first
traversal of the AST.)
def visit_BinOp(self, node): left = self.visit(node.left) right = self.visit(node.right) # addition case if isinstance(node.op, ast.Add):
right are instances of
ast.Num, we return a new
node that contains the sum of the two numbers.
if isinstance(left, ast.Num) and isinstance(right, ast.Num): return ast.Num(left.n + right.n)
right are instances of
ast.Str, we return a new
node that contains the concatenation of the two strings.
if isinstance(left, ast.Str) and isinstance(right, ast.Str): return ast.Str(left.s + right.s)
Otherwise, we should return another
BinOp node with an
Add as the operator.
return ast.BinOp(left, ast.Add(), right)
Multiplying numeric literals and list/tuple/strings containing only literals
The type of AST node that represents multiplication is
BinOp where the node’s
op field is an
Mult instance. So let’s add another case to our if statement
to handle multiplication.
elif isinstance(node.op, ast.Mult):
First, let’s give a name to the AST node types that describe lists and tuples.
We can call them
sequence_node_types. This value will come in handy later.
sequence_node_types = (ast.List, ast.Tuple)
Now, we’ll handle multiplying numeric literals. This is very similar to what we did for adding numbers.
if isinstance(left, ast.Num) and isinstance(right, ast.Num): return ast.Num(left.n * right.n)
Next, we’ll handle multiplying a string by a number. We will check that the number is a non-negative integer. We also need to allow the string and integer to appear in any order.
elif isinstance(left, ast.Num) and isinstance(right, ast.Str) and isinstance(left.n, int) and left.n >= 0: return ast.Str(left.n * right.s) elif isinstance(left, ast.Str) and isinstance(right, ast.Num) and isinstance(right.n, int) and right.n >= 0: return ast.Str(left.s * right.n)
The third kind of multiplication we want to handle involves lists and tuples. We
can start by checking that one of the operands is a list or tuple display, that
the list or tuple contains only literals, and that the other operand is a
positive integer literal. To check that a list or tuple node contains only
literal nodes, we use a
contains_only_literals function that we define later.
elif isinstance(left, ast.Num) and isinstance(right, sequence_node_types) and contains_only_literals(right) and isinstance(left.n, int) and left.n >= 0: return type(right)(left.n * right.elts, ast.Load()) elif isinstance(left, sequence_node_types) and contains_only_literals(left) and isinstance(right, ast.Num) and isinstance(right.n, int) and right.n >= 0: return type(left)(left.elts * right.n, ast.Load())
In other cases, we can return a
BinOp node that represents the multiplication
else: return ast.BinOp(left, ast.Mult, right)
To account for other
BinOp nodes that didn’t fit into any of the cases above,
we have an
else block that returns a
BinOp node with the same operation type
as the original node and
right as operands.
else: return ast.BinOp(left, node.op, right)
Loose end: the
Earlier, we used a helper function called
contains_only_literals to tell if a
Tuple node contained only AST nodes for literals in its elements.
Now, we define it.
We retrieve the elements of the
sequence_ast node through its
elements = sequence_ast.elts
And, let’s say that a ‘literal’ is either a
literal_types = (ast.Num, ast.Str, ast.Bytes)
Finally, we check if each node in
elements is an instance of one of the
literal_types, and return
True if that’s the case.
return all(isinstance(element, literal_types) for element in elements)
Let’s try to constant-fold the expression
'success' * (1 + 1 + (1 * 2) + 1).
We can do this by first parsing this expression to get an AST.
syntax_tree = ast.parse("'success' * (1 + 1 + (1 * 2) + 1)")
Next, we can use the
visit method (which is inherited from
ConstantFolder class to get the folded AST. Note that a
NodeTransformer (the class
ConstantFolder inherits from) modifies the AST
it’s given, meaning we don’t need the return value of
However, the Python
claim that using the return value is the usual way of using
syntax_tree = ConstantFolder().visit(syntax_tree)
Now, we can dump the value of
syntax_tree and see that the constant
expression has been evaluated!
print('after folding:') print(ast.dump(syntax_tree))
The output should look like this:
Here’s a graphical depiction of this AST:
We’ve successfully used Python abstract syntax trees to implement constant folding! We went over concepts like what abstract syntax trees are, how to traverse and modify them, and some of the node types in Python ASTs.
I encourage going deeper into this topic. Here are some things you can try:
- Constant-folding the subtraction and division of numeric literals (for
6 - 6,
1 - 2j, and
0.0 - -1)
- Generating interesting visualizations of Python code from ASTs
- Parse Python code to compute statistics about the code
- Cyclomatic complexity and number of lines per function, for example
- Write a simple type checker for Python, or explore other properties of programs you can compute