Post

Coding the AlphaZero search algorithm from scratch - Part 2: Implementing the MCTS algorithm

Coding the AlphaZero search algorithm from scratch - Part 2: Implementing the MCTS algorithm

Now that we’ve seen in the previous article how the MCTS algorithms works, let us implement it on an actual game: chess! The reasons why I chose this game are:

  • I’m familiar with chess, and thus can understand what the frick the AI is doing;
  • AlphaZero is supposed to work well when playing chess, so we’re supposed to see some improvements between the pure MCTS version and the improved AlphaZero one.

This post is part of a series about AlphaZero. You can find the other posts here.

Implementing a generic MCTS algorithm

The game state

First of all, we have to represent what a game state is, and what we should be able to do with it. This includes:

  • Determining whether a game state is terminal, and returning the winner if it is;
  • Get the list of all possible actions from this game state;
  • Computing the game state resulting from the application of an action on the current game state.

Defining the actions

Right away, we see that we’ll also need to represent what an action is. Since we want to be as generic as possible, we almost don’t require anything from an action. We will only require that it implements the __eq__ and __hash__ dunder methods, so that one can search an Action in a list:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from abc import ABC, abstractmethod
from typing import Self


class Action(ABC):
    """Represent a generic possible action a player can do in a game."""

    @abstractmethod
    def __eq__(self: Self, other: Self) -> bool:
        """Test equality between two Actions."""
        pass

    @abstractmethod
    def __hash__(self: Self) -> int:
        """Return a unique hash value for this Action."""
        pass

Defining the game state

Now that we have our Action defined, let us move to the game state. This essentially boils down to listing everything our game state is supposed to be able to do.

A small remark here: we’ll assume that it’s possible to know which player is to play the position using internal data. For instance, for chess, you could keep count of how many moves were played. If this number is even, then it’s white’s turn, otherwise it’s black.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from abc import ABC, abstractmethod
from typing import Optional, Self

from action import Action


class GameState(ABC):
    """Abstract class to represent a game state.

    It is assumed that the player playing the position can be identified using some internal data.
    For instance, the player whose turn it is could have their pieces represented by positive
    numbers, while the other one would have their pieces represented by negative numbers.
    """

    @abstractmethod
    def get_winner(self: Self) -> Optional[int]:
        """Return the winner of the game, if it exists.

        If the current game state isn't terminal, then None is returned. Otherwise, it returns:
        - 1 if player playing this position won.;
        - -1 if the other player won;
        - 0 in case of a draw.
        """
        pass

    @abstractmethod
    def get_possible_actions(self: Self) -> list[Action]:
        """Return the list of possible actions when playing this game state."""
        pass

    @abstractmethod
    def transition(self: Self, action: Action) -> None:
        """Transform the game state in place from a given action applied on the current state."""
        pass

Though they don’t seem to be useful, these two files will allow a user to simply define the inner workings of the game they’re interested in without worrying about the MCTS implementation. Speaking of which, let us start to code it!

Defining the MCTS nodes

The very first thing we want to do here is to represent a node in our tree. Using this node, we want to be able to:

  • Access and update the node’s statistics, namely its number of visits, wins and loss;
  • Access its parent, in order to get its visit count;
  • Compute its value and UCB score;
  • Determining whether this node is a leaf;
  • Perform the four steps of the MCTS algorithm.

First of all, let’s initialize it:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from math import sqrt
from typing import Self, Optional

from action import Action


class _Node:
    """Represent a node in the MCTS tree."""

    # Constant that represents the exploration/exploitation trade-off
    C: float = sqrt(2)

    @classmethod
    def set_trade_off_constant(cls: type(Self), value: float) -> None:
        """Update the value of the exploration/exploitation trade-off constant."""
        cls.C = value

    def __init__(self: Self, parent: Optional[Self]) -> None:
        """Initialize the node.

        :param parent: Parent node in the MCTS tree.
        """
        self.parent: Optional[Self] = parent

        self.children: list[_Node] = []
        self.actions: list[Action] = []
        self.visits: int = 0
        self.n_wins: int = 0
        self.n_defeats: int = 0

You may have noted that the nodes don’t know which player they are associated with. Indeed, storing this value for each node would be quite wasteful in space. The trick here is that we assumed that we can retrieve the player that is to play the position from the game state. Stating it differently, a game state always represent the position as seen by the player whose turn it is.

