资讯

精准传达 • 有效沟通

从品牌网站建设到网络营销策划,从策略到执行的一站式服务

[c++]一个简单的NEAT机器学习寻路实验-创新互联

0x01 NEAT

NEAT(增强拓扑的进化神经网络)是一种基于遗传算法和神经网络的机器学习算法,不同于全连接神经网络,NEAT的神经网络是可以跨层相连的;它由最初始的输入层和输出层神经元连接,迭代繁衍和进化达到最终形态。
具体的介绍可以参考莫凡老师的教程。
最近因为搜索raylib的项目,发现了一个 简单的NEAT库simpleNEAT,觉得挺有趣,就拿来做个实验。

创新互联服务项目包括金塔网站建设、金塔网站制作、金塔网页制作以及金塔网络营销策划等。多年来,我们专注于互联网行业,利用自身积累的技术优势、行业经验、深度合作伙伴关系等,向广大中小型企业、政府机构等提供互联网行业的解决方案,金塔网站推广取得了明显的社会效益与经济效益。目前,我们服务的客户以成都为中心已经辐射到金塔省份的部分城市,未来相信会继续扩大服务区域并继续获得客户的支持与信任!0x02 效果

0x03 代码

simpleNEAT中的lib目录放到项目文件夹,新建main.cpp,写入以下代码,根据需要修改参数:

#include#include#include#include#include "raylib.h"
#include "lib/SimpleNEAT.hpp"

int screenWidth = 1600;  // 设置窗口宽度
int screenHeight = 900;  // 设置窗口高度
float initRotation = 0.f;  // 个体的初始化旋转角度
float objectSize = 15.f;  // 个体的半径
float sensorMax = 300.f;  // 距离传感器长度
float stepSize = 10.f;  // 移动和旋转的步长
Vector2 initPosistion = {100.f, 100.f};  // 个体出生位置,初始在窗口从左到右从上往下 100,100
Vector2 targetPosition = {float(screenWidth) - 100.f, float(screenHeight) - 100.f};  // 目标位置,初始在窗口从右到左从下往上 100,100
bool isStart = false;  // 是否开始训练
int frameLimit = 1500;  // 每一代帧数限制
const int sensorCount = 11;  // 具体传感器数量
int fps = 60;  // 训练时的fps限制
int winnerCount = 0;  // 到达目标的个体计数器

znn::SimpleNeat initNeat() {  // 初始化神经网络和种群
    znn::Opts.InputSize = sensorCount + 2;  // 设置神经网络输入节点数量=传感器数量+个体与目标的相对距离+个体朝向与目标的相对角度
    znn::Opts.OutputSize = 2;  // 设置神经网络输出节点数量
    znn::Opts.ActiveFunction = znn::Sigmoid;  // 使用的激活函数
    znn::Opts.IterationTimes = 0;  // 迭代次数,0为不限制
    znn::Opts.FitnessThreshold = 0.f;  // 个体适应值阈值,0为不限制
    znn::Opts.IterationCheckPoint = 1;  // 保存最优神经网络的迭代次数,1为每代保存
    znn::Opts.ThreadCount = 16;  // 多线程数量,不设置则为设备默认数量
    znn::Opts.MutateAddNeuronRate = 0.03f;  // 添加新神经元的概率
    znn::Opts.MutateAddConnectionRate = 0.99f;  // 添加新连接的概率
    znn::Opts.PopulationSize = 150;  // 训练个体的数量
    znn::Opts.NewSize = 0;  // 每一代新生个体的数量
    znn::Opts.ChampionToNewSize = 90;  // 冠军被复制和交配的目标数量
    znn::Opts.ChampionKeepSize = 15;  // 冠军的数量
    znn::Opts.KeepWorstSize = 0;  // 保留最差个体的数量
    znn::Opts.KeepComplexSize = 1;  // 保留最复杂神经网络个体的数量,用于交配产生更复杂的神经网络
    znn::Opts.WeightRange = 12;  // 神经连接的权重范围,-12至12
    znn::Opts.BiasRange = 6;  // 神经元的偏置范围,-6至6
    znn::Opts.MutateBiasRate = 1.f;  // 神经元的偏置变异概率
    znn::Opts.MutateWeightRate = 1.f;  // 神经元连接的权重变异概率
    znn::Opts.MutateBiasDirectOrNear = 0.5f;  // 神经元偏置随机变异和就近变异的比例
    znn::Opts.MutateWeightDirectOrNear = 0.5f;  // 神经元连接权重随机变异和就近变异的比例
    znn::Opts.Enable3dNN = false;  // 是否显示3d实时可视化神经网络,不能启用,因为用的raylib库,训练环境也用的raylib库,不能开启多窗口
    znn::Opts.CheckPointPath = "/tmp/raylib_path_findder";  // 自动保存神经网络和NEAT创新ID的路径

    srandom((unsigned) clock());  // 初始化随机种子

    znn::SimpleNeat sneat;  // 创建NEAT对象
    sneat.Start();  // 初始化NEAT神经网络和种群,如果自动保存路径存在,则导入,不存在则新建

    return sneat;
}

