Skip to content

Commit

Permalink
more ai work
Browse files Browse the repository at this point in the history
  • Loading branch information
Robosturm committed Nov 1, 2023
1 parent bd3cb61 commit b1ff542
Show file tree
Hide file tree
Showing 14 changed files with 383 additions and 39 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,7 @@ set(${PROJECT_NAME}_SRCS
ai/dummyai.h ai/dummyai.cpp
ai/capturebuildingselector.h ai/capturebuildingselector.cpp
ai/aiprocesspipe.h ai/aiprocesspipe.cpp
ai/trainingdatagenerator.h ai/trainingdatagenerator.cpp
# production system
ai/productionSystem/simpleproductionsystem.h ai/productionSystem/simpleproductionsystem.cpp

Expand Down
85 changes: 71 additions & 14 deletions ai/heavyai/heavyAiSharedData.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,79 @@ using spUnitTargetedPathFindingSystem = std::shared_ptr<UnitTargetedPathFindingS
namespace HeavyAiSharedData
{

static constexpr qint32 UNIT_COUNT = 40;
static constexpr qint32 SEARCH_RANGE = 30;
static constexpr qint32 UNIT_COUNT = 50;
static constexpr qint32 UNIT_SEARCH_RANGE = 40; // search range for the situation evaluator
static constexpr qint32 TILE_SEARCH_RANGE = 20; // search range for selecting on which tile
static constexpr qint32 HQ_IMPORTANCE = 1024;

enum SituationFeatures
{
Distance,
HP,
HpDamage,
FundsDamage,
MovementPoints,
HasMoved,
Defense,
RepairsOnPosition,
CapturePoints,
BuildingImportance,
Stealthed,
MaxFeatures,
SituationFeatures_Distance,
SituationFeatures_HP,
SituationFeatures_HpDamage,
SituationFeatures_FundsDamage,
SituationFeatures_MovementPoints,
SituationFeatures_HasMoved,
SituationFeatures_Defense,
SituationFeatures_RepairsOnPosition,
SituationFeatures_CapturePoints,
SituationFeatures_BuildingImportance,
SituationFeatures_Stealthed,
SituationFeatures_MinFireRange,
SituationFeatures_MaxFireRange,
SituationFeatures_MaxFeatures,
};

enum BuildingFeatures
{
BuildingFeatures_RemainingGroundProductions,
BuildingFeatures_RemainingSeaProductions,
BuildingFeatures_RemainingAirProductions,
BuildingFeatures_BuildUnitsThisTurn,
BuildingFeatures_Funds,
BuildingFeatures_Costs,
BuildingFeatures_AverageDealingDamage,
BuildingFeatures_AverageReceivingDamage,
BuildingFeatures_AverageDistance,
BuildingFeatures_CounteringUnits,
BuildingFeatures_CoBonus,
BuildingFeatures_MovementPoints,
BuildingFeatures_CanMoveAndFire,
BuildingFeatures_MinFireRange,
BuildingFeatures_MaxFireRange,
BuildingFeatures_CurrentDay,
BuildingFeatures_CanCapture,
BuildingFeatures_OwnUnitCount,
BuildingFeatures_EnemyUnitCount,
};

enum UnitSelectFeatures
{
UnitSelectFeatures_MovementPoints,
UnitSelectFeatures_Costs,
UnitSelectFeatures_MinFireRange,
UnitSelectFeatures_MaxFireRange,
UnitSelectFeatures_CanMoveAndFire,
UnitSelectFeatures_EnemiesInRange,
UnitSelectFeatures_AlliesNearby,
};

enum TileSelectFeatures
{
TileSelectFeatures_TerrainDefense,
TileSelectFeatures_CanReach,
TileSelectFeatures_Movecost,
TileSelectFeatures_AverageReceivingDamage,
TileSelectFeatures_EnemiesInRange,
TileSelectFeatures_NeededPowerCharge,
TileSelectFeatures_Building
};

enum SituationOutput
{
SituationOutput_Lost = -1,
SituationOutput_Draw = 0,
SituationOutput_Win = 1,
};

enum AiCache
Expand All @@ -52,6 +107,8 @@ struct UnitInfo
spUnitTargetedPathFindingSystem pUnitTargetedPathFindingSystem;
};

static constexpr qint32 INPUT_VECTOR_SIZE = HeavyAiSharedData::UNIT_COUNT * HeavyAiSharedData::UNIT_COUNT * HeavyAiSharedData::SituationFeatures::SituationFeatures_MaxFeatures;

using spUnitInfo = std::shared_ptr<UnitInfo>;

}
7 changes: 4 additions & 3 deletions ai/heavyai/simulationmap.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class SimulationMap final : public GameMap
void loadStartState(GameMap* pMap);
void restoreInitialState();