In fact, in this implementation, we’ll even assume that players take turns to play. That is, no player can play twice. If you want to get rid of this assumption, then you might need to store which player is associated to a node, and to modify the functions related to the rollout and backpropagation phase.

Similarly, the nodes themselves are oblivious to the state they represent: the state is instead computed on-the-fly each time. There are two reasons for that:

  • it allows to not store the game states, which can get pretty memory-heavy;
  • it allows to include random events, as mentioned in the first part of this series. Note that in order for this to be implemented, the random events should take place within the transition method of a GameState.

Of course, this comes at a computational cost, and if for a given application the game states can be efficiently stored (that is, if copying such a state is cheap), then it might be worth a shot to try to store them directly in the nodes.

Computing a node’s value

First of all, let us start with the computation of a Node’s value. This is a simple one: we simply have to apply the formula that was mentioned earlier. We just have to take care to check whether this node has ever been visited to avoid a division by 0.

1
2
3
4
5
6
7
    @property
    def value(self: Self) -> float:
        """Return the value of the node."""
        if self.visits == 0:
            return 0

        return (self.n_wins - self.n_defeats) / self.visits

Computing a node’s UCB

Similarly, we can implement the ucb property. Since we’ll use a logarithm, we first import it from the math module:

1
from math import sqrt, log

We can now implement the ucb property:

1
2
3
4
5
6
7
8
9
    @property
    def ucb(self: Self) -> float:
        """Return the USB score of this Node."""
        if self.visits == 0:
            return float("inf")

        # No need to check for the parent's visit count to be positive, since it is necessarily
        # larger than that of its child
        return self.value + self.C * sqrt(log(self.parent.visits) / self.visits)

Determining whether a node is a leaf

The is_leaf property also doesn’t present much difficulty: a node is a leaf if it hasn’t any child by definition.

1
2
3
4
    @property
    def is_leaf(self: Self) -> bool:
        """Return True if this Node is a leaf node."""
        return len(self.children) == 0

Selecting a child using the UCB formula

We now move to the select_child method. It seems natural to code this function in a recursive way, because trees generally work naturally well with it. The first thing we have to do is import the GameState we’ve previously created for typing purposes:

1
from game_state import GameState

Now, for the recursion, the base case is simple: we stop as soon as we hit a leaf:

1
2
3
4
    def select_child(self: Self, state: GameState) -> tuple[Self, GameState]:
        """Recursively select a leaf using the exploration/exploitation trade-off formula."""
        if self.is_leaf:
            return self, state

If we’re not at a leaf, we want to find all the children with the maximal score along with the corresponding action. This can be simply done using the key argument of the built-in max function:

1
2
3
4
        best_child, best_action = max(
            zip(self.children, self.actions),
            key=lambda child_action: child_action[0].ucb,
        )

Now that we have our best child, we just have to recursively call its own select_child method to go one step further. Of course, we mustn’t forget to compute the next state. Putting everything together, we arrive at the following function:

1
2
3
4
5
6
7
8
9
10
11
12
    def select_child(self: Self, state: GameState) -> tuple[Self, GameState]:
        """Recursively select a leaf using the exploration/exploitation trade-off formula."""
        if self.is_leaf:
            return self, state

        best_child, best_action = max(
            zip(self.children, self.actions),
            key=lambda child_action: child_action[0].ucb,
        )
        state.transition(best_action)

        return best_child.select_child(state)

The expanding step

Let us move on to the expand function. We know that this function is only called on leaves, and its goal is to create its children. However, we must first ensure that that node we’re at isn’t terminal, in which case we want it to stay a leaf.

In order to create its children, we simply have to list all the possible actions from this game state and create the associated nodes.

The only thing we need to take care about is the fact that if the node is terminal, then state doesn’t change, while if we create a child and select it, it does. Since we’re going to select the new child at random, we’ll need the choice function from the random module:

1
from random import choice

All in all, this leads to the following function:

1
2
3
4
5
6
7
8
9
10
11
12
    def expand(self: Self, state: GameState) -> tuple[Self, GameState]:
        """Expand this Node by listing all the possible actions and randomly choose a child."""
        # If the state is terminal, we don't expand the associated node
        if state.get_winner() is not None:
            return self, state

        self.actions = state.get_possible_actions()
        self.children = [_Node(self) for _ in self.actions]
        chosen_child, chosen_action = choice(list(zip(self.children, self.actions)))
        state.transition(chosen_action)

        return chosen_child, state