Vector2 getXY(float angle, float distance) {  // 通过角度和距离计算坐标
    Vector2 result;
    float radians = angle * PI / 180.f;
    result.x = distance * std::cos(radians);
    result.y = distance * std::sin(radians);
    return result;
}

struct myWall {  // 障碍物
    std::vectorpath;  // 存储坐标的容器

    void add() {  // 添加坐标
        Vector2 mousePos = GetMousePosition();
        if (path.empty() || (!path.empty() && path[path.size() - 1].x != mousePos.x && path[path.size() - 1].y != mousePos.y)) {  // 判断是否和上一个坐标重复
            path.push_back(mousePos);
        }
    }

    void draw() {  // 绘制障碍
        if (!path.empty()) {
            for (int i = 1; i< path.size(); ++i) {
                DrawLineEx(path[i - 1], path[i], 1.f, WHITE);
            }
        }
    }
};

std::vectorwalls;  // 存储多个障碍

myWall createScreenWall() {  // 创建窗口四周的障碍
    myWall screenWall;
    screenWall.path.push_back({0, 0});
    screenWall.path.push_back({0, float(screenHeight)});
    screenWall.path.push_back({float(screenWidth), float(screenHeight)});
    screenWall.path.push_back({float(screenWidth), 0});
    screenWall.path.push_back({0, 0});
    return screenWall;
}

bool getCollion(Vector2 center, Vector2 sensorTail, Vector2 &collisionPoint, float &sensorDistance) {  // 根据两条线的起止坐标判断是否相交
    std::mapdis2Pos;
    bool isCollision = false;

    for (auto &w: walls) {
        for (int i = 1; i< w.path.size(); ++i) {
            Vector2 collisionPos;
            if (CheckCollisionLines(center, sensorTail, w.path[i - 1], w.path[i], &collisionPos)) {
                float distance = std::sqrt(std::pow(collisionPos.x - center.x, 2.f) + std::pow(collisionPos.y - center.y, 2.f));
                dis2Pos[distance] = collisionPos;
                isCollision = true;
            }
        }
    }

    collisionPoint = dis2Pos.begin()->second;
    sensorDistance = dis2Pos.begin()->first;
    return isCollision;
}

struct object {  // 训练个体
    float rotation = initRotation;  // 初始旋转角度
    Vector2 position = initPosistion;  // 出生位置
    bool isDead = false;  // 是否死亡
    float speed = 0.f;  // 速度
    std::vectorpath;  // 走过的路径
    std::vectorpathWidth;  // 走过路径对应的路宽,根据速度判断
    Vector2 sensorsPos[sensorCount]{};  // 距离传感器的相对末端坐标
    Vector2 sensorCol[sensorCount]{};  // 距离传感器的相对探测到障碍的交叉坐标
    float sensorDis[sensorCount];  // 距离传感器到障碍的长度
    float score = 0.f;  // 记录得分
    float targetAngle = std::atan2((targetPosition.y - position.y), (targetPosition.x - position.x)) / PI * 180.f - float(int(rotation + 180.f) * 10 % 3600 - 1800) / 10.f;  // 个体朝向和目标的相对角度
    float targetDistance = std::sqrt(std::pow(position.x - targetPosition.x, 2.f) + std::pow(position.y - targetPosition.y, 2.f));  // 个体和目标的距离
    float beginDistance = 0.f;  // 个体和出生位置的距离

    void setSensors() {  // 放置具体传感器
        for (int i = 0; i< sensorCount; ++i) {
            Vector2 sensorTail = getXY(float(i) * 30.f + rotation - 150.f, sensorMax);
            sensorsPos[i].x = sensorTail.x + position.x;
            sensorsPos[i].y = sensorTail.y + position.y;
        }
    }

