人工智能A*算法与CNN结合- CNN 增加卷积层的数量,并对卷积核大小进行调整

以下是一个增强版的将 A* 算法与卷积神经网络(CNN)结合的代码实现,其中 CNN 增加了卷积层的数量,并对卷积核大小进行了调整。整体思路依然是先利用 A* 算法生成训练数据,再用这些数据训练 CNN 模型,最后使用训练好的模型进行路径规划。

import numpy as np
import heapq
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader# A* 算法实现
class Node:def __init__(self, x, y, g=float('inf'), h=float('inf'), parent=None):self.x = xself.y = yself.g = gself.h = hself.f = g + hself.parent = parentdef __lt__(self, other):return self.f < other.fdef heuristic(current, goal):return abs(current[0] - goal[0]) + abs(current[1] - goal[1])def astar(grid, start, goal):rows, cols = grid.shapeopen_list = []closed_set = set()start_node = Node(start[0], start[1], g=0, h=heuristic(start, goal))heapq.heappush(open_list, start_node)while open_list:current_node = heapq.heappop(open_list)if (current_node.x, current_node.y) == goal:path = []while current_node:path.append((current_node.x, current_node.y))current_node = current_node.parentreturn path[::-1]closed_set.add((current_node.x, current_node.y))neighbors = [(0, 1), (0, -1), (1, 0), (-1, 0)]for dx, dy in neighbors:new_x, new_y = current_node.x + dx, current_node.y + dyif 0 <= new_x < rows and 0 <= new_y < cols and grid[new_x][new_y] == 0 and (new_x, new_y) not in closed_set:new_g = current_node.g + 1new_h = heuristic((new_x, new_y), goal)new_node = Node(new_x, new_y, g=new_g, h=new_h, parent=current_node)found = Falsefor i, node in enumerate(open_list):if node.x == new_x and node.y == new_y:if new_g < node.g:open_list[i] = new_nodeheapq.heapify(open_list)found = Truebreakif not found:heapq.heappush(open_list, new_node)return None# 生成训练数据
def generate_training_data(grid, num_samples):rows, cols = grid.shapeinputs = []outputs = []for _ in range(num_samples):start = (np.random.randint(0, rows), np.random.randint(0, cols))goal = (np.random.randint(0, rows), np.random.randint(0, cols))path = astar(grid, start, goal)if path:input_data = np.zeros((1, rows, cols))input_data[0, start[0], start[1]] = 1input_data[0, goal[0], goal[1]] = 2output_data = np.zeros((1, rows, cols))for point in path:output_data[0, point[0], point[1]] = 1inputs.append(input_data)outputs.append(output_data)return np.array(inputs), np.array(outputs)# 自定义数据集类
class PathDataset(Dataset):def __init__(self, inputs, outputs):self.inputs = torch.tensor(inputs, dtype=torch.float32)self.outputs = torch.tensor(outputs, dtype=torch.float32)def __len__(self):return len(self.inputs)def __getitem__(self, idx):return self.inputs[idx], self.outputs[idx]# 定义增强版卷积神经网络模型
class EnhancedPathCNN(nn.Module):def __init__(self):super(EnhancedPathCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=5, padding=2)self.relu1 = nn.ReLU()self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)self.relu2 = nn.ReLU()self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)self.relu3 = nn.ReLU()self.conv4 = nn.Conv2d(128, 64, kernel_size=3, padding=1)self.relu4 = nn.ReLU()self.conv5 = nn.Conv2d(64, 1, kernel_size=3, padding=1)def forward(self, x):x = self.relu1(self.conv1(x))x = self.relu2(self.conv2(x))x = self.relu3(self.conv3(x))x = self.relu4(self.conv4(x))x = self.conv5(x)return x# 训练神经网络
def train_model(model, dataloader, criterion, optimizer, epochs):for epoch in range(epochs):running_loss = 0.0for inputs, labels in dataloader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}')# 主程序
if __name__ == "__main__":# 示例地图map_grid = np.array([[0, 0, 0, 0, 0],[0, 1, 1, 0, 0],[0, 0, 0, 0, 0],[0, 0, 1, 1, 0],[0, 0, 0, 0, 0]])# 生成训练数据num_samples = 1000inputs, outputs = generate_training_data(map_grid, num_samples)# 创建数据集和数据加载器dataset = PathDataset(inputs, outputs)dataloader = DataLoader(dataset, batch_size=32, shuffle=True)# 初始化增强版 CNN 模型model = EnhancedPathCNN()# 定义损失函数和优化器criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型epochs = 10train_model(model, dataloader, criterion, optimizer, epochs)# 使用训练好的模型进行路径规划start = (0, 0)goal = (4, 4)input_data = np.zeros((1, 1, map_grid.shape[0], map_grid.shape[1]))input_data[0, 0, start[0], start[1]] = 1input_data[0, 0, goal[0], goal[1]] = 2input_tensor = torch.tensor(input_data, dtype=torch.float32)output = model(input_tensor)output_path = output.detach().numpy()[0, 0]path_points = np.argwhere(output_path > 0.5)print("增强版 CNN 预测的路径点:", path_points)