void moveUnit(QPoint currentPos, QPoint newPos);
void dealDamage(QPoint position, float damage);
void increaseCapturePoints(QPoint position, qint32 points);
Q_INVOKABLE void moveUnit(QPoint currentPos, QPoint newPos);
Q_INVOKABLE void dealDamage(QPoint position, float damage);
Q_INVOKABLE void increaseCapturePoints(QPoint position, qint32 points);
private:
void resetMoveUnit(SimulationStep & step);
void resetDealDamage(SimulationStep & step);
Expand All @@ -35,3 +35,4 @@ class SimulationMap final : public GameMap
std::vector<SimulationStep> m_steps;
};

Q_DECLARE_INTERFACE(SimulationMap, "SimulationMap");
54 changes: 48 additions & 6 deletions ai/heavyai/situationevaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
#include "game/unit.h"

SituationEvaluator::SituationEvaluator(Player* pOwner)
: m_inputVector(1, HeavyAiSharedData::UNIT_COUNT * HeavyAiSharedData::UNIT_COUNT * HeavyAiSharedData::SituationFeatures::MaxFeatures),
m_searchRange(GlobalUtils::getSpCircle(0, HeavyAiSharedData::SEARCH_RANGE)),
: m_inputVector(1, HeavyAiSharedData::INPUT_VECTOR_SIZE),
m_searchRange(GlobalUtils::getSpCircle(0, HeavyAiSharedData::UNIT_SEARCH_RANGE)),
m_pOwner(pOwner)
{
for (qint32 i = 0; i < HeavyAiSharedData::UNIT_COUNT; ++i)
Expand Down Expand Up @@ -63,6 +63,16 @@ void SituationEvaluator::updateInputVector(GameMap* pMap, const QPoint & searchP
}
}

QString SituationEvaluator::getInputVector() const
{
QString input;
for (qint32 i = 0; i < HeavyAiSharedData::INPUT_VECTOR_SIZE; ++i)
{
input += QString::number(m_inputVector(0, i)) + ";";
}
return input;
}