    void getSensorsInfo() {  // 更新传感器数据
        for (int i = 0; i< sensorCount; ++i) {
            if (!getCollion(position, sensorsPos[i], sensorCol[i], sensorDis[i])) {
                sensorCol[i] = sensorsPos[i];  // 传感器与障碍物交叉的位置,没有交叉则为传感器目标位置
                sensorDis[i] = sensorMax;  // 传感器与障碍之间的距离,没有交叉则为预设大值
            }

            if (sensorDis[i]< objectSize) {
                isDead = true;  // 如果传感器道障碍的距离小于个体半径,则判断为死亡
            }
        }

        if (std::abs(position.x - targetPosition.x)< 10.f && std::abs(position.y - targetPosition.y)< 10.f) {  // 判断个体是否到达目标坐标
            isDead = true;  // 到达坐标则死亡
            ++winnerCount;  // 达到目标的个体计数器更新
            score += 10000.f;  // 到达目标加分
        }

        if (path.size() >2) {  // 判断个体是否走了老路,通过路径记录和碰撞判断
            for (int i = 1; i< path.size() - 2; ++i) {
                if (CheckCollisionLines(path[path.size() - 1], path[path.size() - 2], path[i], path[i - 1], nullptr)) {
                    isDead = true;  // 如果和自己的运动轨迹碰撞则死亡
                    break;
                }
            }
        }

        targetDistance = std::sqrt(std::pow(position.x - targetPosition.x, 2.f) + std::pow(position.y - targetPosition.y, 2.f));  // 更新个体和目标的距离
        if (targetDistance< 1.f) {  // 为便于分数判定,需要将目标距离作为被除数
            targetDistance = 1.f;  // 如果距离小于1则为1,避免分数特别大
        }
        targetAngle = std::atan2((targetPosition.y - position.y), (targetPosition.x - position.x)) / PI * 180.f - float(int(rotation + 180.f) * 10 % 3600 - 1800) / 10.f;  // 更新个体朝向与目标的相对角度
        if (std::abs(targetAngle)< 10.f) {  // 如果相对角度小于10,则加分
            score += 1.f;
        }
    }

    object() {  // 创建个体时的初始化操作
        setSensors();  // 更新传感器位置
        getSensorsInfo();  // 更新传感器数据
    }

    void rotate(float angle) {  // 个体旋转操作
        rotation = float(int((rotation + angle * stepSize) * 10.f) % 3600) / 10.f;
        setSensors();
        getSensorsInfo();
    }

    void move(float distance) {  // 个体移动操作
        speed = distance * 30.f;  // 更新速度,用于可视化尾喷长度
        score += speed;  // 更新分数,叠加速度
        Vector2 movePos = getXY(rotation, distance * stepSize);  // 获取需要移动的相对坐标
        position.x += movePos.x;  // 更新个体坐标x
        position.y += movePos.y;  // 更新个体坐标y
        if (distance >0.01f) {  // 如果个体移动距离太小,则判断为死亡
            path.push_back(position);
            if (distance >0.3f) {  // 为防止可视化路径的时候宽度太小,则设置最小宽度0.3
                pathWidth.push_back(distance);
            } else {
                pathWidth.push_back(0.3f);
            }
        } else {
            isDead = true;
        }
        setSensors();  // 更新传感器位置
        getSensorsInfo();  // 更新传感器数据
    }

    void draw() {  // 绘制个体
        if (path.size() >1) {  // 绘制个体移动路径,线条需要两个坐标
            for (int i = 1; i< path.size(); ++i) {
                DrawLineEx(path[i - 1], path[i], 1., ColorAlpha(GREEN, pathWidth[i] * 0.3f));  // 路径宽度改为路径透明度由宽度判定
            }
        }

        if (!isDead) {  // 如果个体存活则绘制传感器
            for (auto sc: sensorCol) {
                DrawLineEx(position, sc, 1, ColorAlpha(BLUE, 0.5f));
                DrawCircleV(sc, objectSize / 10.f * 3.f, ColorAlpha(RED, 0.5f));
            }
        }

        auto objColor = WHITE;  // 如果个体存活,则本体为白色,死亡为红色
        if (isDead) {
            objColor = RED;
        }

        DrawPolyLinesEx(position, 3, objectSize, rotation + 30.f, objectSize / 5.f, objColor);  // 绘制个体本体,三角形
        Vector2 headPos = getXY(rotation, objectSize);  // 获取头部坐标用于给个体头部画一根线
        DrawLineEx(position, {headPos.x + position.x, headPos.y + position.y}, 1, objColor);  // 给个体头部画一根线分辨方向

        if (!isDead) {  // 如果存活则绘制尾喷
            Vector2 tailPos = getXY(rotation + 180.f, speed);  // 获取尾喷相对位置用于绘制,长度由速度决定
            DrawLineEx(position, {tailPos.x + position.x, tailPos.y + position.y}, objectSize / 10.f * 3.f, YELLOW);
        }
    }
};

