Introduction to Reinforcement Learning

Subscribe Send me a message home page tags


#reinforcement learning  #value iteration  #policy iteration 

Core Concepts

Agent

The agent is the decision maker. It makes decisions based on the current state of the agent and the information provided by the environment such as rewards.

In reinforcement learning, the agent is also the solution to the problem.

Environment

The environment is the surroundings of the agent. It interacts with the agent by providing information and rewards. An environment can change by itself or the change can be caused by an action from the agent.

An environment has its state and changes in the environment can be viewed as state transitions. However, this state may not be fully accessible to the agent. For example, suppose we put a video camera in a meeting room and in this case the environment is the meeting room. If we have an agent inside the camera that is tracking the movement of people in the meeting room. In this case, the agent can only access visual information, and the audio information in the environment is not accessible to the camera.

The problem that the reinforcement learning algorithm aims to solve can be formulated in terms of the environment. Conceptually, an environment provides the following two services:

The second service is in the form of \(P(s',r|s,a,)\). This probability is not always available to the agent though.

Interaction between Environment and Agent

environment_and_agent.jpg

Reward

Rewards are feedback provided by the environment to the agent. Rewards are

Overall, it's a hint of the goodness of the action taken by the agent.

Discount Rate

Future rewards are less attractive than the near-term reward. Therefore, we need to apply a discount rate to calculate its present value. This is similar to the discount rate used in the cash flow calculation in finance.

State Value Function

The state value function is also called V-function. It's the value of a given agent state.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
/**
 * Interface of the state-value function (V-function).
 *
 * @param <STATE> The state of the agent.
 */
public interface StateValueFunction<STATE> {
    /**
     * Gets the value of the state.
     *
     * @param s The agent state.
     * @return The value of the state.
     */
    double getValue(STATE s);
}

Mathmatically,

$$ V_{\pi}(s) = \sum_{a} \Pi(a|s) \sum_{s', r} P(s', r|s, a) [r + \gamma V_{\pi}(s')] $$

Action Value Function

The action value function is also called Q-function. It's the value of taking action \(a\) given the current state \(s\).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/**
 * Interface for the action value function (Q-function).
 *
 * @param <STATE> The state of the agent.
 * @param <ACTION> The action of the agent.
 */
public interface ActionValueFunction<STATE, ACTION> {
    /**
     * Gets the value of taking the specific action given the state.
     *
     * @param s The state of the agent.
     * @param a The action to take.
     * @return The value of the action given the state.
     */
    double getValue(STATE s, ACTION a);
}

Mathmatically,

$$ Q_{\pi} (s, a) = \sum_{s',r} P(s', r|s, a) [r + \gamma V_{\pi}(s')] $$
State value function and action value function are equivalent.

If we have the state value function V, we can calculate the action value function Q using its definition. If we have the action value function, we can rewrite

$$ \begin{eqnarray} V_{\pi}(s) & = & \sum_{a} \Pi(a|s) \sum_{s', r} P(s', r|s, a) [r + \gamma V_{\pi}(s')] \\ & = & \sum_{a} \Pi(a|s) Q_{\pi}(s,a) \end{eqnarray} $$

Policy

A policy prescirbes actions to take for a given non-terminal state.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/**
 * Interface for policies.
 *
 * @param <STATE> The state of the agent.
 * @param <ACTION> The action of the agent.
 */
public interface Policy<STATE, ACTION> {

    /**
     * Selects an action to take.
     *
     * @param s The state of the agent.
     * @return The action to take.
     */
    ACTION selectAction(STATE s);

}

Algorithm

In this section, we present two algorithms to calculate the state value function.

We assume the probability P(s',r|s,a), which is the behavior of the environment is accessible to the agent in this section.

Policy Iteration

There are two steps in policy iteration:

Policy evaluation computes the state value function. As we mentioned earlier, the state value function and the action value function are equivalent. So after the policy evaluation step, we have the action value function too. From an action value function, we can build a new greedy policy: for each state, we select the optimal action.

The math involved in policy evaluation is given as follows:

$$ V_{k+1}(s) = \sum_{a} \Pi(a|s) \sum_{s', r} P(s', r|s, a) [r + \gamma V_{k}(s')] $$

Value Iteration

The value iteration method is based on the following formula:

$$ V_{k+1}(s) = \max_{a} \left\{ \sum_{s', r} P(s', r|s, a) [r + \gamma V_{k}(s')] \right\} $$

Here is a simple implementation:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
public class ValueIteration<STATE, ACTION> {

    private final List<STATE> states;
    private final List<ACTION> actions;

    public Pair<IStateValueFunction<STATE>, IPolicy<STATE, ACTION>> run(
            final EnvironmentProxy<STATE, ACTION> envProxy,
            double discountRate,
            double convergenceThreshold) {

        // Initialize the state value function
        final IStateValueFunction stateValueFunction = new StateValueFunction<>();
        this.states.forEach(s -> stateValueFunction.setValue(s, 0));

        final IActionValueFunction actionValueFunction = new ActionValueFunction<>();

        for (;;) {
            actionValueFunction.reset();

            for (var state : this.states) {
                for (var action : this.actions) {

                    for (final StateTransitionRecord<STATE> record : envProxy.generateScenarios(state, action)) {
                        final double delta = record.getProbability()
                                * (record.getReward() + discountRate * stateValueFunction.getValue(record.getToState()));

                        actionValueFunction.setValue(state, action,
                                actionValueFunction.getValue(state, action) + delta);
                    }
                }
            }

            if (computeDifference(stateValueFunction, actionValueFunction) < convergenceThreshold) {
                break;
            }

            stateValueFunction.copyFrom(this.deriveStateValueFunction(actionValueFunction));
        }


        return Pair.of(stateValueFunction, derivePolicy(actionValueFunction));
    }

    private static <STATE, ACTION> IStateValueFunction<STATE> deriveStateValueFunction(
            final IActionValueFunction<STATE, ACTION> q) {

        final IStateValueFunction<STATE> v = new StateValueFunction<>();

        for (final var state : q.getStates()) {
            double maxValue = Double.MIN_VALUE;
            for (final var action : q.getActions()) {
                maxValue = Math.max(maxValue, q.getValue(state, action));
            }
            v.setValue(state, maxValue);
        }

        return v;
    }

    private static <STATE, ACTION> IPolicy<STATE, ACTION> derivePolicy(
            final IActionValueFunction<STATE, ACTION> q) {
        final Policy<STATE, ACTION> policy = new Policy<>();

        for (final var state : q.getStates()) {
            double maxValue = Double.MIN_VALUE;
            ACTION optimalAction = null;

            for (final var action : q.getActions()) {
                final double value = q.getValue(state, action);

                if (value > maxValue) {
                    maxValue = value;
                    optimalAction = action;
                }
            }

            policy.setAction(state, optimalAction);
        }

        return policy;
    }

    private static <STATE, ACTION> double computeDifference(
            final IStateValueFunction<STATE> v,
            final IActionValueFunction<STATE, ACTION> q) {

        final var w = deriveStateValueFunction(q);
        double maxValue = Double.MIN_VALUE;

        for (final var state : v.getStates()) {
            maxValue = Math.max(maxValue, Math.abs(v.getValue(state) - w.getValue(state)));
        }
        return maxValue;
    }
}

----- END -----

If you have questions about this post, you could find me on Discord.
Send me a message Subscribe to blog updates

Want some fun stuff?

/static/shopping_demo.png