Skip to content

Commit bf4a35a

Browse files
committed
change to while loop; tighten text
1 parent 7d3cdf4 commit bf4a35a

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

notebooks_en/2_Logistic_Regression.ipynb

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@
105105
"cell_type": "markdown",
106106
"metadata": {},
107107
"source": [
108+
"Above, we chose the parameters $w$ and $b$ for the model, and used them to get the intermediate variable $z$ adding some random noise to make our synthetic data look more \"realistic.\"\n",
109+
"The call to `numpy.random.seed()` makes our added noise reproducible, i.e., you always get the same pseudo-random numbers from the subsequent call to `numpy.random.normal()`. \n",
110+
"(Let's not get side-tracked into a discussion about pseudo-random number generation, and leave that for another tutorial.)\n",
111+
"\n",
108112
"We can apply a decision boundary now to assign the data to the two classes. Be sure to read the documentation of [`numpy.where()`](https://numpy.org/doc/stable/reference/generated/numpy.where.html) to understand the code below, noting that after the logical condition we specify the values to assign if the condition is `True` or `False`."
109113
]
110114
},
@@ -158,7 +162,7 @@
158162
"We'd like to work with a better loss function, that avoids this problem, and we build one below by integration. (For a more detailed discussion, we recommend Chapter 3 of Michael Nielsen's free ebook [2]).\n",
159163
"\n",
160164
"It's important to note also that our prediction model is a nonlinear function, composed with the linear model, and the square-error would lead to a non-convex loss function that can have local minima, and make gradient descent fail. \n",
161-
"Here's an example posted on Stackoverflow in answer to this very [question](https://math.stackexchange.com/questions/2381724/logistic-regression-when-can-the-cost-function-be-non-convex). Consider just three data points, and a model with no intercept, $z = wx$: $(-1, 2), (-20, -1), (-5, 5)$. What would a square-error mean function look like? We can plot it using SymPy."
165+
"Here's an example posted on Stackoverflow in answer to this very [question](https://math.stackexchange.com/questions/2381724/logistic-regression-when-can-the-cost-function-be-non-convex). Consider just three data points, and a model with no intercept, $z = wx$: $(-1, 2), (-20, -1), (-5, 5)$. What would a square-error loss function look like? We can plot it using SymPy."
162166
]
163167
},
164168
{
@@ -379,7 +383,7 @@
379383
"cell_type": "markdown",
380384
"metadata": {},
381385
"source": [
382-
"So far, we can still use SymPy to get derivatives of the loss function with respect to the parameters. But with more complicated models, finding symbolic derivatives will take a long time.\n",
386+
"So far, we can still use SymPy to get derivatives of the loss function with respect to the parameters. But with more complicated models, finding symbolic derivatives could take a long time.\n",
383387
"\n",
384388
"Have a look at the derivative of the logistic loss with respect to the parameter $b$: "
385389
]
@@ -397,7 +401,7 @@
397401
"cell_type": "markdown",
398402
"metadata": {},
399403
"source": [
400-
"We can use symbolic differentiation but it can take a long time to compute for very complicated functions.\n",
404+
"Although we can use symbolic differentiation, we later need to convert the resulting expression to a Python function that can be called and evaluated at many data inputs. Maybe this is not the best approach.\n",
401405
"\n",
402406
"There's a better way! It's called _automatic differentiation_: the idea is to algorithmically obtain derivatives of numeric functions written in computer code. \n",
403407
"It sounds like magic, but it can be done by a combination of the chain rule, symbolic rules of differentiation for elementary operations, and a numeric evaluation trace of the elementary derivatives. \n",
@@ -432,7 +436,7 @@
432436
"\n",
433437
"In addition, `autograd.numpy` is a wrapper to the NumPy library. This allows you to call your favorite NumPy methods with `autograd` keeping track of every operation so it can give you the derivative (via the chain rule).\n",
434438
"We ill import it using the alias (`as np`), consistent with the tutorials and documentation that you will find online.\n",
435-
"Up to now in the _Engineering Computations_ series of modules, we have refrained from using the aliased form of the import statements, just to have more explicit and readable code. "
439+
"Up to now in the _Engineering Computations_ series of modules, we had refrained from using the aliased form of the import statements, just to have more explicit and readable code. "
436440
]
437441
},
438442
{
@@ -477,6 +481,7 @@
477481
" \n",
478482
"def logistic_model(params, x):\n",
479483
" '''A prediction model based on the logistic function composed with wx+b\n",
484+
" Arguments:\n",
480485
" params: array(w,b) of model parameters\n",
481486
" x : array of x data'''\n",
482487
" w = params[0]\n",
@@ -486,7 +491,11 @@
486491
" return y\n",
487492
"\n",
488493
"def log_loss(params, model, x, y):\n",
489-
" '''The logistic loss function'''\n",
494+
" '''The logistic loss function\n",
495+
" Arguments:\n",
496+
" params: array(w,b) of model parameters\n",
497+
" model: the Python function for the logistic model\n",
498+
" x, y: arrays of input data to the model'''\n",
490499
" y_pred = model(params, x)\n",
491500
" return -np.mean(y * np.log(y_pred) + (1-y) * np.log(1 - y_pred))"
492501
]
@@ -562,18 +571,21 @@
562571
"metadata": {},
563572
"outputs": [],
564573
"source": [
565-
"for i in range(3000):\n",
574+
"max_iter = 3000\n",
575+
"i = 0\n",
576+
"descent = np.ones(len(x_data))\n",
577+
"\n",
578+
"while np.linalg.norm(descent) > 0.001 and i < max_iter:\n",
579+
"\n",
566580
" descent = gradient(params, logistic_model, x_data, y_data)\n",
567-
" oldparams = params\n",
568581
" params = params - descent * 0.01\n",
569-
" residual = np.abs((params - oldparams) / oldparams)\n",
570-
" if np.all(residual < 1e-6):\n",
571-
" break\n",
582+
" i += 1\n",
583+
"\n",
572584
"\n",
573585
"print(f'Optimized value of w is {params[0]:.3f} vs. true value: 2')\n",
574586
"print(f'Optimized value of b is {params[1]:.3f} vs. true value: 1')\n",
575587
"print(f'Exited after {i} iterations')\n",
576-
"print(f'Residual is {residual}')\n",
588+
"\n",
577589
"\n",
578590
"pyplot.scatter(x_data, y_data, alpha=0.4)\n",
579591
"pyplot.plot(x_data, logistic_model(params, x_data), '-r');"
@@ -694,7 +706,7 @@
694706
"name": "python",
695707
"nbconvert_exporter": "python",
696708
"pygments_lexer": "ipython3",
697-
"version": "3.6.13"
709+
"version": "3.8.5"
698710
}
699711
},
700712
"nbformat": 4,

0 commit comments

Comments
 (0)