void keyControl() {  // 用户输入控制
    if (IsMouseButtonDown(0)) {  // 鼠标左键绘制障碍
        walls[walls.size() - 1].add();
    }

    if (IsMouseButtonReleased(0)) {  // 鼠标左键抬起用于添加新的障碍物列表,避免只绘制一条线
        walls.push_back(myWall{});
    }

    if (IsMouseButtonDown(1)) {  // 鼠标右键清除障碍
        walls.clear();
        walls.push_back(createScreenWall());  // 清除障碍以后先添加窗口四周的障碍
        walls.push_back(myWall{});
    }

    if (IsKeyPressed('B')) {  // B键用于设置个体出生位置
        initPosistion = GetMousePosition();
        initRotation = std::atan2(targetPosition.y - initPosistion.y, targetPosition.x - initPosistion.x) / PI * 180.f;  // 设置完出生位置后更新个体初始朝向
    }

    if (IsKeyPressed('T')) {  // T键用于设置目标位置
        targetPosition = GetMousePosition();
        initRotation = std::atan2(targetPosition.y - initPosistion.y, targetPosition.x - initPosistion.x) / PI * 180.f;  // 设置完目标位置后更新个体初始朝向
    }

    if (IsKeyPressed(KEY_SPACE)) {  // 空格键用于控制是否开始训练
        if (isStart) {
            isStart = false;
            SetTargetFPS(30);  // 没开始训练时帧率限制为30
        } else {
            isStart = true;
            SetTargetFPS(fps);
        }
    }
}

bool isBreakFunc() {  // 用于NEAT训练循环中判断是否中断
    return !isStart;
};