代码解释

1. A* 算法部分

这部分和之前的实现基本一致,Node 类用于表示地图中的节点,heuristic 函数计算启发式代价,astar 函数通过维护开放列表和关闭列表来搜索从起点到终点的最短路径。

2. 数据生成部分

generate_training_data 函数随机生成起点和终点,调用 A* 算法计算路径,将起点、终点信息作为输入,路径信息作为输出,生成训练数据。

3. 数据集和数据加载器部分

PathDataset 类继承自 torch.utils.data.Dataset,用于封装训练数据,DataLoader 用于批量加载数据并支持随机打乱。

4. 增强版卷积神经网络部分
  • EnhancedPathCNN 类定义了一个更复杂的 CNN 模型,增加了卷积层的数量,同时调整了卷积核的大小。具体来说,前两层使用 5x5 的卷积核,后三层使用 3x3 的卷积核,每层卷积后都接一个 ReLU 激活函数,最后一层卷积输出单通道的预测路径。
  • train_model 函数使用均方误差损失函数(nn.MSELoss)和 Adam 优化器对模型进行训练,在多个 epoch 中不断更新模型参数。
5. 主程序部分

定义示例地图,生成训练数据,创建数据集和数据加载器,初始化增强版 CNN 模型,定义损失函数和优化器,训练模型,最后使用训练好的模型进行路径规划并输出预测的路径点。

注意事项

  • 增加卷积层和调整卷积核大小可能会增加模型的复杂度和训练时间,需要确保有足够的计算资源和训练数据。
  • 可以根据实际情况进一步调整模型结构、卷积核大小、学习率等超参数,以获得更好的性能。
  • CNN 模型预测的路径不一定是最优路径,A* 算法能保证最优路径,但 CNN 可以在效率上有一定提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/14765.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Java的在线购物系统的设计与实现

引言 课题背景 随着Internet国际互联网的发展&#xff0c;越来越多的企业开始建造自己的网站。基于Internet的信息服务&#xff0c;商务服务已经成为现代企业一项不可缺少的内容。很多企业都已不满足于建立一个简单的仅仅能够发布信息的静态网站。现代企业需要的是一个功能强…

cefsharp131升级132测试(WinForms.NETCore)

一、升级&#xff08;Nuget&#xff09; 版本说明&#xff08;readme&#xff09;:最低.NET Core3.1 (NET5.0) Visual C 2019 Redist 二、试运行、兼容性测试 三、后记说明 支持H264版本推荐版本63,79,84,88,100,111,125&#xff08;支持h264和pdf预览&#xff09; 其他H264版…

C#中深度解析BinaryFormatter序列化生成的二进制文件

C#中深度解析BinaryFormatter序列化生成的二进制文件 BinaryFormatter序列化时,对象必须有 可序列化特性[Serializable] 一.新建窗体测试程序BinaryDeepAnalysisDemo,将默认的Form1重命名为FormBinaryDeepAnalysis 二.新建测试类Test Test.cs源程序如下: using System; us…

【实用教程】在 Android Studio 中连接 MuMu 模拟器

MuMu 模拟器是一个非常流行的安卓模拟器&#xff0c;特别适合开发人员进行应用测试&#xff0c;我使用它的根本原因在于Android Studio自带的AVM实现是太难用了&#xff0c;但是Mumu模拟器启动以后不会自动被Android Studio识别到&#xff0c;但是其他模拟器都是能够正常被Andr…

LLAMA-Factory安装教程(解决报错cannot allocate memory in static TLS block的问题)

步骤一&#xff1a; 下载基础镜像 # 配置docker DNS vi /etc/docker/daemon.json # daemon.json文件中 { "insecure-registries": ["https://swr.cn-east-317.qdrgznjszx.com"], "registry-mirrors": ["https://docker.mirrors.ustc.edu.c…

Ollama 部署 DeepSeek-R1 及Open-WebUI

Ollama 部署 DeepSeek-R1 及Open-WebUI 文章目录 Ollama 部署 DeepSeek-R1 及Open-WebUI〇、说明为什么使用本方案 一、 安装Ollama1、主要特点&#xff1a;2、安装3、验证 二、Ollama 部署 DeepSeek1、部署2、模型选用3、Ollama 常用命令4、Ollama模型默认存储路径 安装open-w…

基于微信小程序的医院预约挂号系统的设计与实现

hello hello~ &#xff0c;这里是 code袁~&#x1f496;&#x1f496; &#xff0c;欢迎大家点赞&#x1f973;&#x1f973;关注&#x1f4a5;&#x1f4a5;收藏&#x1f339;&#x1f339;&#x1f339; &#x1f981;作者简介&#xff1a;一名喜欢分享和记录学习的在校大学生…

redis项目

短信登录 这一块我们会使用redis共享session来实现 商户查询缓存 通过本章节&#xff0c;我们会理解缓存击穿&#xff0c;缓存穿透&#xff0c;缓存雪崩等问题&#xff0c;让小伙伴的对于这些概念的理解不仅仅是停留在概念上&#xff0c;更是能在代码中看到对应的内容 优惠…

