文章目录
- 前言
- 一、`if self.time_step<4:`
- 控制时间步的递增
- 判断是否在配送中心
- 特定时间步的操作
- 更新
- 更新当前节点和已选择节点列表
- 更新需求和负载
- 更新访问标记
- 更新负无穷掩码
- 更新步骤状态,将更新后的状态同步到 self.step_state
- 二、使用步骤
- 总结
前言
class CVRPEnv:step(self, selected)
函数是强化学习代码实现中的核心。
精读该代码的目标:
- 熟悉每一个参数的shape。
- 熟悉每个参数之间的关系(剪切,扩展,等)。
一、if self.time_step<4:
控制时间步的递增
# 控制时间步的递增self.time_step=self.time_step+1self.selectex_count = self.selected_count+1
判断是否在配送中心
#判断是否在配送中心self.at_the_depot = (selected == 0)
特定时间步的操作
if self.time_step==3:self.last_current_node = self.current_node.clone()self.last_load = self.load.clone()if self.time_step == 4:self.last_current_node = self.current_node.clone()self.last_load = self.load.clone()self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot)&(self.last_current_node!=0)] = 0
更新
更新当前节点和已选择节点列表
#更新当前节点和已选择节点列表self.current_node = selectedself.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)
参数 | Shape |
---|---|
self.current_node | (batch, pomo) |
self.selected_node_list | (batch, pomo,0~) |
注:0~
表示第三维度逐渐增加
self.selected_node_list
的shape:
self.current_node
的shape:
self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)
,表示先将self.current_node
扩展为三维数据,再将self.current_node
沿着self.selected_node_list
的第三维度(dim=2
)进行依次剪切进去。
更新需求和负载
#更新需求和负载demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)gathering_index = selected[:, :, None]selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)self.load -= selected_demandself.load[self.at_the_depot] = 1 # refill loaded at the depot
参数 | Shape | 含义g |
---|---|---|
self.depot_node_demand | (batch, problem + 1) | 表示每个批次中,每个问题(包括配送中心)对应的节点需求 |
demand_list | (batch, pomo, problem + 1) | 包含每个节点需求的张量 |
selected | (batch, pomo) | 表示每个批次中的每个智能体所选择的节点编号(这些节点是从节点集合中选择的) |
selected_demand | (batch, pomo) | 示每个智能体所选择节点的需求。 |
demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
[:, None, :]
:先在self.depot_node_demand
的第二维(即问题维度)上增加一个新的维度,使其变为(batch_size, 1, problem_size + 1)
。.expand(self.batch_size, self.pomo_size, -1)
:将数据self.depot_node_demand扩展为(batch_size, pomo_size, problem_size + 1)
,表示每个批次中的每个 POMO 智能体都有一份相同的需求数据。
gathering_index = selected[:, :, None]
- 将
selected
进行维度扩展
selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
demand_list
的 shape 是(batch_size, pomo_size, problem_size + 1)
,包含了所有节点的需求数据。gather(dim=2, index=gathering_index)
会按照gathering_index
(即selected
中存储的节点编号)从demand_list
中选择出对应的节点需求。dim=2
表示沿着第三维(即问题维度)进行选择。gather
的结果是一个 shape 为(batch_size, pomo_size, 1)
的张量。.squeeze(dim=2)
去掉了多余的第三维,最终得到selected_demand
,其 shape 是(batch_size, pomo_size)
,表示每个智能体所选择节点的需求。
更新访问标记
#更新访问标记(防止重复选择已访问的节点)self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0 # depot is considered unvisited, unless you are AT the depot
参数 | Shape | 含义 |
---|---|---|
self.visited_ninf_flag | (batch, pomo, problem+ 1) | 记录了每 个智能体(POMO)在每个批次中已访问的节点的信息,标记某些节点是否已经被访问(用负无穷表示)。 |
self.BATCH_IDX | (batch, pomo) | 批次索引的张量 |
self.POMO_IDX | (batch, pomo) | 智能体(POMO)索引的张量 |
selected | (batch, pomo) | 表示每个批次中的每个智能体所选择的节点编号(这些节点是从节点集合中选择的) |
self.at_the_depot | (batch, pomo) | 一个布尔型张量,表示每个智能体是否处于配送中心(即该智能体是否在节点 0,通常是配送中心)。 |
self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected]
表示从 visited_ninf_flag
张量中选择出对应批次和智能体的对应位置,并设置为 float('-inf')
,表示这些节点已经被访问过。
举例:
假设我们有以下参数:
- batch_size = 2,即有 2 个批次。
- pomo_size = 3,即每个批次有 3 个智能体(POMO)。
- problem_size = 4,即有 4 个节点(包含配送中心)。
self.visited_ninf_flag = [[[ 0., 0., 0., 0., 0.], # 第一个批次(batch 0)[ 0., 0., 0., 0., 0.], # POMO 0, POMO 1, POMO 2 各自对节点的访问标志[ 0., 0., 0., 0., 0.]],[[ 0., 0., 0., 0., 0.], # 第二个批次(batch 1)[ 0., 0., 0., 0., 0.],[ 0., 0., 0., 0., 0.]]
]
self.BATCH_IDX(批次索引):
self.BATCH_IDX = [[0, 0, 0], # 第一个批次[1, 1, 1] # 第二个批次
]
self.POMO_IDX(POMO 索引):
self.POMO_IDX = [[0, 1, 2], # 每个批次中三个智能体的索引[0, 1, 2]
]
selected(每个智能体选择的节点):
selected = [[1, 2, 0], # 第一个批次中,智能体选择的节点:POMO 0 选择节点 1,POMO 1 选择节点 2,POMO 2 选择节点 0[3, 1, 2] # 第二个批次中,智能体选择的节点:POMO 0 选择节点 3,POMO 1 选择节点 1,POMO 2 选择节点 2
]
执行这一行代码 self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
。
对于第一个批次(BATCH_IDX[0]),我们有三个智能体(POMO_IDX[0]),选择了节点 [1, 2, 0],分别是:
- selected[0][0] = 1 表示 POMO 0 选择了节点 1。
- selected[0][1] = 2 表示 POMO 1 选择了节点 2。
- selected[0][2] = 0 表示 POMO 2 选择了节点 0。
对于第二个批次(BATCH_IDX[1]),我们同样有三个智能体(POMO_IDX[1]),选择了节点 [3, 1, 2],分别是: - selected[1][0] = 3 表示 POMO 0 选择了节点 3。
- selected[1][1] = 1 表示 POMO 1 选择了节点 1。
- selected[1][2] = 2 表示 POMO 2 选择了节点 2。
更新 visited_ninf_flag: 根据批次索引和 POMO 索引,我们更新了对应位置的值为负无穷 -inf:
- 对于 BATCH_IDX[0] 和 POMO_IDX[0, 1, 2],我们将 selected[0][0] = 1,selected[0][1] = 2,selected[0][2] = 0 位置标记为 -inf。
- 对于 BATCH_IDX[1] 和 POMO_IDX[0, 1, 2],我们将 selected[1][0] = 3,selected[1][1] = 1,selected[1][2] = 2 位置标记为 -inf。
self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0
,我们将所有不在配送中心的智能体的配送中心访问标志设置为 0。
-[:, :, 0]
是一个切片操作,表示我们提取张量中的第一个节点(通常是配送中心节点)。
~self.at_the_depot
是对 self.at_the_depot 张量的布尔取反操作,将 True 变为 False,将 False 变为 True。
更新负无穷掩码
#更新负无穷掩码(屏蔽需求量超过当前负载的节点)self.ninf_mask = self.visited_ninf_flag.clone()round_error_epsilon = 0.00001demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list_2=torch.full((demand_too_large.shape[0],demand_too_large.shape[1],1),False)demand_too_large = torch.cat((demand_too_large, _2), dim=2)self.ninf_mask[demand_too_large] = float('-inf')
参数 | Shape | 含义 |
---|---|---|
self.visited_ninf_flag | (batch, pomo, problem+ 1) | 记录了每 个智能体(POMO)在每个批次中已访问的节点的信息,标记某些节点是否已经被访问(用负无穷表示)。 |
self.ninf_mask | (batch, pomo, problem+ 1) | self.visited_ninf_flag.clone() |
更新步骤状态,将更新后的状态同步到 self.step_state
#更新步骤状态,将更新后的状态同步到 self.step_stateself.step_state.selected_count = self.time_stepself.step_state.load = self.loadself.step_state.current_node = self.current_nodeself.step_state.ninf_mask = self.ninf_mask
参数 | Shape |
---|