-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathDrivingScenarioEnv.m
193 lines (176 loc) · 6.77 KB
/
DrivingScenarioEnv.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
classdef DrivingScenarioEnv < rl.env.MATLABEnvironment
% Copyright 2020 The MathWorks, Inc.
%MYENVIRONMENT: Template for defining custom environment in MATLAB.
% parameters for simulation environment
properties
scenario
network
traffic
cars
state
driver
InjectionRate
TurnRatio
N = 3 % number of road
phaseDuration = 50 % time duration for each of the phase
T
end
% simulation doesn't have yellow light
% manually set up clearning phase here if needed
properties
clearingPhase = false
clearingPhaseTime = 0
TrafficSignalDesign
ObservationSpaceDesign
end
% parameter for reward definition
properties
rewardForPass = 0
vehicleEnterJunction % keep record of cars pass the intersection
hitPenalty = 20
penaltyForFreqSwitch = 1
safeDistance = 2.25 % check collision
slowSpeedThreshold = 3.5 % check whether car is waiting
end
properties
recordVid = false
vid
end
properties
discrete_action = [0 1 2];
dim =10;
end
properties(Access = protected)
IsDone = false
end
%% Necessary Methods
methods
function this = DrivingScenarioEnv()
% Initialize Observation settings
ObservationInfo = rlNumericSpec([10, 1]); % # of state
ObservationInfo.Name = 'real-time traffic information';
ObservationInfo.Description = '';
% Initialize action settings
ActionInfo = rlFiniteSetSpec([0 1 2]); % three phases
ActionInfo.Name = 'traffic signal phases';
% The following line implements built-in functions of the RL environment
this = [email protected](ObservationInfo,ActionInfo);
end
function [state, Reward,IsDone,LoggedSignals] = step(this, Action)
Action = getForce(this, Action);
% update the action
pre_phase = this.traffic.IsOpen;
if this.TrafficSignalDesign == 1
cur_phase = signalPhaseDesign1(Action);
elseif this.TrafficSignalDesign == 2
cur_phase = signalPhaseDesign2(Action);
elseif this.TrafficSignalDesign == 3
cur_phase = signalPhaseDesign3(Action);
end
% Reward: penalty for signal phase switch
changed = ~isequal(pre_phase, cur_phase);
Reward = this.penaltyForFreqSwitch * (1 - changed);
% (yellow light time)add clearing phase when signal phase switch
if changed && this.clearingPhase
for i = 1:this.clearingPhaseTime
this.traffic.IsOpen = [0, 0, 0, 0, 0, 0];
advance(this.scenario);
this.T = this.T + this.scenario.SampleTime;
notifyEnvUpdated(this);
% check terminal condition
IsHit = checkCollision(this);
Reward = Reward - IsHit * this.hitPenalty;
this.IsDone = IsHit || this.T+0.5 >= this.scenario.StopTime;
if this.IsDone
break
end
end
end
% (green light time)simulate the signal phase based on the action by RL
this.traffic.IsOpen = cur_phase;
if ~this.IsDone
for i = 1:this.phaseDuration
% update traffic state
advance(this.scenario);
this.T = this.T + this.scenario.SampleTime;
% update visulization
notifyEnvUpdated(this);
% check terminal condition
IsHit = checkCollision(this);
Reward = Reward - IsHit * this.hitPenalty;
this.IsDone = IsHit || this.T+0.5 >= this.scenario.StopTime;
if this.IsDone
break
end
% obtain reward
Reward = Reward + obtainReward(this, cur_phase);
end
end
if this.ObservationSpaceDesign == 1
state = observationSpace1(this, Action);
else
state = observationSpace2(this, Action);
end
this.state = state;
IsDone = this.IsDone;
LoggedSignals = [];
end
function InitialState = reset(this)
% flag for record simulation
this.recordVid = false;
% Initialize scenario
this.scenario = createTJunctionScenario();
this.scenario.StopTime = 100;
this.scenario.SampleTime = 0.05;
this.T = 0;
% initialize network
this.network = createTJunctionNetwork(this.scenario);
this.traffic = trafficControl.TrafficController(this.network(7:12));
% car parameters
this.InjectionRate = [250, 250, 250]; % veh/hour
this.TurnRatio = [50, 50];
this.cars = createVehiclesForTJunction(this.scenario, this.network, this.InjectionRate, this.TurnRatio);
this.vehicleEnterJunction = [];
% obtain state from traffic and network
if this.ObservationSpaceDesign == 1
InitialState = observationSpace1(this, 0);
else
InitialState = observationSpace2(this, 0);
end
% visulization
notifyEnvUpdated(this);
end
end
methods
function force = getForce(this,action)
if ~ismember(action,this.ActionInfo.Elements)
error('Action must be integer from 1 to numAction');
end
force = action;
end
% update the action info based on max force
function updateActionInfo(this)
this.ActionInfo.Elements = this.discrete_action;
end
end
methods (Access = protected)
function envUpdatedCallback(this)
if this.T == 0
close all;
plot(this.scenario)
set(gcf,'Visible','On');
if this.recordVid
this.vid = VideoWriter('baseRLlearningProcess33');
this.vid.FrameRate=20;
open(this.vid)
end
end
if this.recordVid
frame = getframe(gcf);
writeVideo(this.vid,frame);
end
this.traffic.plotOpenPaths()
drawnow
end
end
end