Implementing the rollout

We now have to implement the rollout method. There’s not really any catch in this one: we simply play random moves until a winner (or a draw) is found. The implementation is then straightforward:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
    @staticmethod
    def rollout(state: GameState) -> int:
        """Perform the rollout phase of the MCTS.

        This function randomly selects moves until a terminal game state is reached from this state.
        Once such a state has been reached, this function will return:
        - 1 if it resulted in a win for the player playing this position;
        - -1 if it resulted in a loss for the player playing this position;
        - 0 if it resulted in a draw.
        """
        player = 1

        while (winner := state.get_winner()) is None:
            random_action = choice(state.get_possible_actions())
            state.transition(random_action)
            player *= -1

        return winner * player

Apply the backpropagation

Finally, we only have to implement the backpropagation to finish our implementation of a node. There are three things to do on a node:

  • Increment its visit count;
  • Increase its number of draws if the rollout result is a draw, or its number of wins if the rollout value doesn’t match the player to play in this position;
  • Recursively backpropagate if we’re not at the root.

All in all, this leads to the following implementation:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
    def backpropagate(self: Self, rollout_value: int) -> None:
        """Backpropagate the rollout result through the tree.

        This function performs the backpropagation phase of the MCTS algorithm. It updates the value
        of the nodes according to the player playing in their position and the rollout result.

        :param rollout_value: The value of the rollout result. It is equal to 0 if the current 
          player won, 1 if the other player won, and 0.5 in case of a draw.
        """
        self.visits += 1

        # If the current player wins the rollout, then this node's value must decrease
        # This is because intuitively, the current player is the adversary of the player that will
        #   look at this node
        if rollout_value == -1:
            self.n_wins += 1
        elif rollout_value:
            self.n_defeats += 1

        # Unless we're at the root of the tree
        if self.parent is not None:
            self.parent.backpropagate(-rollout_value)

String representation

We’ll also add a __repr__ dunder method for good measure, so that debugging is a little bit easier:

1
2
3
4
5
    def __repr__(self: Self) -> str:
        value = self.value
        visits = self.visits

        return f"_Node({value=}, {visits=})"

The MCTS player

Let us now move to the MCTS class. During its initilization, we need to:

  • Set the trade-off constant to a given value;
  • Initialize the root;
  • Set the number of iterations per play.

All in all, this leads to the following implementation:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class MCTS:
    """A class used to perform an MCTS algorithm."""
    def __init__(
        self: Self,
        state: GameState,
        trade_off_constant: float,
        n_simulations: int,
    ) -> None:
        """Initialize the MCTS algorithm.

        This function sets the global parameters used by the MCTS algorithm and initialize it.

        :param state: The state the game starts in.
        :param trade_off_constant: The trade-off constant as used in the UCB formula.
        :param n_simulations: The number of simulations that are performed from the root.
        """
        _Node.set_trade_off_constant(trade_off_constant)
        self.root = _Node(None)
        self.root_state = state
        self.n_simulations = n_simulations

Transitioning from one game state to another

