为什么我的 AI 模型可以训练但不会进化 - ML Agents
Why is my AI model trains but doesn't evolve - ML Agents
创建了一个简单的统一游戏,球应该击中目标而不击中墙壁。于是,开始训练,结果太糟糕了。球只是收集 4 个目标之一。但是 EndEpisode() 在收集到最后一个目标时发生。
Screenshot of the scene and the balls path throughout the training of 1650,000 steps(if im not wrong, since I called it a generation for every 10,000 steps of training.)
球甚至没有尝试击中第二个目标。我的代码有什么问题?
我什至尝试过用 RayPerceptionSensor3D 将球体替换为圆柱体,这样它就不会翻转并干扰 rayperceptionSensor3d。但它给出了更糟糕的结果。
using System.Security.Cryptography;
using System.Data.SqlTypes;
using System.Security;
using System.Runtime.InteropServices;
using System.Net.Sockets;
using System.ComponentModel.Design.Serialization;
using System.Collections.Generic;
using UnityEngine;
using MLAgents;
using MLAgents.Sensors;
using TMPro;
public class MazeRoller : Agent
{
Rigidbody rBody;
Vector3 ballpos;
void Start () {
rBody = GetComponent<Rigidbody>();
ballpos = rBody.transform.position;
}
public TextMeshPro text;
public TextMeshPro miss;
public TextMeshPro hit;
int count=0,c=0,h=0,m=0;
int boxescollect=0;
public Transform Target;
public Transform st1;
public Transform st2;
public Transform st3;
public override void OnEpisodeBegin()
{
rBody.angularVelocity = Vector3.zero;
rBody.velocity = Vector3.zero;
rBody.transform.position = ballpos;
boxescollect=0;
st1.GetComponent<Renderer> ().enabled = true;
st1.GetComponent<Collider> ().enabled = true;
st2.GetComponent<Renderer> ().enabled = true;
st2.GetComponent<Collider> ().enabled = true;
st3.GetComponent<Renderer> ().enabled = true;
st3.GetComponent<Collider> ().enabled = true;
}
void OnCollisionEnter(Collision collision)
{
if(collision.gameObject.name == "Target")
{
if(st1.GetComponent<Renderer> ().enabled==true || st2.GetComponent<Renderer> ().enabled==true || st3.GetComponent<Renderer> ().enabled==true)
{
SetReward(-3.0f+(float)(boxescollect));
}
SetReward(2.0f);
h++;
hit.SetText(h+"");
EndEpisode();
}
else if(collision.gameObject.name == "Target1")
{
boxescollect++;
AddReward(0.2f);
st1.GetComponent<Renderer> ().enabled = false;
st1.GetComponent<Collider> ().enabled = false;
}
else if(collision.gameObject.name == "Target2")
{
boxescollect++;
AddReward(0.4f);
st2.GetComponent<Renderer> ().enabled = false;
st2.GetComponent<Collider> ().enabled = false;
}
else if(collision.gameObject.name == "Target3")
{
boxescollect++;
AddReward(0.6f);
st3.GetComponent<Renderer> ().enabled = false;
st3.GetComponent<Collider> ().enabled = false;
}
//collision.gameObject.name == "wall1"||collision.gameObject.name == "wall2"||collision.gameObject.name == "wall3"||collision.gameObject.name == "wall4"||collision.gameObject.name == "wall5"||collision.gameObject.name == "wall6"||collision.gameObject.name == "wall7"
else if(collision.gameObject.tag == "wall")
{
if(st1.GetComponent<Renderer> ().enabled==true || st2.GetComponent<Renderer> ().enabled==true || st3.GetComponent<Renderer> ().enabled==true)
{
AddReward(-3.0f+(float)(boxescollect));
}
SetReward(-1.0f);
m++;
miss.SetText(m+"");
EndEpisode();
}
}
public override void CollectObservations(VectorSensor sensor)
{
// Target and Agent positions
sensor.AddObservation(Target.position);
sensor.AddObservation(this.transform.position);
sensor.AddObservation(boxescollect);
sensor.AddObservation(boxescollect-3);
sensor.AddObservation(st1.position);
sensor.AddObservation(st2.position);
sensor.AddObservation(st3.position);
float dist = Vector3.Distance(Target.position,this.transform.position);
//Distance between Agent and target
sensor.AddObservation(dist);
float d1 = Vector3.Distance(st1.position,this.transform.position);
//Distance between Agent and target
sensor.AddObservation(d1);
float d2 = Vector3.Distance(st2.position,this.transform.position);
//Distance between Agent and target
sensor.AddObservation(d2);
float d3 = Vector3.Distance(st3.position,this.transform.position);
//Distance between Agent and target
sensor.AddObservation(d3);
// Agent velocity
sensor.AddObservation(rBody.velocity.x);
sensor.AddObservation(rBody.velocity.z);
}
public float speed = 10;
public override void OnActionReceived(float[] vectorAction)
{
Vector3 controlSignal = Vector3.zero;
controlSignal.x = vectorAction[0];
controlSignal.z = vectorAction[1];
//speed = vectorAction[2];
rBody.AddForce(controlSignal * speed);
//speed=0;
count++;
if(count==10000)
{
count=0;
h=0;
m=0;
c++;
miss.SetText(m+"");
hit.SetText(h+"");
text.SetText(c+"");
}
}
public override float[] Heuristic()
{
var action = new float[2];
action[0] = Input.GetAxis("Horizontal");
action[1] = Input.GetAxis("Vertical");
return action;
}
}
weird Graph of the training - tensorboard
这是我在tensorboard训练后得到的。
你在结束这一集时只完成了一个目标,而不是为了完全实现你的目标。所以你的图表看起来很乱,它结束得太早了,特工不明白它的目的。
我认为您可以添加一些新规则。
-如果代理人退缩了,他就会受到惩罚
-如果在情节结束前没有拿走所有4个立方体,代理人将受到惩罚
只有当智能体完成拿走所有 4 个立方体的任务(奖励)或者如果智能体采取了一些步骤而没有实现他的目标(惩罚)时,情节才应该结束
希望对您有所帮助。
我觉得我的英语不好。
___edit 2:___
您的问题很可能与本文档中描述的问题具有相似的特征。 (特别是第 28 页)
https://repositorio.upct.es/bitstream/handle/10317/8094/tfg-san-est.pdf?sequence=1&isAllowed=y
(对不起,这是西班牙语,但是 google 翻译器会给你一个相当准确的翻译。)
文档中的问题与您的问题相同,代理在拐角处出现问题,当他到达拐角时他returns到达起点,仅在拐角处出现这种情况。
你试过换景吗?
也许......尝试不带墙,看看代理是否真的在寻找 "all" 目标并更深入地研究问题。
图表是其中最少的,它只是一个表示。如果代理没有完成他的任务,你将不会有一个好的图表。
创建了一个简单的统一游戏,球应该击中目标而不击中墙壁。于是,开始训练,结果太糟糕了。球只是收集 4 个目标之一。但是 EndEpisode() 在收集到最后一个目标时发生。
Screenshot of the scene and the balls path throughout the training of 1650,000 steps(if im not wrong, since I called it a generation for every 10,000 steps of training.)
球甚至没有尝试击中第二个目标。我的代码有什么问题?
我什至尝试过用 RayPerceptionSensor3D 将球体替换为圆柱体,这样它就不会翻转并干扰 rayperceptionSensor3d。但它给出了更糟糕的结果。
using System.Security.Cryptography;
using System.Data.SqlTypes;
using System.Security;
using System.Runtime.InteropServices;
using System.Net.Sockets;
using System.ComponentModel.Design.Serialization;
using System.Collections.Generic;
using UnityEngine;
using MLAgents;
using MLAgents.Sensors;
using TMPro;
public class MazeRoller : Agent
{
Rigidbody rBody;
Vector3 ballpos;
void Start () {
rBody = GetComponent<Rigidbody>();
ballpos = rBody.transform.position;
}
public TextMeshPro text;
public TextMeshPro miss;
public TextMeshPro hit;
int count=0,c=0,h=0,m=0;
int boxescollect=0;
public Transform Target;
public Transform st1;
public Transform st2;
public Transform st3;
public override void OnEpisodeBegin()
{
rBody.angularVelocity = Vector3.zero;
rBody.velocity = Vector3.zero;
rBody.transform.position = ballpos;
boxescollect=0;
st1.GetComponent<Renderer> ().enabled = true;
st1.GetComponent<Collider> ().enabled = true;
st2.GetComponent<Renderer> ().enabled = true;
st2.GetComponent<Collider> ().enabled = true;
st3.GetComponent<Renderer> ().enabled = true;
st3.GetComponent<Collider> ().enabled = true;
}
void OnCollisionEnter(Collision collision)
{
if(collision.gameObject.name == "Target")
{
if(st1.GetComponent<Renderer> ().enabled==true || st2.GetComponent<Renderer> ().enabled==true || st3.GetComponent<Renderer> ().enabled==true)
{
SetReward(-3.0f+(float)(boxescollect));
}
SetReward(2.0f);
h++;
hit.SetText(h+"");
EndEpisode();
}
else if(collision.gameObject.name == "Target1")
{
boxescollect++;
AddReward(0.2f);
st1.GetComponent<Renderer> ().enabled = false;
st1.GetComponent<Collider> ().enabled = false;
}
else if(collision.gameObject.name == "Target2")
{
boxescollect++;
AddReward(0.4f);
st2.GetComponent<Renderer> ().enabled = false;
st2.GetComponent<Collider> ().enabled = false;
}
else if(collision.gameObject.name == "Target3")
{
boxescollect++;
AddReward(0.6f);
st3.GetComponent<Renderer> ().enabled = false;
st3.GetComponent<Collider> ().enabled = false;
}
//collision.gameObject.name == "wall1"||collision.gameObject.name == "wall2"||collision.gameObject.name == "wall3"||collision.gameObject.name == "wall4"||collision.gameObject.name == "wall5"||collision.gameObject.name == "wall6"||collision.gameObject.name == "wall7"
else if(collision.gameObject.tag == "wall")
{
if(st1.GetComponent<Renderer> ().enabled==true || st2.GetComponent<Renderer> ().enabled==true || st3.GetComponent<Renderer> ().enabled==true)
{
AddReward(-3.0f+(float)(boxescollect));
}
SetReward(-1.0f);
m++;
miss.SetText(m+"");
EndEpisode();
}
}
public override void CollectObservations(VectorSensor sensor)
{
// Target and Agent positions
sensor.AddObservation(Target.position);
sensor.AddObservation(this.transform.position);
sensor.AddObservation(boxescollect);
sensor.AddObservation(boxescollect-3);
sensor.AddObservation(st1.position);
sensor.AddObservation(st2.position);
sensor.AddObservation(st3.position);
float dist = Vector3.Distance(Target.position,this.transform.position);
//Distance between Agent and target
sensor.AddObservation(dist);
float d1 = Vector3.Distance(st1.position,this.transform.position);
//Distance between Agent and target
sensor.AddObservation(d1);
float d2 = Vector3.Distance(st2.position,this.transform.position);
//Distance between Agent and target
sensor.AddObservation(d2);
float d3 = Vector3.Distance(st3.position,this.transform.position);
//Distance between Agent and target
sensor.AddObservation(d3);
// Agent velocity
sensor.AddObservation(rBody.velocity.x);
sensor.AddObservation(rBody.velocity.z);
}
public float speed = 10;
public override void OnActionReceived(float[] vectorAction)
{
Vector3 controlSignal = Vector3.zero;
controlSignal.x = vectorAction[0];
controlSignal.z = vectorAction[1];
//speed = vectorAction[2];
rBody.AddForce(controlSignal * speed);
//speed=0;
count++;
if(count==10000)
{
count=0;
h=0;
m=0;
c++;
miss.SetText(m+"");
hit.SetText(h+"");
text.SetText(c+"");
}
}
public override float[] Heuristic()
{
var action = new float[2];
action[0] = Input.GetAxis("Horizontal");
action[1] = Input.GetAxis("Vertical");
return action;
}
}
weird Graph of the training - tensorboard 这是我在tensorboard训练后得到的。
你在结束这一集时只完成了一个目标,而不是为了完全实现你的目标。所以你的图表看起来很乱,它结束得太早了,特工不明白它的目的。
我认为您可以添加一些新规则。 -如果代理人退缩了,他就会受到惩罚 -如果在情节结束前没有拿走所有4个立方体,代理人将受到惩罚
只有当智能体完成拿走所有 4 个立方体的任务(奖励)或者如果智能体采取了一些步骤而没有实现他的目标(惩罚)时,情节才应该结束
希望对您有所帮助。 我觉得我的英语不好。
___edit 2:___
您的问题很可能与本文档中描述的问题具有相似的特征。 (特别是第 28 页)
https://repositorio.upct.es/bitstream/handle/10317/8094/tfg-san-est.pdf?sequence=1&isAllowed=y (对不起,这是西班牙语,但是 google 翻译器会给你一个相当准确的翻译。)
文档中的问题与您的问题相同,代理在拐角处出现问题,当他到达拐角时他returns到达起点,仅在拐角处出现这种情况。
你试过换景吗? 也许......尝试不带墙,看看代理是否真的在寻找 "all" 目标并更深入地研究问题。
图表是其中最少的,它只是一个表示。如果代理没有完成他的任务,你将不会有一个好的图表。