【嵌入式 Linux 音视频+ AI 实战项目】瑞芯微 Rockchip 系列 RK3588-基于深度学习的人脸门禁+ IPC 智能安防监控系统

前言 本文主要介绍我最近开发的一个个人实战项目&#xff0c;“基于深度学习的人脸门禁 IPC 智能安防监控系统”&#xff0c;全程满帧流畅运行。这个项目我目前全网搜了一圈&#xff0c;还没发现有相关类型的开源项目。这个项目只要稍微改进下&#xff0c;就可以变成市面上目前…

RabbitMQ 从入门到精通:从工作模式到集群部署实战(四)

#作者&#xff1a;闫乾苓 系列前几篇&#xff1a; 《RabbitMQ 从入门到精通&#xff1a;从工作模式到集群部署实战&#xff08;一&#xff09;》&#xff1a;link 《RabbitMQ 从入门到精通&#xff1a;从工作模式到集群部署实战&#xff08;二&#xff09;》&#xff1a; lin…

RabbitMQ 从入门到精通:从工作模式到集群部署实战(五)

#作者&#xff1a;闫乾苓 系列前几篇&#xff1a; 《RabbitMQ 从入门到精通&#xff1a;从工作模式到集群部署实战&#xff08;一&#xff09;》&#xff1a;link 《RabbitMQ 从入门到精通&#xff1a;从工作模式到集群部署实战&#xff08;二&#xff09;》&#xff1a; lin…

mysql 学习11 事务,事务简介,事务操作,事务四大特性,并发事务问题,事务隔离级别

一 事务简介&#xff0c; 数据库准备&#xff1a; create table account(id int auto_increment primary key comment 主键ID,name varchar(128) not null comment 姓名,backaccountnumber char(18) unique comment 银行账号,money float comment 余额 )comment 银行账号表;…

C语言的灵魂——指针(3)

前言&#xff1a;上期我们介绍了const修饰指针&#xff0c;saaert断言都是针对指针本身的&#xff0c;文章后面我们用指针与数组建立了联系&#xff0c;这种联系或者是关系就是这篇文章所要介绍的。上一篇文章的传送门&#xff1a;指针2 指针3 一&#xff0c;数组名的含义及理解…

企业FTP替代升级,实现传输大文件提升100倍!

随着信息技术的飞速发展&#xff0c;网络安全环境也变得越来越复杂。在这种背景下&#xff0c;传统的FTP&#xff08;文件传输协议&#xff09;已经很难满足现代企业对文件传输的需求了。FTP虽然用起来简单&#xff0c;但它的局限性和安全漏洞让它在面对高效、安全的数据交换时…

树和二叉树_7

树和二叉树_7 一、leetcode-102二、题解1.引库2.代码 一、leetcode-102 二叉树的层序遍历 给你二叉树的根节点 root &#xff0c;返回其节点值的 层序遍历 。 &#xff08;即逐层地&#xff0c;从左到右访问所有节点&#xff09;。 样例输入&#xff1a;root [3,9,20,null,nu…

2.8作业

作业 优化登录框&#xff1a; 当用户点击取消按钮&#xff0c;弹出问题对话框&#xff0c;询问是否要确定退出登录&#xff0c;并提供两个按钮&#xff0c;yes|No&#xff0c;如果用户点击的Yes&#xff0c;则关闭对话框&#xff0c;如果用户点击的No&#xff0c;则继续登录 当…

【WB 深度学习实验管理】使用 PyTorch Lightning 实现高效的图像分类实验跟踪

本文使用到的 Jupyter Notebook 可在GitHub仓库002文件夹找到&#xff0c;别忘了给仓库点个小心心~~~ https://github.com/LFF8888/FF-Studio-Resources 在机器学习项目中&#xff0c;实验跟踪和结果可视化是至关重要的环节。无论是调整超参数、优化模型架构&#xff0c;还是监…

人工智能入门 数学基础 线性代数 笔记

必备的数学知识是理解人工智能不可或缺的要素&#xff0c;今天的种种人工智能技术归根到底都建立在数学模型之上&#xff0c;而这些数学模型又都离不开线性代数&#xff08;linear algebra&#xff09;的理论框架。 线性代数的核心意义&#xff1a;世间万事万物都可以被抽象成某…

5 计算机网络

5 计算机网络 5.1 OSI/RM七层模型 5.2 TCP/IP协议簇 5.2.1:常见协议基础 一、 TCP是可靠的&#xff0c;效率低的&#xff1b; 1.HTTP协议端口默认80&#xff0c;HTTPSSL之后成为HTTPS协议默认端口443。 2.对于0~1023一般是默认的公共端口不需要注册&#xff0c;1024以后的则需…

unity碰撞的监测和监听

1.创建一个地面 2.去资源商店下载一个火焰素材 3.把procedural fire导入到自己的项目包管理器中 4.给magic fire 0 挂在碰撞组件Rigidbody , Sphere Collider 5.创建脚本test 并挂在magic fire 0 脚本代码 using System.Collections; using System.Collections.Generic; usi…