一、概述
本代码实现了在网格世界环境中使用 TD (0)(Temporal Difference (0))算法进行策略评估,并对评估结果进行可视化展示。通过模拟智能体在网格世界中的移动,不断更新状态值函数,最终得到每个状态的价值估计。
二、依赖库
numpy
:用于进行数值计算,如数组操作、随机数生成等。matplotlib.pyplot
:用于绘制图形,将状态值函数的评估结果进行可视化。
三、代码结构与详细说明
1. 导入库
收起
python
import numpy as np
import matplotlib.pyplot as plt
导入 numpy
库并将其别名为 np
,导入 matplotlib.pyplot
库并将其别名为 plt
,以便后续使用。
2. 定义网格世界环境类 GridWorld
收起
python
# 定义网格世界环境
class GridWorld:def __init__(self, size=10):self.size = sizeself.terminal = (3, 3) # 终止状态
- 功能:初始化网格世界环境。
- 参数:
size
:网格世界的大小,默认为 10x10 的网格,即网格的边长为 10。self.terminal
:终止状态的坐标,这里设置为(3, 3)
,当智能体到达该状态时,一个回合结束。
收起
python
def is_terminal(self, state):return state == self.terminal
- 功能:判断给定的状态是否为终止状态。
- 参数:
state
:要判断的状态,以坐标元组(x, y)
的形式表示。
- 返回值:如果
state
等于self.terminal
,返回True
;否则返回False
。
收起
python
def step(self, state, action):x, y = stateif self.is_terminal(state):return state, 0 # 终止状态不再变化
- 功能:模拟智能体在当前状态下采取指定动作后的状态转移和奖励获取。
- 参数:
state
:当前状态,以坐标元组(x, y)
的形式表示。action
:智能体采取的动作,用整数表示,0 表示向上,1 表示向下,2 表示向左,3 表示向右。
- 处理逻辑:如果当前状态是终止状态,则直接返回当前状态和奖励 0,因为终止状态不会再发生变化。
收起
python
# 定义动作:0=上, 1=下, 2=左, 3=右dx, dy = [(-1,0), (1,0), (0,-1), (0,1)][action]new_x = max(0, min(self.size - 1, x + dx))new_y = max(0, min(self.size - 1, y + dy))new_state = (new_x, new_y)
- 处理逻辑:
- 根据动作编号
action
从列表[(-1,0), (1,0), (0,-1), (0,1)]
中选取对应的偏移量(dx, dy)
。 - 计算新的
x
和y
坐标,使用max
和min
函数确保新坐标在网格世界的边界内(范围是 0 到self.size - 1
)。 - 将新的坐标组合成新的状态
new_state
。
- 根据动作编号
收起
python
reward = -1 # 每步固定奖励return new_state, reward
- 处理逻辑:每走一步给予固定奖励 -1,然后返回新的状态和奖励。
3. TD (0) 策略评估函数 td0_policy_evaluation
收起
python
# TD(0) 策略评估
def td0_policy_evaluation(env, episodes=5000, alpha=0.1, gamma=1.0):V = np.zeros((env.size, env.size)) # 初始化状态值函数
- 功能:使用 TD (0) 算法对策略进行评估,得到每个状态的价值估计。
- 参数:
env
:网格世界环境对象,用于获取环境信息和执行状态转移。episodes
:训练的回合数,默认为 5000。alpha
:学习率,控制每次更新状态值函数时的步长,默认为 0.1。gamma
:折扣因子,用于权衡当前奖励和未来奖励的重要性,默认为 1.0。
- 处理逻辑:初始化一个大小为
(env.size, env.size)
的二维数组V
,用于存储每个状态的价值估计,初始值都为 0。
收起
python
for _ in range(episodes):state = (0, 0) # 初始状态while True:if env.is_terminal(state):break
- 处理逻辑:
- 进行
episodes
个回合的训练,每个回合从初始状态(0, 0)
开始。 - 使用
while
循环不断执行动作,直到智能体到达终止状态。
- 进行
收起
python
# 随机策略:均匀选择动作(上下左右各25%概率)action = np.random.randint(0, 4)next_state, reward = env.step(state, action)
- 处理逻辑:
- 使用
np.random.randint(0, 4)
随机选择一个动作,每个动作的选择概率为 25%。 - 调用环境的
step
方法,执行选择的动作,得到下一个状态next_state
和奖励reward
。
- 使用
收起
python
# TD(0) 更新公式td_target = reward + gamma * V[next_state]td_error = td_target - V[state]V[state] += alpha * td_error
- 处理逻辑:
- 根据 TD (0) 算法的更新公式,计算目标值
td_target
,即当前奖励加上折扣后的下一个状态的价值估计。 - 计算 TD 误差
td_error
,即目标值与当前状态的价值估计之差。 - 使用学习率
alpha
乘以 TD 误差,更新当前状态的价值估计。
- 根据 TD (0) 算法的更新公式,计算目标值
收起
python
state = next_state # 转移到下一状态
- 处理逻辑:将当前状态更新为下一个状态,继续下一次循环。
收起
python
return V
- 返回值:返回经过
episodes
个回合训练后得到的状态值函数V
。
4. 运行算法
收起
python
# 运行算法
env = GridWorld()
V = td0_policy_evaluation(env, episodes=1000)
- 处理逻辑:
- 创建一个
GridWorld
类的实例env
,初始化网格世界环境。 - 调用
td0_policy_evaluation
函数,对环境进行 1000 个回合的策略评估,得到状态值函数V
。
- 创建一个
5. 可视化结果函数 plot_value_function
收起
python
# 可视化结果
def plot_value_function(V):fig, ax = plt.subplots()im = ax.imshow(V, cmap='coolwarm')
- 功能:将状态值函数
V
进行可视化展示。 - 处理逻辑:
- 创建一个图形对象
fig
和一个坐标轴对象ax
。 - 使用
ax.imshow
函数将状态值函数V
以图像的形式显示出来,使用coolwarm
颜色映射。
- 创建一个图形对象
收起
python
for i in range(V.shape[0]):for j in range(V.shape[1]):text = ax.text(j, i, f"{V[i, j]:.1f}",ha="center", va="center", color="black")
- 处理逻辑:遍历状态值函数
V
的每个元素,在对应的图像位置上添加文本标签,显示该状态的价值估计,保留一位小数。
收起
python
ax.set_title("TD(0) Estimated State Value Function")plt.axis('off')plt.colorbar(im)plt.show()
- 处理逻辑:
- 设置图形的标题为 “TD (0) Estimated State Value Function”。
- 关闭坐标轴显示。
- 添加颜色条,用于显示颜色与数值的对应关系。
- 显示绘制好的图形。
6. 调用可视化函数
收起
python
plot_value_function(V)
调用 plot_value_function
函数,将经过 TD (0) 算法评估得到的状态值函数 V
进行可视化展示。
四、注意事项
- 可以根据需要调整
GridWorld
类的size
参数和terminal
属性,改变网格世界的大小和终止状态的位置。 - 可以调整
td0_policy_evaluation
函数的episodes
、alpha
和gamma
参数,优化策略评估的效果。 - 代码中的随机策略是简单的均匀随机选择动作,可根据实际需求修改为更复杂的策略。
完整代码
import numpy as np
import matplotlib.pyplot as plt# 定义网格世界环境
class GridWorld:def __init__(self, size=10):self.size = sizeself.terminal = (3, 3) # 终止状态def is_terminal(self, state):return state == self.terminaldef step(self, state, action):x, y = stateif self.is_terminal(state):return state, 0 # 终止状态不再变化# 定义动作:0=上, 1=下, 2=左, 3=右dx, dy = [(-1,0), (1,0), (0,-1), (0,1)][action]new_x = max(0, min(self.size - 1, x + dx))new_y = max(0, min(self.size - 1, y + dy))new_state = (new_x, new_y)reward = -1 # 每步固定奖励return new_state, reward# TD(0) 策略评估
def td0_policy_evaluation(env, episodes=5000, alpha=0.1, gamma=1.0):V = np.zeros((env.size, env.size)) # 初始化状态值函数for _ in range(episodes):state = (0, 0) # 初始状态while True:if env.is_terminal(state):break# 随机策略:均匀选择动作(上下左右各25%概率)action = np.random.randint(0, 4)next_state, reward = env.step(state, action)# TD(0) 更新公式td_target = reward + gamma * V[next_state]td_error = td_target - V[state]V[state] += alpha * td_errorstate = next_state # 转移到下一状态return V# 运行算法
env = GridWorld()
V = td0_policy_evaluation(env, episodes=1000)# 可视化结果
def plot_value_function(V):fig, ax = plt.subplots()im = ax.imshow(V, cmap='coolwarm')for i in range(V.shape[0]):for j in range(V.shape[1]):text = ax.text(j, i, f"{V[i, j]:.1f}",ha="center", va="center", color="black")ax.set_title("TD(0) Estimated State Value Function")plt.axis('off')plt.colorbar(im)plt.show()plot_value_function(V)