We then went to add a transition function that will take an Action as argument and progress in the tree accordingly to this Action. This is also straightforward: we simply have to fetch the child corresponding to this action and change both the root and the `root_state:

1
2
3
4
5
6
7
8
9
10
11
12
    def transition(self: Self, action: Action) -> None:
        """Advance the game with a given action.

        This function allows to progress in the game tree without any computation. This may for
        example prove useful if the adversary is exterior to this class.
        """
        index = self.root.actions.index(action)

        self.root = self.root.children[index]
        # Remove parent so that the rest of the tree is garbage collected
        self.root.parent = None
        self.root_state.transition(action)

The reason why we’re setting the new Node’s parent to None is so that we don’t keep any reference to what’s above the new Node is that tree. This allows Python to garbage collect them, so that the associated memory is free. For debugging reasons, you might want to comment out this line so that you can inspect the whole tree.

Note that this is where we need to have __eq__ and/or __hash__ defined for an Action: so that the index method can look for a given Action.

Find the best action

Last but not least, the decide method will glue together everything we’ve done up to now to actually make a decision. First of all, we want to perform our simulations:

1
2
3
4
5
6
        for _ in range(self.n_simulations):
            node = self.root
            node, state = node.select_child(deepcopy(self.root_state))
            node, state = node.expand(state)
            rollout_value = node.rollout(state)
            node.backpropagate(rollout_value)

where we used the deepcopy function from Python’s built-in copy module:

1
from copy import deepcopy

The reason why we need this function is that we assume that the transition function of a GameState modifies the GameState in-place. Since the computational cost come mostly from the rollout part, we want to avoid to create copies there, which is why I settled for an in-place transition function and creating copies only at the start of a simulation.

We now want to fetch the index of the child having the largest visit counts. Once again, it’s possible to use the built-in max function along with the enumerate generator for that:

1
2
3
        chosen_action = self.root.actions[
            max(enumerate(self.root.children), key=lambda x: x[1].visits)[0]
        ]

Finally, if advance is set to True, we want to change the root along with the player associated to it. In all cases, we return the action to take. All in all, this is the final implementation of the decide method:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
    def decide(self: Self, advance: bool = True) -> Optional[Action]:
        """Decide the next move to be played.

        This function performs a number of simulation as specified in the algorithm initialization,
        updating the nodes' scores along the way. Once these simulations have been performed, it
        selects the best child to go with according to its visit counts.

        :param advance: If set to True, in addition to returning the best action, the root will be
          set to the child corresponding to this action.
        """
        for _ in range(self.n_simulations):
            node = self.root
            node, state = node.select_child(deepcopy(self.root_state))
            node, state = node.expand(state)
            rollout_value = node.rollout(state)
            node.backpropagate(rollout_value)

        chosen_action = self.root.actions[
            max(enumerate(self.root.children), key=lambda x: x[1].visits)[0]
        ]

        if advance:
            self.transition(chosen_action)

        return chosen_action

Evaluate the MCTS algorithm on chess

Let us now give a quick example on how to use this code to create an AI for chess. First, we have to represent what an action and a game state are in this case. We won’t reinvent the wheel here and simply create wrappers around the related objects defined in the python-chess library:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from typing import Self, Optional

from chess import Board, Move

from action import Action
from game_state import GameState


class ChessAction(Action):
    def __init__(self, uci: str) -> None:
        self.move = Move.from_uci(uci)

    def __hash__(self) -> int:
        return hash(self.move)

    def __eq__(self, other: Self) -> bool:
        return self.move == other.move

    def __repr__(self) -> str:
        return self.move.__repr__()


class GameStateChess(GameState):
    def __init__(self, board: Board) -> None:
        self.board = board

    def get_winner(self: Self) -> Optional[int]:
        outcome = self.board.outcome(claim_draw=True)

        if outcome is None:
            return None

        if outcome.winner is None:
            return 0

        return -1

    def get_possible_actions(self: Self) -> list[ChessAction]:
        return [ChessAction(move.uci()) for move in self.board.legal_moves]

    def transition(self: Self, action: ChessAction) -> None:
        self.board.push(action.move)

    def __repr__(self):
        return self.board.__repr__()

I’ve also added a __repr__ method to better visualize the state computed by the algorithm.

A small remark on the get_winner function: you may wonder why we don’t check for a win for the current player. The reason for that is that chess, along with many board games, has this nice property that it isn’t possible for a player to win by a move of its adversary. Thus, since a game state always represents the position as seen by the player who is to play it, it isn’t possible to start in a position where the last move made us won! This is clearly a negligible optimization, but on more complex games with this property, this could prove useful.

And we’re done! For instance, let us consider the following position, which is checkmate in 2 for white: Checkmate in 2 for white. FEN: r2qkb1r/pp2nppp/3p4/2pNN1B1/2BnP3/3P4/PPP2PPP/R2bK2R w KQkq - 1 0

In order to get what our AI predicts to be the best move, here’s the code we would write:

1
2
3
4
5
6
7
8
9
10
11
12
from math import sqrt

from chess import Board

from game_state_chess import GameStateChess
from mcts import MCTS

initial_board = Board(fen="r2qkb1r/pp2nppp/3p4/2pNN1B1/2BnP3/3P4/PPP2PPP/R2bK2R w KQkq - 1 0")
mcts_player = MCTS(GameStateChess(initial_board), sqrt(2), 800)
action = mcts_player.decide(advance=False)
print(action)

By running this, we can see that our AI… sucks. It can’t even find a mate in two. In the next article, we are going to investigate why this is the case, and how to remedy this.

This post is licensed under CC BY 4.0 by the author.