int main() {
    SetConfigFlags(FLAG_MSAA_4X_HINT);  // 设置抗锯齿

    InitWindow(screenWidth, screenHeight, "寻路实验");  // 初始化raylib窗口

    SetTargetFPS(30);  // 设置帧率

    walls.push_back(createScreenWall());  // 创建窗口四周障碍
    walls.push_back(myWall{});

    initRotation = std::atan2(targetPosition.y - initPosistion.y, targetPosition.x - initPosistion.x) / PI * 180.f;  // 初始化个体朝向

    auto sneat = initNeat();  // 初始化神经网络和种群

    int stepCount = 0;  // 训练迭代计数器
    std::function()>interactiveFunc = [&]() {
        ++stepCount;

        std::vectorobjs;  // 新建个体集容器
        for (int i = 0; i< znn::Opts.PopulationSize; ++i) {  // 塞满个体
            objs.emplace_back();
        }

        std::mappopulationFitness;  // 创建神经网络地址对应的适应度map

        for (int step = 0; step< frameLimit; ++step) {  // 每一代训练,基于帧数限制的循环
            keyControl();  // 用户输入控制

            if (stepCount % znn::Opts.IterationCheckPoint == 0) {  // 如果达到自动保存次数,则可视化显示
                BeginDrawing();  // 开始绘制
                ClearBackground(BLACK);  // 清空背景

                DrawCircleV(initPosistion, 10.f, GRAY);  // 绘制出生点
                DrawCircleV(targetPosition, 10.f, RED);  // 绘制目标点

                for (auto &w: walls) {  // 绘制障碍
                    w.draw();
                }
            }

            int deadCount = 0;  // 死亡个体计数器

            for (int i = 0; i< znn::Opts.PopulationSize; ++i) {  // 每个训练个体的神经网络判断输入和输出
                if (!objs[i].isDead) {  // 如果个体存活则继续
                    if ((step >100 && objs[i].path.size()< 50) || (objs[i].path.size() >100 && std::abs(objs[i].path[objs[i].path.size() - 1].y - objs[i].path[objs[i].path.size() - 100].y)< 3 &&
                                                                     std::abs(objs[i].path[objs[i].path.size() - 1].x - objs[i].path[objs[i].path.size() - 100].x)< 3) || objs[i].position.x< 0 ||
                        objs[i].position.x >float(screenWidth) || objs[i].position.y< 0 || objs[i].position.y >float(screenHeight)) {  // 简单判断死亡
                        objs[i].isDead = true;
                    } else {
                        std::vectorperInputs;  // 准备神经网络输入数据

                        for (auto &sd: objs[i].sensorDis) {  // 输入数据放入传感器到障碍物的距离
                            perInputs.push_back(1.f - ((sd-objectSize) / (sensorMax-objectSize)));  // 距离除以传感器大值,离得越近数值越大,同时排除个体自身尺寸,使得输入值在0-1范围
                        }
                        perInputs.push_back(objs[i].targetAngle);  // 输入数据放入个体朝向到目标的相对角度
                        perInputs.push_back(objs[i].targetDistance);  // 输入数据放入个体和目标的距离

                        std::vectornextMove = sneat.population.generation.neuralNetwork.FeedForwardPredict(&sneat.population.NeuralNetworks[i], perInputs);  // 根据输入数据计算每个神经网络的输出
                        objs[i].rotate((nextMove[0] - 0.5f) * 2.f);  // 执行输出结果的旋转操作
                        objs[i].move(nextMove[1]);  // 执行输出结果的移动操作

                        if (perInputs[(sensorCount - 1) / 2] >0.f && nextMove[1]< 1.f) {  // 简单判断个体前方有障碍则减速的加分
                            objs[i].score += 1;
                        }
                    }
                } else {  // 如果个体死亡则更新死亡计数器
                    ++deadCount;
                }

                if (stepCount % znn::Opts.IterationCheckPoint == 0) {  // 如果达到自动保存次数,则可视化显示
                    objs[i].draw();  // 绘制个体
                }
            }

            if (stepCount % znn::Opts.IterationCheckPoint == 0) {  // 如果达到自动保存次数,则可视化显示
                DrawFPS(10, 10);  // 绘制fps
                EndDrawing();  // 单帧绘制完毕
            }

            if (deadCount == znn::Opts.PopulationSize) {  // 如果全部个体死亡则终止本代
                break;
            }
        }

        for (int i = 0; i< znn::Opts.PopulationSize; ++i) {  // 更新每个个体的适应度(得分)
            objs[i].beginDistance = std::sqrt(std::pow(objs[i].position.x - initPosistion.x, 2.f) + std::pow(objs[i].position.y - initPosistion.y, 2.f));
            populationFitness[&sneat.population.NeuralNetworks[i]] = (objs[i].beginDistance + objs[i].score * 5.f) / objs[i].targetDistance;
        }

        std::cout<< "Winner: "<< winnerCount<< "\n";
        winnerCount = 0;  // 重置达到目标的个体计数

        return populationFitness;  // NEAT训练函数的格式
    };

    while (!WindowShouldClose()) {  // 判断窗口是否关闭
        while (!isStart) {  // 如果没开始训练,则不更新和绘制个体
            keyControl();

            BeginDrawing();

            ClearBackground(BLACK);

            DrawCircleV(initPosistion, 10.f, GRAY);
            DrawCircleV(targetPosition, 10.f, RED);

            for (auto &w: walls) {
                w.draw();
            }

            DrawFPS(10, 10);
            EndDrawing();
        }

        stepCount = 0;  // 重置训练迭代计数器

        auto best = sneat.TrainByInteractive(interactiveFunc, isBreakFunc);  // NEAT训练函数,开始训练

        printf("Pause\n");  // 如果训练循环终止,则重新开始训练
    }

    CloseWindow();   // 关闭窗口

    return 0;
}0x04 编译
g++ -lraylib -std=c++17 -O2 main.cpp
0x05 运行

执行编译后生成的a.out:

./a.out
0x05 用法
  1. 鼠标左键 绘制障碍
  2. 鼠标右键 清除障碍
  3. B 键设置个体出生位置
  4. T 键设置目标位置
  5. 空格键 开始/暂停训练

训练开始以后可以实时用添加障碍

你是否还在寻找稳定的海外服务器提供商?创新互联www.cdcxhl.cn海外机房具备T级流量清洗系统配攻击溯源,准确流量调度确保服务器高可用性,企业级服务器适合批量采购,新人活动首月15元起,快前往官网查看详情吧


网页名称:[c++]一个简单的NEAT机器学习寻路实验-创新互联
网页链接:http://cdkjz.cn/article/dcssch.html
多年建站经验

多一份参考,总有益处

联系快上网,免费获得专属《策划方案》及报价

咨询相关问题或预约面谈,可以通过以下方式与我们联系

大客户专线   成都:13518219792   座机:028-86922220