float SituationEvaluator::getOutput()
{
auto inputDimension = opennn::get_dimensions(m_inputVector);
Expand All @@ -72,7 +82,7 @@ float SituationEvaluator::getOutput()

void SituationEvaluator::clearUnitInput(qint32 index)
{
for (qint32 feature = 0; feature < HeavyAiSharedData::SituationFeatures::MaxFeatures; ++feature)
for (qint32 feature = 0; feature < HeavyAiSharedData::SituationFeatures::SituationFeatures_MaxFeatures; ++feature)
{
qint32 basePosition = HeavyAiSharedData::UNIT_COUNT * HeavyAiSharedData::UNIT_COUNT * feature + index * HeavyAiSharedData::UNIT_COUNT;
for (qint32 enemyUnit = 0; enemyUnit < HeavyAiSharedData::UNIT_COUNT; ++enemyUnit)
Expand All @@ -84,11 +94,11 @@ void SituationEvaluator::clearUnitInput(qint32 index)

void SituationEvaluator::fillUnitInput(qint32 index)
{
for (qint32 feature = 0; feature < HeavyAiSharedData::SituationFeatures::MaxFeatures; ++feature)
for (qint32 feature = 0; feature < HeavyAiSharedData::SituationFeatures::SituationFeatures_MaxFeatures; ++feature)
{
qint32 basePosition = HeavyAiSharedData::UNIT_COUNT * HeavyAiSharedData::UNIT_COUNT * feature + index * HeavyAiSharedData::UNIT_COUNT;
using updateFeature = void (SituationEvaluator::*)(qint32 basePosition, const HeavyAiSharedData::spUnitInfo & unitInfo);
constexpr std::array<updateFeature, HeavyAiSharedData::SituationFeatures::MaxFeatures> featureCb{
constexpr std::array<updateFeature, HeavyAiSharedData::SituationFeatures::SituationFeatures_MaxFeatures> featureCb{
&SituationEvaluator::updateDistance,
&SituationEvaluator::updateHp,
&SituationEvaluator::updateHpDamage,
Expand All @@ -100,6 +110,8 @@ void SituationEvaluator::fillUnitInput(qint32 index)
&SituationEvaluator::updateCapturePoints,
&SituationEvaluator::updateBuildingImportance,
&SituationEvaluator::updateStealthed,
&SituationEvaluator::updateMinFirerange,
&SituationEvaluator::updateMaxFirerange,
};
(this->*featureCb[feature])(basePosition, m_unitsInfo[index]);
}
Expand Down Expand Up @@ -155,7 +167,7 @@ void SituationEvaluator::updateHpDamage(qint32 basePosition, const HeavyAiShared
void SituationEvaluator::updateFundsDamage(qint32 basePosition, const HeavyAiSharedData::spUnitInfo & unitInfo)
{

qint32 hpOffset = HeavyAiSharedData::UNIT_COUNT * HeavyAiSharedData::UNIT_COUNT * (HeavyAiSharedData::SituationFeatures::HpDamage - HeavyAiSharedData::SituationFeatures::HP);
qint32 hpOffset = HeavyAiSharedData::UNIT_COUNT * HeavyAiSharedData::UNIT_COUNT * (HeavyAiSharedData::SituationFeatures::SituationFeatures_HpDamage - HeavyAiSharedData::SituationFeatures::SituationFeatures_HP);
for (qint32 enemyUnit = 0; enemyUnit < HeavyAiSharedData::UNIT_COUNT; ++enemyUnit)
{
if (shouldFillInfo(unitInfo, enemyUnit))
Expand Down Expand Up @@ -247,6 +259,36 @@ void SituationEvaluator::updateRepairsOnPosition(qint32 basePosition, const Heav
}
}

void SituationEvaluator::updateMinFirerange(qint32 basePosition, const HeavyAiSharedData::spUnitInfo & unitInfo)
{
for (qint32 enemyUnit = 0; enemyUnit < HeavyAiSharedData::UNIT_COUNT; ++enemyUnit)
{
if (shouldFillInfo(unitInfo, enemyUnit))
{
m_inputVector(0, basePosition + enemyUnit) = m_unitsInfo[enemyUnit]->pUnit->getAiCache()[HeavyAiSharedData::AiCache::MinFirerange];
}
else
{
m_inputVector(0, basePosition + enemyUnit) = 0;
}
}
}

void SituationEvaluator::updateMaxFirerange(qint32 basePosition, const HeavyAiSharedData::spUnitInfo & unitInfo)
{
for (qint32 enemyUnit = 0; enemyUnit < HeavyAiSharedData::UNIT_COUNT; ++enemyUnit)
{
if (shouldFillInfo(unitInfo, enemyUnit))
{
m_inputVector(0, basePosition + enemyUnit) = m_unitsInfo[enemyUnit]->pUnit->getAiCache()[HeavyAiSharedData::AiCache::MaxFirerange];
}
else
{
m_inputVector(0, basePosition + enemyUnit) = 0;
}
}
}

void SituationEvaluator::updateCapturePoints(qint32 basePosition, const HeavyAiSharedData::spUnitInfo & unitInfo)
{
for (qint32 enemyUnit = 0; enemyUnit < HeavyAiSharedData::UNIT_COUNT; ++enemyUnit)
Expand Down
8 changes: 8 additions & 0 deletions ai/heavyai/situationevaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ class SituationEvaluator : public QObject
* @return
*/
float getOutput();
/**
* @brief getInputVector
* @return
*/
QString getInputVector() const;
private:
void getUnitsInRange(GameMap* pMap, const QPoint & searchPoint);
void createPathFindingSystems(GameMap* pMap);
Expand All @@ -56,6 +61,8 @@ class SituationEvaluator : public QObject
void updateStealthed(qint32 basePosition, const HeavyAiSharedData::spUnitInfo & unitInfo);
void updateBuildingImportance(qint32 unitPosition);
void updateStealthInfo(GameMap* pMap, qint32 unitPosition);
void updateMinFirerange(qint32 basePosition, const HeavyAiSharedData::spUnitInfo & unitInfo);
void updateMaxFirerange(qint32 basePosition, const HeavyAiSharedData::spUnitInfo & unitInfo);
private:
opennn::Tensor<opennn::type, 2> m_inputVector;
opennn::NeuralNetwork m_neuralNetwork;
Expand All @@ -64,3 +71,4 @@ class SituationEvaluator : public QObject
Player* m_pOwner{nullptr};
};

Q_DECLARE_INTERFACE(SituationEvaluator, "SituationEvaluator");
127 changes: 127 additions & 0 deletions ai/trainingdatagenerator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#include "ai/trainingdatagenerator.h"
#include "ai/coreai.h"
#include "ai/heavyai/situationevaluator.h"
#include "coreengine/memorymanagement.h"
#include "coreengine/settings.h"
#include "game/gamemap.h"
#include "game/gameaction.h"

TrainingDataGenerator::TrainingDataGenerator(GameMap* pMap)
: m_pMap(pMap)
{
#ifdef GRAPHICSUPPORT
setObjectName("TrainingDataGenerator");
#endif
Interpreter::setCppOwnerShip(this);
init();
}

void TrainingDataGenerator::init()
{
m_evaluators.clear();
m_data.clear();
if (Settings::getInstance()->getCreateAiTrainingData())
{
for (qint32 i = 0; i < m_pMap->getPlayerCount(); ++i)
{
m_evaluators.push_back(MemoryManagement::create<SituationEvaluator>(m_pMap->getPlayer(i)));
m_data.push_back(QStringList());
}
}
}

void TrainingDataGenerator::onActionDone(GameAction* pAction)
{
if (Settings::getInstance()->getCreateAiTrainingData())
{
if (pAction->getActionID() == CoreAI::ACTION_BUILD_UNITS)
{

}
else if (pAction->getMovePathLength() > 0)
{
qint32 player = m_pMap->getCurrentPlayer()->getPlayerID();
m_evaluators[player]->updateInputVector(m_pMap, pAction->getActionTarget(), true);
m_data[player].append(m_evaluators[player]->getInputVector());
}
}
}

void TrainingDataGenerator::saveDataToFile()
{
saveDataToFile(m_pMap->getMapName() + QDateTime::currentDateTime().toString("dd-MM-yyyy-hh-mm-ss") + ".csv");
}

void TrainingDataGenerator::saveDataToFile(const QString & filepath)
{
if (Settings::getInstance()->getCreateAiTrainingData())
{
qint32 winner = m_pMap->getWinnerTeam();
if (winner >= 0 || m_pMap->getGameRules()->getDrawVotingResult() == GameEnums::DrawVoting_Yes)
{
QFile file("situationResults_" + filepath);
file.open(QFile::OpenModeFlag::WriteOnly);
QTextStream stream(&file);
for (qint32 i = 0; i < m_data.size(); ++i)
{
QString output;
if (winner < 0)
{
output = "0";
}
else if (winner == m_pMap->getPlayer(i)->getTeam())
{
output = "1";
}
else
{
output = "-1";
}
for (qint32 i2 = 0; i2 < m_data.size(); ++i2)
{
stream << m_data[i][i2] << ";" << output;
}
}
}
}
}

void TrainingDataGenerator::serializeObject(QDataStream& pStream) const
{
serializeObject(pStream, false);
}

void TrainingDataGenerator::serializeObject(QDataStream& pStream, bool forHash) const
{
pStream << getVersion();
pStream << static_cast<qint32>(m_data.size());
for (auto & data : m_data)
{
pStream << static_cast<qint32>(data.size());
for (auto & item : data)
{
pStream << item;
}
}
}

void TrainingDataGenerator::deserializeObject(QDataStream& pStream)
{
init();
qint32 version = 0;
pStream >> version;
qint32 size = 0;
pStream >> size;
for (qint32 i = 0; i < size; ++i)
{
m_data.push_back(QStringList());
qint32 size2 = 0;
pStream >> size2;
for (qint32 i2 = 0; i2 < size2; ++i2)
{
QString item;
pStream >> item;
m_data[i].push_back(item);
}
}
}
Loading

0 comments on commit b1ff542

Please sign in to comment.