Skip to content

Commit 2d99f28

Browse files
committed
Reinforcement learning:
1 parent d8327c6 commit 2d99f28

File tree

2 files changed

+798
-0
lines changed

2 files changed

+798
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,399 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 25,
6+
"metadata": {
7+
"collapsed": true
8+
},
9+
"outputs": [],
10+
"source": [
11+
"import random\n",
12+
"import numpy as np\n",
13+
"from matplotlib import pyplot as plt\n",
14+
"%matplotlib inline"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": 33,
20+
"metadata": {
21+
"collapsed": true
22+
},
23+
"outputs": [],
24+
"source": [
25+
"NODES = 6\n",
26+
"GOAL = 4\n",
27+
"\n",
28+
"ALPHA = 0.01\n",
29+
"GAMMA = 0.9"
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": 40,
35+
"metadata": {
36+
"collapsed": false
37+
},
38+
"outputs": [],
39+
"source": [
40+
"#for ix in range(6):\n",
41+
"# for iy in range(6):\n",
42+
"# print \"(\" + str(ix) + \", \" + str(iy) + \"): 0,\"\n",
43+
"\n",
44+
"# Reward function\n",
45+
"R = {\n",
46+
"(0, 1): 0,\n",
47+
"(0, 2): 0,\n",
48+
"(0, 5): -4,\n",
49+
"(1, 0): -1,\n",
50+
"(1, 2): 5,\n",
51+
"(1, 3): 2,\n",
52+
"(1, 5): 0,\n",
53+
"(2, 0): -5,\n",
54+
"(2, 1): 0,\n",
55+
"(2, 3): 10,\n",
56+
"(2, 5): 8,\n",
57+
"(3, 0): 5,\n",
58+
"(3, 1): -3,\n",
59+
"(3, 2): 4,\n",
60+
"(3, 4): 50,\n",
61+
"(3, 5): 2,\n",
62+
"(4, 0): -10,\n",
63+
"(4, 1): -5,\n",
64+
"(4, 2): -20,\n",
65+
"(4, 3): 0,\n",
66+
"(4, 4): 100,\n",
67+
"(4, 5): -50,\n",
68+
"(5, 0): -15,\n",
69+
"(5, 1): 2,\n",
70+
"(5, 2): 7,\n",
71+
"(5, 3): 0,\n",
72+
"(5, 4): 70,\n",
73+
"}\n",
74+
"\n",
75+
"Q = {}"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": 41,
81+
"metadata": {
82+
"collapsed": false,
83+
"scrolled": true
84+
},
85+
"outputs": [],
86+
"source": [
87+
"def get_actions(current):\n",
88+
" actions = []\n",
89+
" for rx in xrange(NODES):\n",
90+
" if (current, rx) in R:\n",
91+
" actions.append(rx)\n",
92+
" return actions\n",
93+
"\n",
94+
"# get_actions(4)"
95+
]
96+
},
97+
{
98+
"cell_type": "code",
99+
"execution_count": 42,
100+
"metadata": {
101+
"collapsed": false,
102+
"scrolled": false
103+
},
104+
"outputs": [
105+
{
106+
"name": "stdout",
107+
"output_type": "stream",
108+
"text": [
109+
"Episode: 0 | Reward: 53\n",
110+
"Episode: 1 | Reward: 33\n",
111+
"Episode: 2 | Reward: 29\n",
112+
"Episode: 3 | Reward: 116\n",
113+
"Episode: 4 | Reward: 84\n",
114+
"Episode: 5 | Reward: 5\n",
115+
"Episode: 6 | Reward: 25\n",
116+
"Episode: 7 | Reward: 182\n",
117+
"Episode: 8 | Reward: 145\n",
118+
"Episode: 9 | Reward: 49\n",
119+
"Episode: 10 | Reward: 165\n",
120+
"Episode: 11 | Reward: 19\n",
121+
"Episode: 12 | Reward: -5\n",
122+
"Episode: 13 | Reward: -11\n",
123+
"Episode: 14 | Reward: 10\n",
124+
"Episode: 15 | Reward: 3\n",
125+
"Episode: 16 | Reward: 59\n",
126+
"Episode: 17 | Reward: -22\n",
127+
"Episode: 18 | Reward: 11\n",
128+
"Episode: 19 | Reward: 15\n",
129+
"Episode: 20 | Reward: 7\n",
130+
"Episode: 21 | Reward: 177\n",
131+
"Episode: 22 | Reward: 70\n",
132+
"Episode: 23 | Reward: 85\n",
133+
"Episode: 24 | Reward: 69\n",
134+
"Episode: 25 | Reward: 214\n",
135+
"Episode: 26 | Reward: 4\n",
136+
"Episode: 27 | Reward: 62\n",
137+
"Episode: 28 | Reward: 18\n",
138+
"Episode: 29 | Reward: 94\n",
139+
"Episode: 30 | Reward: 51\n",
140+
"Episode: 31 | Reward: 146\n",
141+
"Episode: 32 | Reward: 4\n",
142+
"Episode: 33 | Reward: 3\n",
143+
"Episode: 34 | Reward: -5\n",
144+
"Episode: 35 | Reward: 102\n",
145+
"Episode: 36 | Reward: -4\n",
146+
"Episode: 37 | Reward: 52\n",
147+
"Episode: 38 | Reward: 1\n",
148+
"Episode: 39 | Reward: -3\n",
149+
"Episode: 40 | Reward: 203\n",
150+
"Episode: 41 | Reward: 24\n",
151+
"Episode: 42 | Reward: 6\n",
152+
"Episode: 43 | Reward: -1\n",
153+
"Episode: 44 | Reward: 210\n",
154+
"Episode: 45 | Reward: 84\n",
155+
"Episode: 46 | Reward: 127\n",
156+
"Episode: 47 | Reward: 175\n",
157+
"Episode: 48 | Reward: 11\n",
158+
"Episode: 49 | Reward: 144\n",
159+
"Episode: 50 | Reward: 26\n",
160+
"Episode: 51 | Reward: 30\n",
161+
"Episode: 52 | Reward: -19\n",
162+
"Episode: 53 | Reward: 101\n",
163+
"Episode: 54 | Reward: 0\n",
164+
"Episode: 55 | Reward: 12\n",
165+
"Episode: 56 | Reward: -4\n",
166+
"Episode: 57 | Reward: 85\n",
167+
"Episode: 58 | Reward: -5\n",
168+
"Episode: 59 | Reward: -3\n",
169+
"Episode: 60 | Reward: 52\n",
170+
"Episode: 61 | Reward: 34\n",
171+
"Episode: 62 | Reward: 1\n",
172+
"Episode: 63 | Reward: 57\n",
173+
"Episode: 64 | Reward: 69\n",
174+
"Episode: 65 | Reward: 7\n",
175+
"Episode: 66 | Reward: 7\n",
176+
"Episode: 67 | Reward: 1\n",
177+
"Episode: 68 | Reward: 11\n",
178+
"Episode: 69 | Reward: 100\n",
179+
"Episode: 70 | Reward: 47\n",
180+
"Episode: 71 | Reward: 21\n",
181+
"Episode: 72 | Reward: 168\n",
182+
"Episode: 73 | Reward: 13\n",
183+
"Episode: 74 | Reward: 78\n",
184+
"Episode: 75 | Reward: -2\n",
185+
"Episode: 76 | Reward: 301\n",
186+
"Episode: 77 | Reward: 251\n",
187+
"Episode: 78 | Reward: 28\n",
188+
"Episode: 79 | Reward: 21\n",
189+
"Episode: 80 | Reward: -13\n",
190+
"Episode: 81 | Reward: 28\n",
191+
"Episode: 82 | Reward: 31\n",
192+
"Episode: 83 | Reward: -25\n",
193+
"Episode: 84 | Reward: 10\n",
194+
"Episode: 85 | Reward: 8\n",
195+
"Episode: 86 | Reward: 81\n",
196+
"Episode: 87 | Reward: 14\n",
197+
"Episode: 88 | Reward: 15\n",
198+
"Episode: 89 | Reward: 27\n",
199+
"Episode: 90 | Reward: 16\n",
200+
"Episode: 91 | Reward: 117\n",
201+
"Episode: 92 | Reward: 21\n",
202+
"Episode: 93 | Reward: 61\n",
203+
"Episode: 94 | Reward: -12\n",
204+
"Episode: 95 | Reward: 98\n",
205+
"Episode: 96 | Reward: -15\n",
206+
"Episode: 97 | Reward: 63\n",
207+
"Episode: 98 | Reward: 37\n",
208+
"Episode: 99 | Reward: 2\n"
209+
]
210+
}
211+
],
212+
"source": [
213+
"N_ep = 100\n",
214+
"\n",
215+
"for ep in range(N_ep):\n",
216+
" pos = 0\n",
217+
" rew = 0\n",
218+
" \n",
219+
" # while not pos == GOAL:\n",
220+
" for kx in range(10):\n",
221+
" # print pos\n",
222+
" p_act = get_actions(pos)\n",
223+
" \n",
224+
" best_ac = []\n",
225+
" q_best = None\n",
226+
" \n",
227+
" for ac in p_act:\n",
228+
" rq = Q.setdefault((pos, ac), 0)\n",
229+
" if rq > q_best:\n",
230+
" q_best = rq\n",
231+
" best = [ac]\n",
232+
" elif rq == q_best:\n",
233+
" best.append(ac)\n",
234+
" \n",
235+
" nxt_pos = random.choice(p_act)\n",
236+
" nxt_p_ac = get_actions(nxt_pos)\n",
237+
" \n",
238+
" nq_best = None\n",
239+
" \n",
240+
" for ac in nxt_p_ac:\n",
241+
" rq = Q.setdefault((nxt_pos, ac), 0)\n",
242+
" nq_best = max(nq_best, rq)\n",
243+
" \n",
244+
" Q[(pos, nxt_pos)] = (1- ALPHA)*Q[(pos, nxt_pos)] + ALPHA*(R[(pos, nxt_pos)] + GAMMA*nq_best)\n",
245+
" rew += R[(pos, nxt_pos)]\n",
246+
" pos = nxt_pos\n",
247+
" print \"Episode:\", ep, \"| Reward:\", rew"
248+
]
249+
},
250+
{
251+
"cell_type": "code",
252+
"execution_count": 43,
253+
"metadata": {
254+
"collapsed": false
255+
},
256+
"outputs": [
257+
{
258+
"data": {
259+
"text/plain": [
260+
"{(0, 1): 1.0527522119424737,\n",
261+
" (0, 2): 1.7765083795300343,\n",
262+
" (0, 5): 5.210843889083548,\n",
263+
" (1, 0): 0.0391664790275139,\n",
264+
" (1, 2): 2.6955696115274415,\n",
265+
" (1, 3): 3.156196226973814,\n",
266+
" (1, 5): 4.963970280138251,\n",
267+
" (2, 0): -1.3609407399152875,\n",
268+
" (2, 1): 0.5589380241903006,\n",
269+
" (2, 3): 5.812972573178521,\n",
270+
" (2, 5): 7.301826243368765,\n",
271+
" (3, 0): 1.4065845289178485,\n",
272+
" (3, 1): -0.25803609254372506,\n",
273+
" (3, 2): 1.7033928867280708,\n",
274+
" (3, 4): 11.208298009811994,\n",
275+
" (3, 5): 3.5555140878750735,\n",
276+
" (4, 0): -1.0076784500262796,\n",
277+
" (4, 1): -0.30127326885017214,\n",
278+
" (4, 2): -2.017800766935019,\n",
279+
" (4, 3): 0.5506222723837432,\n",
280+
" (4, 4): 17.84781294854947,\n",
281+
" (4, 5): -3.771335058692992,\n",
282+
" (5, 0): -4.208275309248176,\n",
283+
" (5, 1): 1.2881067744708923,\n",
284+
" (5, 2): 2.852172356891487,\n",
285+
" (5, 3): 1.9640566727660835,\n",
286+
" (5, 4): 26.060782382076916}"
287+
]
288+
},
289+
"execution_count": 43,
290+
"metadata": {},
291+
"output_type": "execute_result"
292+
}
293+
],
294+
"source": [
295+
"#plt.figure(0)\n",
296+
"#for ix in range(NODES):\n",
297+
"# for iy in range(NODES):\n",
298+
"# plt.subplot(NODES, NODES)\n",
299+
"Q"
300+
]
301+
},
302+
{
303+
"cell_type": "code",
304+
"execution_count": 39,
305+
"metadata": {
306+
"collapsed": false
307+
},
308+
"outputs": [
309+
{
310+
"name": "stdout",
311+
"output_type": "stream",
312+
"text": [
313+
"0\n",
314+
"5\n",
315+
"Episode: 0 | Reward: 66\n",
316+
"0\n",
317+
"5\n",
318+
"Episode: 1 | Reward: 66\n",
319+
"0\n",
320+
"5\n",
321+
"Episode: 2 | Reward: 66\n",
322+
"0\n",
323+
"5\n",
324+
"Episode: 3 | Reward: 66\n",
325+
"0\n",
326+
"5\n",
327+
"Episode: 4 | Reward: 66\n"
328+
]
329+
}
330+
],
331+
"source": [
332+
"N_ep = 5\n",
333+
"\n",
334+
"for ep in range(N_ep):\n",
335+
" pos = 0\n",
336+
" rew = 0\n",
337+
" \n",
338+
" while not pos == GOAL:\n",
339+
" print pos\n",
340+
" p_act = get_actions(pos)\n",
341+
" \n",
342+
" best_ac = []\n",
343+
" q_best = None\n",
344+
" \n",
345+
" for ac in p_act:\n",
346+
" rq = Q.setdefault((pos, ac), 0)\n",
347+
" if rq > q_best:\n",
348+
" q_best = rq\n",
349+
" best = [ac]\n",
350+
" elif rq == q_best:\n",
351+
" best.append(ac)\n",
352+
" \n",
353+
" nxt_pos = random.choice(best)\n",
354+
" nxt_p_ac = get_actions(nxt_pos)\n",
355+
" \n",
356+
" nq_best = None\n",
357+
" \n",
358+
" for ac in nxt_p_ac:\n",
359+
" rq = Q.setdefault((nxt_pos, ac), 0)\n",
360+
" nq_best = max(nq_best, rq)\n",
361+
" \n",
362+
" Q[(pos, nxt_pos)] = (1- ALPHA)*Q[(pos, nxt_pos)] + ALPHA*(R[(pos, nxt_pos)] + GAMMA*nq_best)\n",
363+
" rew += R[(pos, nxt_pos)]\n",
364+
" pos = nxt_pos\n",
365+
" print \"Episode:\", ep, \"| Reward:\", rew"
366+
]
367+
},
368+
{
369+
"cell_type": "code",
370+
"execution_count": null,
371+
"metadata": {
372+
"collapsed": true
373+
},
374+
"outputs": [],
375+
"source": []
376+
}
377+
],
378+
"metadata": {
379+
"kernelspec": {
380+
"display_name": "Python 2",
381+
"language": "python",
382+
"name": "python2"
383+
},
384+
"language_info": {
385+
"codemirror_mode": {
386+
"name": "ipython",
387+
"version": 2
388+
},
389+
"file_extension": ".py",
390+
"mimetype": "text/x-python",
391+
"name": "python",
392+
"nbconvert_exporter": "python",
393+
"pygments_lexer": "ipython2",
394+
"version": "2.7.12"
395+
}
396+
},
397+
"nbformat": 4,
398+
"nbformat_minor": 2
399+
}

0 commit comments

Comments
 (0)